001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math; 006 007import java.util.function.BiFunction; 008import org.ejml.data.DMatrixRMaj; 009import org.ejml.dense.row.NormOps_DDRM; 010import org.ejml.dense.row.factory.DecompositionFactory_DDRM; 011import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64; 012import org.ejml.simple.SimpleBase; 013import org.ejml.simple.SimpleMatrix; 014 015public final class SimpleMatrixUtils { 016 private SimpleMatrixUtils() {} 017 018 /** 019 * Compute the matrix exponential, e^M of the given matrix. 020 * 021 * @param matrix The matrix to compute the exponential of. 022 * @return The resultant matrix. 023 */ 024 public static SimpleMatrix expm(SimpleMatrix matrix) { 025 BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider = SimpleBase::solve; 026 SimpleMatrix A = matrix; 027 double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM()); 028 int nSquarings = 0; 029 030 if (A_L1 < 1.495585217958292e-002) { 031 Pair<SimpleMatrix, SimpleMatrix> pair = pade3(A); 032 return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider); 033 } else if (A_L1 < 2.539398330063230e-001) { 034 Pair<SimpleMatrix, SimpleMatrix> pair = pade5(A); 035 return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider); 036 } else if (A_L1 < 9.504178996162932e-001) { 037 Pair<SimpleMatrix, SimpleMatrix> pair = pade7(A); 038 return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider); 039 } else if (A_L1 < 2.097847961257068e+000) { 040 Pair<SimpleMatrix, SimpleMatrix> pair = pade9(A); 041 return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider); 042 } else { 043 double maxNorm = 5.371920351148152; 044 double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg) 045 nSquarings = (int) Math.max(0, Math.ceil(log)); 046 A = A.divide(Math.pow(2.0, nSquarings)); 047 Pair<SimpleMatrix, SimpleMatrix> pair = pade13(A); 048 return dispatchPade(pair.getFirst(), pair.getSecond(), nSquarings, solveProvider); 049 } 050 } 051 052 private static SimpleMatrix dispatchPade( 053 SimpleMatrix U, 054 SimpleMatrix V, 055 int nSquarings, 056 BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) { 057 SimpleMatrix P = U.plus(V); 058 SimpleMatrix Q = U.negative().plus(V); 059 060 SimpleMatrix R = solveProvider.apply(Q, P); 061 062 for (int i = 0; i < nSquarings; i++) { 063 R = R.mult(R); 064 } 065 066 return R; 067 } 068 069 private static Pair<SimpleMatrix, SimpleMatrix> pade3(SimpleMatrix A) { 070 double[] b = new double[] {120, 60, 12, 1}; 071 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 072 073 SimpleMatrix A2 = A.mult(A); 074 SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3]))); 075 SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0])); 076 return new Pair<>(U, V); 077 } 078 079 private static Pair<SimpleMatrix, SimpleMatrix> pade5(SimpleMatrix A) { 080 double[] b = new double[] {30240, 15120, 3360, 420, 30, 1}; 081 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 082 SimpleMatrix A2 = A.mult(A); 083 SimpleMatrix A4 = A2.mult(A2); 084 085 SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); 086 SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0])); 087 088 return new Pair<>(U, V); 089 } 090 091 private static Pair<SimpleMatrix, SimpleMatrix> pade7(SimpleMatrix A) { 092 double[] b = new double[] {17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1}; 093 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 094 SimpleMatrix A2 = A.mult(A); 095 SimpleMatrix A4 = A2.mult(A2); 096 SimpleMatrix A6 = A4.mult(A2); 097 098 SimpleMatrix U = 099 A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); 100 SimpleMatrix V = 101 A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])); 102 103 return new Pair<>(U, V); 104 } 105 106 private static Pair<SimpleMatrix, SimpleMatrix> pade9(SimpleMatrix A) { 107 double[] b = 108 new double[] { 109 17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1 110 }; 111 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 112 SimpleMatrix A2 = A.mult(A); 113 SimpleMatrix A4 = A2.mult(A2); 114 SimpleMatrix A6 = A4.mult(A2); 115 SimpleMatrix A8 = A6.mult(A2); 116 117 SimpleMatrix U = 118 A.mult( 119 A8.scale(b[9]) 120 .plus(A6.scale(b[7])) 121 .plus(A4.scale(b[5])) 122 .plus(A2.scale(b[3])) 123 .plus(ident.scale(b[1]))); 124 SimpleMatrix V = 125 A8.scale(b[8]) 126 .plus(A6.scale(b[6])) 127 .plus(A4.scale(b[4])) 128 .plus(A2.scale(b[2])) 129 .plus(ident.scale(b[0])); 130 131 return new Pair<>(U, V); 132 } 133 134 private static Pair<SimpleMatrix, SimpleMatrix> pade13(SimpleMatrix A) { 135 double[] b = 136 new double[] { 137 64764752532480000.0, 138 32382376266240000.0, 139 7771770303897600.0, 140 1187353796428800.0, 141 129060195264000.0, 142 10559470521600.0, 143 670442572800.0, 144 33522128640.0, 145 1323241920, 146 40840800, 147 960960, 148 16380, 149 182, 150 1 151 }; 152 SimpleMatrix ident = eye(A.numRows(), A.numCols()); 153 154 SimpleMatrix A2 = A.mult(A); 155 SimpleMatrix A4 = A2.mult(A2); 156 SimpleMatrix A6 = A4.mult(A2); 157 158 SimpleMatrix U = 159 A.mult( 160 A6.scale(b[13]) 161 .plus(A4.scale(b[11])) 162 .plus(A2.scale(b[9])) 163 .plus(A6.scale(b[7])) 164 .plus(A4.scale(b[5])) 165 .plus(A2.scale(b[3])) 166 .plus(ident.scale(b[1]))); 167 SimpleMatrix V = 168 A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))) 169 .plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]))); 170 171 return new Pair<>(U, V); 172 } 173 174 private static SimpleMatrix eye(int rows, int cols) { 175 return SimpleMatrix.identity(Math.min(rows, cols)); 176 } 177 178 /** 179 * The identy of a square matrix. 180 * 181 * @param rows the number of rows (and columns) 182 * @return the identiy matrix, rows x rows. 183 */ 184 public static SimpleMatrix eye(int rows) { 185 return SimpleMatrix.identity(rows); 186 } 187 188 /** 189 * Decompose the given matrix using Cholesky Decomposition and return a view of the upper 190 * triangular matrix (if you want lower triangular see the other overload of this method.) If the 191 * input matrix is zeros, this will return the zero matrix. 192 * 193 * @param src The matrix to decompose. 194 * @return The decomposed matrix. 195 * @throws RuntimeException if the matrix could not be decomposed (i.e. is not positive 196 * semidefinite). 197 */ 198 public static SimpleMatrix lltDecompose(SimpleMatrix src) { 199 return lltDecompose(src, false); 200 } 201 202 /** 203 * Decompose the given matrix using Cholesky Decomposition. If the input matrix is zeros, this 204 * will return the zero matrix. 205 * 206 * @param src The matrix to decompose. 207 * @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix. 208 * @return The decomposed matrix. 209 * @throws RuntimeException if the matrix could not be decomposed (i.e. is not positive 210 * semidefinite). 211 */ 212 public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) { 213 SimpleMatrix temp = src.copy(); 214 215 CholeskyDecomposition_F64<DMatrixRMaj> chol = 216 DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular); 217 if (!chol.decompose(temp.getMatrix())) { 218 // check that the input is not all zeros -- if they are, we special case and return all 219 // zeros. 220 var matData = temp.getDDRM().data; 221 var isZeros = true; 222 for (double matDatum : matData) { 223 isZeros &= Math.abs(matDatum) < 1e-6; 224 } 225 if (isZeros) { 226 return new SimpleMatrix(temp.numRows(), temp.numCols()); 227 } 228 229 throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n" + src.toString()); 230 } 231 232 return SimpleMatrix.wrap(chol.getT(null)); 233 } 234 235 /** 236 * Computes the matrix exponential using Eigen's solver. 237 * 238 * @param A the matrix to exponentiate. 239 * @return the exponential of A. 240 */ 241 public static SimpleMatrix exp(SimpleMatrix A) { 242 SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows()); 243 WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData()); 244 return toReturn; 245 } 246}