BatchLSEstimator.java

/* Copyright 2002-2016 CS Systèmes d'Information
 * Licensed to CS Systèmes d'Information (CS) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * CS licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.orekit.estimation.leastsquares;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.linear.RealVector;
import org.hipparchus.optim.ConvergenceChecker;
import org.hipparchus.optim.nonlinear.vector.leastsquares.EvaluationRmsChecker;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer.Optimum;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
import org.hipparchus.optim.nonlinear.vector.leastsquares.ParameterValidator;
import org.hipparchus.util.Incrementor;
import org.orekit.errors.OrekitException;
import org.orekit.errors.OrekitExceptionWrapper;
import org.orekit.estimation.measurements.EstimatedMeasurement;
import org.orekit.estimation.measurements.EstimationsProvider;
import org.orekit.estimation.measurements.ObservedMeasurement;
import org.orekit.orbits.Orbit;
import org.orekit.propagation.conversion.NumericalPropagatorBuilder;
import org.orekit.propagation.numerical.NumericalPropagator;
import org.orekit.time.ChronologicalComparator;
import org.orekit.utils.ParameterDriver;
import org.orekit.utils.ParameterDriversList;
import org.orekit.utils.ParameterDriversList.DelegatingDriver;


/** Least squares estimator for orbit determination.
 * @author Luc Maisonobe
 * @since 8.0
 */
public class BatchLSEstimator {

    /** Builder for propagator. */
    private final NumericalPropagatorBuilder propagatorBuilder;

    /** Measurements. */
    private final List<ObservedMeasurement<?>> measurements;

    /** Solver for least squares problem. */
    private final LeastSquaresOptimizer optimizer;

    /** Convergence threshold on normalized parameters. */
    private double parametersConvergenceThreshold;

    /** Builder for the least squares problem. */
    private final LeastSquaresBuilder lsBuilder;

    /** Oberver for iterations. */
    private BatchLSObserver observer;

    /** Last estimations. */
    private Map<ObservedMeasurement<?>, EstimatedMeasurement<?>> estimations;

    /** Last orbit. */
    private Orbit orbit;

    /** Optimum found. */
    private Optimum optimum;

    /** Counter for the evaluations. */
    private Incrementor evaluationsCounter;

    /** Counter for the iterations. */
    private Incrementor iterationsCounter;

    /** Simple constructor.
     * @param propagatorBuilder builder to user for propagation
     * @param optimizer solver for least squares problem
     * @exception OrekitException if some propagator parameter cannot be retrieved
     */
    public BatchLSEstimator(final NumericalPropagatorBuilder propagatorBuilder,
                            final LeastSquaresOptimizer optimizer)
        throws OrekitException {

        this.propagatorBuilder              = propagatorBuilder;
        this.measurements                   = new ArrayList<ObservedMeasurement<?>>();
        this.optimizer                      = optimizer;
        this.parametersConvergenceThreshold = Double.NaN;
        this.lsBuilder                      = new LeastSquaresBuilder();
        this.estimations                    = null;
        this.observer                       = null;

        // our model computes value and Jacobian in one call,
        // so we don't use the lazy evaluation feature
        lsBuilder.lazyEvaluation(false);

        // we manage weight by ourselves, as we change them during
        // iterations (setting to 0 the identified outliers measurements)
        // so the least squares problem should not see our weights
        lsBuilder.weight(null);

    }

    /** Set an observer for iterations.
     * @param observer observer to be notified at the end of each iteration
     */
    public void setObserver(final BatchLSObserver observer) {
        this.observer = observer;
    }

    /** Add a measurement.
     * @param measurement measurement to add
     * @exception OrekitException if the measurement has a parameter
     * that is already used
     */
    public void addMeasurement(final ObservedMeasurement<?> measurement)
        throws OrekitException {
        measurements.add(measurement);
    }

    /** Set the maximum number of iterations.
     * <p>
     * The iterations correspond to the top level iterations of
     * the {@link LeastSquaresOptimizer least squares optimizer}.
     * </p>
     * @param maxIterations maxIterations maximum number of iterations
     * @see #setMaxEvaluations(int)
     * @see #getIterationsCount()
     */
    public void setMaxIterations(final int maxIterations) {
        lsBuilder.maxIterations(maxIterations);
    }

    /** Set the maximum number of model evaluations.
     * <p>
     * The evaluations correspond to the orbit propagations and
     * measurements estimations performed with a set of estimated
     * parameters.
     * </p>
     * <p>
     * For {@link org.hipparchus.optim.nonlinear.vector.leastsquares.GaussNewtonOptimizer
     * Gauss-Newton optimizer} there is one evaluation at each iteration,
     * so the maximum numbers may be set to the same value. For {@link
     * org.hipparchus.optim.nonlinear.vector.leastsquares.LevenbergMarquardtOptimizer
     * Levenberg-Marquardt optimizer}, there can be several evaluations at
     * some iterations (typically for the first couple of iterations), so the
     * maximum number of evaluations may be set to a higher value than the
     * maximum number of iterations.
     * </p>
     * @param maxEvaluations maximum number of model evaluations
     * @see #setMaxIterations(int)
     * @see #getEvaluationsCount()
     */
    public void setMaxEvaluations(final int maxEvaluations) {
        lsBuilder.maxEvaluations(maxEvaluations);
    }

    /** Get the orbital parameters supported by this estimator.
     * @param estimatedOnly if true, only estimated parameters are returned
     * @return orbital parameters supported by this estimator
     * @exception OrekitException if different parameters have the same name
     */
    public ParameterDriversList getOrbitalParametersDrivers(final boolean estimatedOnly)
        throws OrekitException {

        if (estimatedOnly) {
            final ParameterDriversList estimated = new ParameterDriversList();
            for (final DelegatingDriver delegating : propagatorBuilder.getOrbitalParametersDrivers().getDrivers()) {
                if (delegating.isSelected()) {
                    for (final ParameterDriver driver : delegating.getRawDrivers()) {
                        estimated.add(driver);
                    }
                }
            }
            return estimated;
        } else {
            return propagatorBuilder.getOrbitalParametersDrivers();
        }

    }

    /** Get the propagator parameters supported by this estimator.
     * @param estimatedOnly if true, only estimated parameters are returned
     * @return propagator parameters supported by this estimator
     * @exception OrekitException if different parameters have the same name
     */
    public ParameterDriversList getPropagatorParametersDrivers(final boolean estimatedOnly)
        throws OrekitException {

        if (estimatedOnly) {
            final ParameterDriversList estimated = new ParameterDriversList();
            for (final DelegatingDriver delegating : propagatorBuilder.getPropagationParametersDrivers().getDrivers()) {
                if (delegating.isSelected()) {
                    for (final ParameterDriver driver : delegating.getRawDrivers()) {
                        estimated.add(driver);
                    }
                }
            }
            return estimated;
        } else {
            return propagatorBuilder.getPropagationParametersDrivers();
        }

    }

    /** Get the measurements parameters supported by this estimator (including measurements and modifiers).
     * @param estimatedOnly if true, only estimated parameters are returned
     * @return measurements parameters supported by this estimator
     * @exception OrekitException if different parameters have the same name
     */
    public ParameterDriversList getMeasurementsParametersDrivers(final boolean estimatedOnly)
        throws OrekitException {

        final ParameterDriversList parameters =  new ParameterDriversList();
        for (final  ObservedMeasurement<?> measurement : measurements) {
            for (final ParameterDriver driver : measurement.getParametersDrivers()) {
                if ((!estimatedOnly) || driver.isSelected()) {
                    parameters.add(driver);
                }
            }
        }

        parameters.sort();

        return parameters;

    }

    /**
     * Set convergence threshold.
     * <p>
     * The convergence used for estimation is based on the estimated
     * parameters {@link ParameterDriver#getNormalizedValue() normalized values}.
     * Convergence is considered to have been reached when the difference
     * between previous and current normalized value is less than the
     * convergence threshold for all parameters. The same value is used
     * for all parameters since they are normalized and hence dimensionless.
     * </p>
     * <p>
     * Normalized values are computed as {@code (current - reference)/scale},
     * so convergence is reached when the following condition holds for
     * all estimated parameters:
     * {@code |current[i] - previous[i]| <= threshold * scale[i]}
     * </p>
     * <p>
     * So the convergence threshold specified here can be considered as
     * a multiplication factor applied to scale. Since for all parameters
     * the scale is often small (typically about 1 m for orbital positions
     * for example), then the threshold should not be too small. A value
     * of 10⁻³ is often quite accurate.
     *
     * @param parametersConvergenceThreshold convergence threshold on
     * normalized parameters (dimensionless, related to parameters scales)
     * @see EvaluationRmsChecker
     */
    public void setParametersConvergenceThreshold(final double parametersConvergenceThreshold) {
        this.parametersConvergenceThreshold = parametersConvergenceThreshold;
    }

    /** Estimate the orbital, propagation and measurements parameters.
     * <p>
     * The initial guess for all parameters must have been set before calling this method
     * using {@link #getOrbitalParametersDrivers(boolean)}, {@link #getPropagatorParametersDrivers(boolean)},
     * and {@link #getMeasurementsParametersDrivers(boolean)} and then {@link ParameterDriver#setValue(double)
     * setting the values} of the parameters.
     * </p>
     * <p>
     * After this method returns, the estimated parameters can be retrieved using
     * {@link #getOrbitalParametersDrivers(boolean)}, {@link #getPropagatorParametersDrivers(boolean)},
     * and {@link #getMeasurementsParametersDrivers(boolean)} and then {@link ParameterDriver#getValue()
     * getting the values} of the parameters.
     * </p>
     * <p>
     * As a convenience, the method also returns a fully configured and ready to use
     * propagator set up with all the estimated values.
     * </p>
     * @return propagator configured with estimated orbit as initial state, and all
     * propagator estimated parameters also set
     * @exception OrekitException if there is a conflict in parameters names
     * or if orbit cannot be determined
     */
    public NumericalPropagator estimate() throws OrekitException {

        // get all estimated parameters
        final ParameterDriversList estimatedOrbitalParameters      = getOrbitalParametersDrivers(true);
        final ParameterDriversList estimatedPropagatorParameters   = getPropagatorParametersDrivers(true);
        final ParameterDriversList estimatedMeasurementsParameters = getMeasurementsParametersDrivers(true);

        // create start point
        final double[] start = new double[estimatedOrbitalParameters.getNbParams() +
                                          estimatedPropagatorParameters.getNbParams() +
                                          estimatedMeasurementsParameters.getNbParams()];
        int iStart = 0;
        for (final ParameterDriver driver : estimatedOrbitalParameters.getDrivers()) {
            start[iStart++] = driver.getNormalizedValue();
        }
        for (final ParameterDriver driver : estimatedPropagatorParameters.getDrivers()) {
            start[iStart++] = driver.getNormalizedValue();
        }
        for (final ParameterDriver driver : estimatedMeasurementsParameters.getDrivers()) {
            start[iStart++] = driver.getNormalizedValue();
        }
        lsBuilder.start(start);

        // create target (which is an array set to 0, as we compute weighted residuals ourselves)
        int p = 0;
        for (final ObservedMeasurement<?> measurement : measurements) {
            if (measurement.isEnabled()) {
                p += measurement.getDimension();
            }
        }
        final double[] target = new double[p];
        lsBuilder.target(target);

        // set up the model
        final ModelObserver modelObserver = new ModelObserver() {
            /** {@inheritDoc} */
            @Override
            public void modelCalled(final Orbit newOrbit,
                                    final Map<ObservedMeasurement<?>, EstimatedMeasurement<?>> newEstimations) {
                BatchLSEstimator.this.orbit       = newOrbit;
                BatchLSEstimator.this.estimations = newEstimations;
            }
        };
        final Model model = new Model(propagatorBuilder, measurements, estimatedMeasurementsParameters,
                                      modelObserver);
        lsBuilder.model(model);

        // add a validator for orbital parameters
        lsBuilder.parameterValidator(new Validator(estimatedOrbitalParameters,
                                                   estimatedPropagatorParameters,
                                                   estimatedMeasurementsParameters));

        lsBuilder.checker(new ConvergenceChecker<LeastSquaresProblem.Evaluation>() {
            /** {@inheritDoc} */
            @Override
            public boolean converged(final int iteration,
                                     final LeastSquaresProblem.Evaluation previous,
                                     final LeastSquaresProblem.Evaluation current) {
                final double lInf = current.getPoint().getLInfDistance(previous.getPoint());
                return lInf <= parametersConvergenceThreshold;
            }
        });

        // set up the problem to solve
        final LeastSquaresProblem problem = new TappedLSProblem(lsBuilder.build(),
                                                                model,
                                                                estimatedOrbitalParameters,
                                                                estimatedPropagatorParameters,
                                                                estimatedMeasurementsParameters);

        try {

            // solve the problem
            optimum = optimizer.optimize(problem);

            // create a new configured propagator with all estimated parameters
            return model.createPropagator(optimum.getPoint());

        } catch (MathRuntimeException mrte) {
            throw new OrekitException(mrte);
        } catch (OrekitExceptionWrapper oew) {
            throw oew.getException();
        }

    }

    /** Get the last estimations performed.
     * @return last estimations performed
     */
    public Map<ObservedMeasurement<?>, EstimatedMeasurement<?>> getLastEstimations() {
        return Collections.unmodifiableMap(estimations);
    }

    /** Get the optimum found.
     * @return optimum found after last call to {@link #estimate()}
     */
    public Optimum getOptimum() {
        return optimum;
    }

    /** Get the number of iterations used for last estimation.
     * @return number of iterations used for last estimation
     * @see #setMaxIterations(int)
     */
    public int getIterationsCount() {
        return iterationsCounter.getCount();
    }

    /** Get the number of evaluations used for last estimation.
     * @return number of evaluations used for last estimation
     * @see #setMaxEvaluations(int)
     */
    public int getEvaluationsCount() {
        return evaluationsCounter.getCount();
    }

    /** Wrapper used to tap the various counters. */
    private class TappedLSProblem implements LeastSquaresProblem {

        /** Underlying problem. */
        private final LeastSquaresProblem problem;

        /** Multivariate function model. */
        private final Model model;

        /** Estimated orbital parameters. */
        private final ParameterDriversList estimatedOrbitalParameters;

        /** Estimated propagator parameters. */
        private final ParameterDriversList estimatedPropagatorParameters;

        /** Estimated measurements parameters. */
        private final ParameterDriversList estimatedMeasurementsParameters;

        /** Simple constructor.
         * @param problem underlying problem
         * @param model multivariate function model
         * @param estimatedOrbitalParameters estimated orbital parameters
         * @param estimatedPropagatorParameters estimated propagator parameters
         * @param estimatedMeasurementsParameters estimated measurements parameters
         */
        TappedLSProblem(final LeastSquaresProblem problem,
                        final Model model,
                        final ParameterDriversList estimatedOrbitalParameters,
                        final ParameterDriversList estimatedPropagatorParameters,
                        final ParameterDriversList estimatedMeasurementsParameters) {
            this.problem                         = problem;
            this.model                           = model;
            this.estimatedOrbitalParameters      = estimatedOrbitalParameters;
            this.estimatedPropagatorParameters   = estimatedPropagatorParameters;
            this.estimatedMeasurementsParameters = estimatedMeasurementsParameters;
        }

        /** {@inheritDoc} */
        @Override
        public Incrementor getEvaluationCounter() {
            // tap the evaluations counter
            BatchLSEstimator.this.evaluationsCounter = problem.getEvaluationCounter();
            model.setEvaluationsCounter(BatchLSEstimator.this.evaluationsCounter);
            return BatchLSEstimator.this.evaluationsCounter;
        }

        /** {@inheritDoc} */
        @Override
        public Incrementor getIterationCounter() {
            // tap the iterations counter
            BatchLSEstimator.this.iterationsCounter = problem.getIterationCounter();
            model.setIterationsCounter(BatchLSEstimator.this.iterationsCounter);
            return BatchLSEstimator.this.iterationsCounter;
        }

        /** {@inheritDoc} */
        @Override
        public ConvergenceChecker<Evaluation> getConvergenceChecker() {
            return problem.getConvergenceChecker();
        }

        /** {@inheritDoc} */
        @Override
        public RealVector getStart() {
            return problem.getStart();
        }

        /** {@inheritDoc} */
        @Override
        public int getObservationSize() {
            return problem.getObservationSize();
        }

        /** {@inheritDoc} */
        @Override
        public int getParameterSize() {
            return problem.getParameterSize();
        }

        /** {@inheritDoc} */
        @Override
        public Evaluation evaluate(final RealVector point) {

            // perform the evaluation
            final Evaluation evaluation = problem.evaluate(point);

            // notify the observer
            if (observer != null) {
                try {
                    observer.evaluationPerformed(iterationsCounter.getCount(),
                                                 evaluationsCounter.getCount(),
                                                 orbit,
                                                 estimatedOrbitalParameters,
                                                 estimatedPropagatorParameters,
                                                 estimatedMeasurementsParameters,
                                                 new Provider(),
                                                 evaluation);
                } catch (OrekitException oe) {
                    throw new OrekitExceptionWrapper(oe);
                }
            }

            return evaluation;

        }

    }

    /** Provider for evaluations. */
    private class Provider implements EstimationsProvider {

        /** Sorted estimations. */
        private EstimatedMeasurement<?>[] sortedEstimations;

        /** {@inheritDoc} */
        @Override
        public int getNumber() {
            return estimations.size();
        }

        /** {@inheritDoc} */
        @Override
        public EstimatedMeasurement<?> getEstimatedMeasurement(final int index)
            throws OrekitException {

            // safety checks
            if (index < 0 || index >= estimations.size()) {
                throw new OrekitException(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE,
                                          index, 0, estimations.size());
            }

            if (sortedEstimations == null) {

                // lazy evaluation of the sorted array
                sortedEstimations = new EstimatedMeasurement<?>[estimations.size()];
                int i = 0;
                for (final Map.Entry<ObservedMeasurement<?>, EstimatedMeasurement<?>> entry : estimations.entrySet()) {
                    sortedEstimations[i++] = entry.getValue();
                }

                // sort the array chronologically
                Arrays.sort(sortedEstimations, 0, sortedEstimations.length,
                            new ChronologicalComparator());

            }

            return sortedEstimations[index];

        }

    }

    /** Validator for estimated parameters. */
    private static class Validator implements ParameterValidator {

        /** Estimated orbital parameters. */
        private final ParameterDriversList estimatedOrbitalParameters;

        /** Estimated propagator parameters. */
        private final ParameterDriversList estimatedPropagatorParameters;

        /** Estimated measurements parameters. */
        private final ParameterDriversList estimatedMeasurementsParameters;

        /** Simple constructor.
         * @param estimatedOrbitalParameters estimated orbital parameters
         * @param estimatedPropagatorParameters estimated propagator parameters
         * @param estimatedMeasurementsParameters estimated measurements parameters
         */
        Validator(final ParameterDriversList estimatedOrbitalParameters,
                  final ParameterDriversList estimatedPropagatorParameters,
                  final ParameterDriversList estimatedMeasurementsParameters) {
            this.estimatedOrbitalParameters      = estimatedOrbitalParameters;
            this.estimatedPropagatorParameters   = estimatedPropagatorParameters;
            this.estimatedMeasurementsParameters = estimatedMeasurementsParameters;
        }

        /** {@inheritDoc} */
        @Override
        public RealVector validate(final RealVector params)
            throws OrekitExceptionWrapper {

            try {
                int i = 0;
                for (final ParameterDriver driver : estimatedOrbitalParameters.getDrivers()) {
                    // let the parameter handle min/max clipping
                    driver.setNormalizedValue(params.getEntry(i));
                    params.setEntry(i++, driver.getNormalizedValue());
                }
                for (final ParameterDriver driver : estimatedPropagatorParameters.getDrivers()) {
                    // let the parameter handle min/max clipping
                    driver.setNormalizedValue(params.getEntry(i));
                    params.setEntry(i++, driver.getNormalizedValue());
                }
                for (final ParameterDriver driver : estimatedMeasurementsParameters.getDrivers()) {
                    // let the parameter handle min/max clipping
                    driver.setNormalizedValue(params.getEntry(i));
                    params.setEntry(i++, driver.getNormalizedValue());
                }

                return params;
            } catch (OrekitException oe) {
                throw new OrekitExceptionWrapper(oe);
            }
        }
    }

}