001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math;
006
007import java.util.function.BiFunction;
008import org.ejml.data.DMatrixRMaj;
009import org.ejml.dense.row.NormOps_DDRM;
010import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
011import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
012import org.ejml.simple.SimpleBase;
013import org.ejml.simple.SimpleMatrix;
014
015public final class SimpleMatrixUtils {
016  private SimpleMatrixUtils() {}
017
018  /**
019   * Compute the matrix exponential, e^M of the given matrix.
020   *
021   * @param matrix The matrix to compute the exponential of.
022   * @return The resultant matrix.
023   */
024  public static SimpleMatrix expm(SimpleMatrix matrix) {
025    BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider = SimpleBase::solve;
026    SimpleMatrix A = matrix;
027    double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM());
028    int nSquarings = 0;
029
030    if (A_L1 < 1.495585217958292e-002) {
031      Pair<SimpleMatrix, SimpleMatrix> pair = pade3(A);
032      return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider);
033    } else if (A_L1 < 2.539398330063230e-001) {
034      Pair<SimpleMatrix, SimpleMatrix> pair = pade5(A);
035      return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider);
036    } else if (A_L1 < 9.504178996162932e-001) {
037      Pair<SimpleMatrix, SimpleMatrix> pair = pade7(A);
038      return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider);
039    } else if (A_L1 < 2.097847961257068e+000) {
040      Pair<SimpleMatrix, SimpleMatrix> pair = pade9(A);
041      return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider);
042    } else {
043      double maxNorm = 5.371920351148152;
044      double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg)
045      nSquarings = (int) Math.max(0, Math.ceil(log));
046      A = A.divide(Math.pow(2.0, nSquarings));
047      Pair<SimpleMatrix, SimpleMatrix> pair = pade13(A);
048      return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider);
049    }
050  }
051
052  private static SimpleMatrix dispatchPade(
053      SimpleMatrix U,
054      SimpleMatrix V,
055      int nSquarings,
056      BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) {
057    SimpleMatrix P = U.plus(V);
058    SimpleMatrix Q = U.negative().plus(V);
059
060    SimpleMatrix R = solveProvider.apply(Q, P);
061
062    for (int i = 0; i < nSquarings; i++) {
063      R = R.mult(R);
064    }
065
066    return R;
067  }
068
069  private static Pair<SimpleMatrix, SimpleMatrix> pade3(SimpleMatrix A) {
070    double[] b = new double[] {120, 60, 12, 1};
071    SimpleMatrix ident = eye(A.numRows(), A.numCols());
072
073    SimpleMatrix A2 = A.mult(A);
074    SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3])));
075    SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0]));
076    return new Pair<>(U, V);
077  }
078
079  private static Pair<SimpleMatrix, SimpleMatrix> pade5(SimpleMatrix A) {
080    double[] b = new double[] {30240, 15120, 3360, 420, 30, 1};
081    SimpleMatrix ident = eye(A.numRows(), A.numCols());
082    SimpleMatrix A2 = A.mult(A);
083    SimpleMatrix A4 = A2.mult(A2);
084
085    SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
086    SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
087
088    return new Pair<>(U, V);
089  }
090
091  private static Pair<SimpleMatrix, SimpleMatrix> pade7(SimpleMatrix A) {
092    double[] b = new double[] {17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1};
093    SimpleMatrix ident = eye(A.numRows(), A.numCols());
094    SimpleMatrix A2 = A.mult(A);
095    SimpleMatrix A4 = A2.mult(A2);
096    SimpleMatrix A6 = A4.mult(A2);
097
098    SimpleMatrix U =
099        A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
100    SimpleMatrix V =
101        A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
102
103    return new Pair<>(U, V);
104  }
105
106  private static Pair<SimpleMatrix, SimpleMatrix> pade9(SimpleMatrix A) {
107    double[] b =
108        new double[] {
109          17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1
110        };
111    SimpleMatrix ident = eye(A.numRows(), A.numCols());
112    SimpleMatrix A2 = A.mult(A);
113    SimpleMatrix A4 = A2.mult(A2);
114    SimpleMatrix A6 = A4.mult(A2);
115    SimpleMatrix A8 = A6.mult(A2);
116
117    SimpleMatrix U =
118        A.mult(
119            A8.scale(b[9])
120                .plus(A6.scale(b[7]))
121                .plus(A4.scale(b[5]))
122                .plus(A2.scale(b[3]))
123                .plus(ident.scale(b[1])));
124    SimpleMatrix V =
125        A8.scale(b[8])
126            .plus(A6.scale(b[6]))
127            .plus(A4.scale(b[4]))
128            .plus(A2.scale(b[2]))
129            .plus(ident.scale(b[0]));
130
131    return new Pair<>(U, V);
132  }
133
134  private static Pair<SimpleMatrix, SimpleMatrix> pade13(SimpleMatrix A) {
135    double[] b =
136        new double[] {
137          64764752532480000.0,
138          32382376266240000.0,
139          7771770303897600.0,
140          1187353796428800.0,
141          129060195264000.0,
142          10559470521600.0,
143          670442572800.0,
144          33522128640.0,
145          1323241920,
146          40840800,
147          960960,
148          16380,
149          182,
150          1
151        };
152    SimpleMatrix ident = eye(A.numRows(), A.numCols());
153
154    SimpleMatrix A2 = A.mult(A);
155    SimpleMatrix A4 = A2.mult(A2);
156    SimpleMatrix A6 = A4.mult(A2);
157
158    SimpleMatrix U =
159        A.mult(
160            A6.scale(b[13])
161                .plus(A4.scale(b[11]))
162                .plus(A2.scale(b[9]))
163                .plus(A6.scale(b[7]))
164                .plus(A4.scale(b[5]))
165                .plus(A2.scale(b[3]))
166                .plus(ident.scale(b[1])));
167    SimpleMatrix V =
168        A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8])))
169            .plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
170
171    return new Pair<>(U, V);
172  }
173
174  private static SimpleMatrix eye(int rows, int cols) {
175    return SimpleMatrix.identity(Math.min(rows, cols));
176  }
177
178  /**
179   * The identy of a square matrix.
180   *
181   * @param rows the number of rows (and columns)
182   * @return the identiy matrix, rows x rows.
183   */
184  public static SimpleMatrix eye(int rows) {
185    return SimpleMatrix.identity(rows);
186  }
187
188  /**
189   * Decompose the given matrix using Cholesky Decomposition and return a view of the upper
190   * triangular matrix (if you want lower triangular see the other overload of this method.) If the
191   * input matrix is zeros, this will return the zero matrix.
192   *
193   * @param src The matrix to decompose.
194   * @return The decomposed matrix.
195   * @throws RuntimeException if the matrix could not be decomposed (i.e. is not positive
196   *     semidefinite).
197   */
198  public static SimpleMatrix lltDecompose(SimpleMatrix src) {
199    return lltDecompose(src, false);
200  }
201
202  /**
203   * Decompose the given matrix using Cholesky Decomposition. If the input matrix is zeros, this
204   * will return the zero matrix.
205   *
206   * @param src The matrix to decompose.
207   * @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix.
208   * @return The decomposed matrix.
209   * @throws RuntimeException if the matrix could not be decomposed (i.e. is not positive
210   *     semidefinite).
211   */
212  public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) {
213    SimpleMatrix temp = src.copy();
214
215    CholeskyDecomposition_F64<DMatrixRMaj> chol =
216        DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
217    if (!chol.decompose(temp.getMatrix())) {
218      // check that the input is not all zeros -- if they are, we special case and return all
219      // zeros.
220      var matData = temp.getDDRM().data;
221      var isZeros = true;
222      for (double matDatum : matData) {
223        isZeros &= Math.abs(matDatum) < 1e-6;
224      }
225      if (isZeros) {
226        return new SimpleMatrix(temp.numRows(), temp.numCols());
227      }
228
229      throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n" + src.toString());
230    }
231
232    return SimpleMatrix.wrap(chol.getT(null));
233  }
234
235  /**
236   * Computes the matrix exponential using Eigen's solver.
237   *
238   * @param A the matrix to exponentiate.
239   * @return the exponential of A.
240   */
241  public static SimpleMatrix exp(SimpleMatrix A) {
242    SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows());
243    WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData());
244    return toReturn;
245  }
246}