1   /* Copyright 2002-2025 CS GROUP
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.propagation.numerical;
18  
19  import org.hipparchus.analysis.differentiation.Gradient;
20  import org.hipparchus.exception.LocalizedCoreFormats;
21  import org.hipparchus.linear.DecompositionSolver;
22  import org.hipparchus.linear.MatrixUtils;
23  import org.hipparchus.linear.QRDecomposition;
24  import org.hipparchus.linear.RealMatrix;
25  import org.hipparchus.util.Precision;
26  import org.orekit.attitudes.AttitudeProvider;
27  import org.orekit.attitudes.AttitudeProviderModifier;
28  import org.orekit.errors.OrekitException;
29  import org.orekit.forces.ForceModel;
30  import org.orekit.orbits.Orbit;
31  import org.orekit.orbits.OrbitType;
32  import org.orekit.orbits.PositionAngleType;
33  import org.orekit.propagation.FieldSpacecraftState;
34  import org.orekit.propagation.SpacecraftState;
35  import org.orekit.propagation.integration.AdditionalDerivativesProvider;
36  import org.orekit.propagation.integration.CombinedDerivatives;
37  import org.orekit.utils.DataDictionary;
38  import org.orekit.utils.ParameterDriver;
39  import org.orekit.utils.TimeSpanMap;
40  
41  import java.util.HashMap;
42  import java.util.List;
43  import java.util.Map;
44  
45  /** Abstract generator for numerical State Transition Matrix.
46   * @author Luc Maisonobe
47   * @author Melina Vanel
48   * @author Romain Serra
49   * @since 13.1
50   */
51  abstract class AbstractStateTransitionMatrixGenerator implements AdditionalDerivativesProvider {
52  
53      /** Space dimension. */
54      protected static final int SPACE_DIMENSION = 3;
55  
56      /** Threshold for matrix solving. */
57      private static final double THRESHOLD = Precision.SAFE_MIN;
58  
59      /** Name of the Cartesian STM additional state. */
60      private final String stmName;
61  
62      /** Force models used in propagation. */
63      private final List<ForceModel> forceModels;
64  
65      /** Attitude provider used in propagation. */
66      private final AttitudeProvider attitudeProvider;
67  
68      /** Observers for partial derivatives. */
69      private final Map<String, PartialsObserver> partialsObservers;
70  
71      /** Number of state variables. */
72      private final int stateDimension;
73  
74      /** Dimension of flatten STM. */
75      private final int dimension;
76  
77      /** Simple constructor.
78       * @param stmName name of the Cartesian STM additional state
79       * @param forceModels force models used in propagation
80       * @param attitudeProvider attitude provider used in propagation
81       * @param stateDimension dimension of state vector
82       */
83      AbstractStateTransitionMatrixGenerator(final String stmName, final List<ForceModel> forceModels,
84                                             final AttitudeProvider attitudeProvider, final int stateDimension) {
85          this.stmName           = stmName;
86          this.forceModels       = forceModels;
87          this.attitudeProvider  = attitudeProvider;
88          this.stateDimension    = stateDimension;
89          this.dimension         = stateDimension * stateDimension;
90          this.partialsObservers = new HashMap<>();
91      }
92  
93      /** Register an observer for partial derivatives.
94       * <p>
95       * The observer {@link PartialsObserver#partialsComputed(SpacecraftState, double[], double[])} partialsComputed}
96       * method will be called when partial derivatives are computed, as a side effect of
97       * calling {@link #computePartials(SpacecraftState)} (SpacecraftState)}
98       * </p>
99       * @param name name of the parameter driver this observer is interested in (may be null)
100      * @param observer observer to register
101      */
102     void addObserver(final String name, final PartialsObserver observer) {
103         partialsObservers.put(name, observer);
104     }
105 
106     /** {@inheritDoc} */
107     @Override
108     public String getName() {
109         return stmName;
110     }
111 
112     /** {@inheritDoc} */
113     @Override
114     public int getDimension() {
115         return dimension;
116     }
117 
118     /**
119      * Getter for the number of state variables.
120      * @return state vector dimension
121      */
122     public int getStateDimension() {
123         return stateDimension;
124     }
125 
126     /**
127      * Protected getter for the force models.
128      * @return forces
129      */
130     protected List<ForceModel> getForceModels() {
131         return forceModels;
132     }
133 
134     /**
135      * Protected getter for the partials observers map.
136      * @return map
137      */
138     protected Map<String, PartialsObserver> getPartialsObservers() {
139         return partialsObservers;
140     }
141 
142     /**
143      * Method to build a linear system solver.
144      * @param matrix equations matrix
145      * @return solver
146      */
147     private DecompositionSolver getDecompositionSolver(final RealMatrix matrix) {
148         return new QRDecomposition(matrix, THRESHOLD).getSolver();
149     }
150 
151     /** Set the initial value of the State Transition Matrix.
152      * <p>
153      * The returned state must be added to the propagator.
154      * </p>
155      * @param state initial state
156      * @param dYdY0 initial State Transition Matrix ∂Y/∂Y₀,
157      * if null (which is the most frequent case), assumed to be 6x6 identity
158      * @param orbitType orbit type used for states Y and Y₀ in {@code dYdY0}
159      * @param positionAngleType position angle used states Y and Y₀ in {@code dYdY0}
160      * @return state with initial STM (converted to Cartesian ∂C/∂Y₀) added
161      */
162     SpacecraftState setInitialStateTransitionMatrix(final SpacecraftState state, final RealMatrix dYdY0,
163                                                     final OrbitType orbitType,
164                                                     final PositionAngleType positionAngleType) {
165 
166         final RealMatrix nonNullDYdY0;
167         if (dYdY0 == null) {
168             nonNullDYdY0 = MatrixUtils.createRealIdentityMatrix(getStateDimension());
169         } else {
170             if (dYdY0.getRowDimension() != getStateDimension() ||
171                     dYdY0.getColumnDimension() != getStateDimension()) {
172                 throw new OrekitException(LocalizedCoreFormats.DIMENSIONS_MISMATCH_2x2,
173                         dYdY0.getRowDimension(), dYdY0.getColumnDimension(),
174                         getStateDimension(), getStateDimension());
175             }
176             nonNullDYdY0 = dYdY0;
177         }
178 
179         // convert to Cartesian STM
180         final RealMatrix dCdY0;
181         if (state.isOrbitDefined()) {
182             final RealMatrix dYdC = MatrixUtils.createRealIdentityMatrix(getStateDimension());
183             final Orbit orbit = orbitType.convertType(state.getOrbit());
184             final double[][] jacobian = new double[6][6];
185             orbit.getJacobianWrtCartesian(positionAngleType, jacobian);
186             dYdC.setSubMatrix(jacobian, 0, 0);
187             final DecompositionSolver decomposition = getDecompositionSolver(dYdC);
188             dCdY0 = decomposition.solve(nonNullDYdY0);
189         } else {
190             dCdY0 = nonNullDYdY0;
191         }
192 
193         // set additional state
194         return state.addAdditionalData(getName(), flatten(dCdY0));
195 
196     }
197 
198     /**
199      * Flattens a matrix into an 1-D array.
200      * @param matrix matrix to be flatten
201      * @return array
202      */
203     double[] flatten(final RealMatrix matrix) {
204         final double[] flat = new double[getDimension()];
205         int k = 0;
206         for (int i = 0; i < getStateDimension(); ++i) {
207             for (int j = 0; j < getStateDimension(); ++j) {
208                 flat[k++] = matrix.getEntry(i, j);
209             }
210         }
211         return flat;
212     }
213 
214     /** {@inheritDoc} */
215     @Override
216     public boolean yields(final SpacecraftState state) {
217         return !state.hasAdditionalData(getName());
218     }
219 
220     /** {@inheritDoc} */
221     public CombinedDerivatives combinedDerivatives(final SpacecraftState state) {
222         final double[] factor = computePartials(state);
223 
224         // retrieve current State Transition Matrix
225         final double[] p    = state.getAdditionalState(getName());
226         final double[] pDot = new double[p.length];
227 
228         // perform multiplication
229         multiplyMatrix(factor, p, pDot, getStateDimension());
230 
231         return new CombinedDerivatives(pDot, null);
232 
233     }
234 
235     /** Compute evolution matrix product.
236      * @param factor factor matrix
237      * @param x right factor of the multiplication, as a flatten array in row major order
238      * @param y placeholder where to put the result, as a flatten array in row major order
239      * @param columns number of columns of both x and y (so their dimensions are the state one times the columns)
240      */
241     abstract void multiplyMatrix(double[] factor, double[] x, double[] y, int columns);
242 
243     /** Compute the various partial derivatives.
244      * @param state current spacecraft state
245      * @return factor matrix
246      */
247     double[] computePartials(final SpacecraftState state) {
248 
249         // set up containers for partial derivatives
250         final double[]              factor               = new double[(stateDimension - SPACE_DIMENSION) * stateDimension];
251         final Map<String, double[]> partialsDictionary = new HashMap<>();
252 
253         // evaluate contribution of all force models
254         final AttitudeProvider equivalentAttitudeProvider = wrapAttitudeProviderIfPossible();
255         final boolean isThereAnyForceNotDependingOnlyOnPosition = getForceModels().stream().anyMatch(force -> !force.dependsOnPositionOnly());
256         final NumericalGradientConverter posOnlyConverter = new NumericalGradientConverter(state, SPACE_DIMENSION, equivalentAttitudeProvider);
257         final NumericalGradientConverter fullConverter = isThereAnyForceNotDependingOnlyOnPosition ?
258                 new NumericalGradientConverter(state, getStateDimension(), equivalentAttitudeProvider) : posOnlyConverter;
259         final SpacecraftState stateForParameters = state.withAdditionalData(new LocalDoubleArrayDictionary(state.getAdditionalDataValues()));
260 
261         for (final ForceModel forceModel : getForceModels()) {
262 
263             final NumericalGradientConverter     converter    = forceModel.dependsOnPositionOnly() ? posOnlyConverter : fullConverter;
264             final FieldSpacecraftState<Gradient> dsState      = converter.getState(forceModel);
265             final Gradient[]                     parameters   = converter.getParametersAtStateDate(dsState, forceModel);
266 
267             // update partial derivatives w.r.t. state variables
268             final Gradient[] ratesPartials = computeRatesPartialsAndUpdateFactor(forceModel, dsState, parameters, factor);
269 
270             // partials derivatives with respect to parameters
271             updateFactorForParameters(forceModel, converter, ratesPartials, partialsDictionary, stateForParameters, factor);
272 
273         }
274 
275         return factor;
276 
277     }
278 
279     /**
280      * Compute with automatic differentiation the partial derivatives of state variables' rate
281      * that are not part of the position vector.
282      * @param forceModel force model
283      * @param fieldState state in Taylor differential algebra
284      * @param parameters force parameters in Taylor differential algebra
285      * @param factor factor matrix to update
286      * @return array of rates in Taylor differential algebra
287      */
288     abstract Gradient[] computeRatesPartialsAndUpdateFactor(ForceModel forceModel,
289                                                             FieldSpacecraftState<Gradient> fieldState,
290                                                             Gradient[] parameters, double[] factor);
291 
292     /**
293      * Update factor regarding partials of force model parameters.
294      * @param forceModel force
295      * @param converter gradient converter
296      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
297      * @param partialsDictionary dictionary storing the partials
298      * @param state spacecraft state
299      * @param factor factor matrix (flattened)
300      */
301     private void updateFactorForParameters(final ForceModel forceModel, final NumericalGradientConverter converter,
302                                            final Gradient[] ratesPartials, final Map<String, double[]> partialsDictionary,
303                                            final SpacecraftState state, final double[] factor) {
304         int paramsIndex = converter.getFreeStateParameters();
305         for (ParameterDriver driver : forceModel.getParametersDrivers()) {
306             if (driver.isSelected()) {
307 
308                 // for each span (for each estimated value) corresponding name is added
309                 for (TimeSpanMap.Span<String> span = driver.getNamesSpanMap().getFirstSpan(); span != null; span = span.next()) {
310                     updateDictionaryEntry(partialsDictionary, span, ratesPartials, paramsIndex);
311                     ++paramsIndex;
312                 }
313             }
314         }
315 
316         // notify observers
317         for (Map.Entry<String, PartialsObserver> observersEntry : getPartialsObservers().entrySet()) {
318             observersEntry.getValue().partialsComputed(state, factor,
319                     partialsDictionary.getOrDefault(observersEntry.getKey(), new double[ratesPartials.length]));
320         }
321     }
322 
323     /**
324      * Update entry of dictionary with derivative information.
325      * @param partialsDictionary dictionary
326      * @param span time span
327      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
328      * @param paramsIndex index of parameter as an independent variable of the differential algebra
329      */
330     private void updateDictionaryEntry(final Map<String, double[]> partialsDictionary, final TimeSpanMap.Span<String> span,
331                                        final Gradient[] ratesPartials, final int paramsIndex) {
332         // get the partials derivatives for this driver
333         partialsDictionary.putIfAbsent(span.getData(), new double[ratesPartials.length]);
334 
335         // add the contribution of the current force model
336         final double[] increment = partialsDictionary.get(span.getData());
337         for (int i = 0; i < ratesPartials.length; ++i) {
338             increment[i] += ratesPartials[i].getGradient()[paramsIndex];
339         }
340         partialsDictionary.replace(span.getData(), increment);
341     }
342 
343     /**
344      * Method that first checks if it is possible to replace the attitude provider with a computationally cheaper one
345      * to evaluate. If applicable, the new provider only computes the rotation and uses dummy rate and acceleration,
346      * since they should not be used later on.
347      * @return same provider if at least one forces used attitude derivatives, otherwise one wrapping the old one for
348      * the rotation
349      */
350     AttitudeProvider wrapAttitudeProviderIfPossible() {
351         if (forceModels.stream().anyMatch(ForceModel::dependsOnAttitudeRate)) {
352             // at least one force uses an attitude rate, need to keep the original provider
353             return attitudeProvider;
354         } else {
355             // the original provider can be replaced by a lighter one for performance
356             return AttitudeProviderModifier.getFrozenAttitudeProvider(attitudeProvider);
357         }
358     }
359 
360     /** Interface for observing partials derivatives. */
361     @FunctionalInterface
362     public interface PartialsObserver {
363 
364         /** Callback called when partial derivatives have been computed.
365          * @param state current spacecraft state
366          * @param factor factor matrix, flattened along rows
367          * @param partials partials derivatives of all state variables' rates (except from position) w.r.t. the parameter driver
368          * that was registered (zero if no parameters were not selected or parameter is unknown)
369          */
370         void partialsComputed(SpacecraftState state, double[] factor, double[] partials);
371 
372     }
373 
374     /**
375      * Local override of data dictionary using HashMap for performance.
376      */
377     private static class LocalDoubleArrayDictionary extends DataDictionary {
378 
379         /** Serialization UID. */
380         private static final long serialVersionUID = 1L;
381 
382         /** Map for quick access. */
383         private final transient Map<String, Object> objectMap;
384 
385         /**
386          * Constructor.
387          * @param inputDictionary dictionary whose content is to reproduce
388          */
389         LocalDoubleArrayDictionary(final DataDictionary inputDictionary) {
390             super(inputDictionary);
391             objectMap = toMap();
392         }
393 
394         @Override
395         public Object get(final String key) {
396             return objectMap.get(key);
397         }
398     }
399 }
400