001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.system; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Num; 009import edu.wpi.first.math.numbers.N1; 010import java.util.function.BiFunction; 011import java.util.function.DoubleFunction; 012import java.util.function.Function; 013 014public final class NumericalIntegration { 015 private NumericalIntegration() { 016 // utility Class 017 } 018 019 /** 020 * Performs Runge Kutta integration (4th order). 021 * 022 * @param f The function to integrate, which takes one argument x. 023 * @param x The initial value of x. 024 * @param dtSeconds The time over which to integrate. 025 * @return the integration of dx/dt = f(x) for dt. 026 */ 027 @SuppressWarnings("overloads") 028 public static double rk4(DoubleFunction<Double> f, double x, double dtSeconds) { 029 final var h = dtSeconds; 030 final var k1 = f.apply(x); 031 final var k2 = f.apply(x + h * k1 * 0.5); 032 final var k3 = f.apply(x + h * k2 * 0.5); 033 final var k4 = f.apply(x + h * k3); 034 035 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 036 } 037 038 /** 039 * Performs Runge Kutta integration (4th order). 040 * 041 * @param f The function to integrate. It must take two arguments x and u. 042 * @param x The initial value of x. 043 * @param u The value u held constant over the integration period. 044 * @param dtSeconds The time over which to integrate. 045 * @return The result of Runge Kutta integration (4th order). 046 */ 047 @SuppressWarnings("overloads") 048 public static double rk4( 049 BiFunction<Double, Double, Double> f, double x, Double u, double dtSeconds) { 050 final var h = dtSeconds; 051 052 final var k1 = f.apply(x, u); 053 final var k2 = f.apply(x + h * k1 * 0.5, u); 054 final var k3 = f.apply(x + h * k2 * 0.5, u); 055 final var k4 = f.apply(x + h * k3, u); 056 057 return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); 058 } 059 060 /** 061 * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. 062 * 063 * @param <States> A Num representing the states of the system to integrate. 064 * @param <Inputs> A Num representing the inputs of the system to integrate. 065 * @param f The function to integrate. It must take two arguments x and u. 066 * @param x The initial value of x. 067 * @param u The value u held constant over the integration period. 068 * @param dtSeconds The time over which to integrate. 069 * @return the integration of dx/dt = f(x, u) for dt. 070 */ 071 @SuppressWarnings("overloads") 072 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rk4( 073 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 074 Matrix<States, N1> x, 075 Matrix<Inputs, N1> u, 076 double dtSeconds) { 077 final var h = dtSeconds; 078 079 Matrix<States, N1> k1 = f.apply(x, u); 080 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5)), u); 081 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5)), u); 082 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h)), u); 083 084 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 085 } 086 087 /** 088 * Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 089 * 090 * @param <States> A Num prepresenting the states of the system. 091 * @param f The function to integrate. It must take one argument x. 092 * @param x The initial value of x. 093 * @param dtSeconds The time over which to integrate. 094 * @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt. 095 */ 096 @SuppressWarnings("overloads") 097 public static <States extends Num> Matrix<States, N1> rk4( 098 Function<Matrix<States, N1>, Matrix<States, N1>> f, Matrix<States, N1> x, double dtSeconds) { 099 final var h = dtSeconds; 100 101 Matrix<States, N1> k1 = f.apply(x); 102 Matrix<States, N1> k2 = f.apply(x.plus(k1.times(h * 0.5))); 103 Matrix<States, N1> k3 = f.apply(x.plus(k2.times(h * 0.5))); 104 Matrix<States, N1> k4 = f.apply(x.plus(k3.times(h))); 105 106 return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(h / 6.0)); 107 } 108 109 /** 110 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. By default, the max 111 * error is 1e-6. 112 * 113 * @param <States> A Num representing the states of the system to integrate. 114 * @param <Inputs> A Num representing the inputs of the system to integrate. 115 * @param f The function to integrate. It must take two arguments x and u. 116 * @param x The initial value of x. 117 * @param u The value u held constant over the integration period. 118 * @param dtSeconds The time over which to integrate. 119 * @return the integration of dx/dt = f(x, u) for dt. 120 */ 121 @SuppressWarnings("overloads") 122 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 123 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 124 Matrix<States, N1> x, 125 Matrix<Inputs, N1> u, 126 double dtSeconds) { 127 return rkdp(f, x, u, dtSeconds, 1e-6); 128 } 129 130 /** 131 * Performs adaptive Dormand-Prince integration of dx/dt = f(x, u) for dt. 132 * 133 * @param <States> A Num representing the states of the system to integrate. 134 * @param <Inputs> A Num representing the inputs of the system to integrate. 135 * @param f The function to integrate. It must take two arguments x and u. 136 * @param x The initial value of x. 137 * @param u The value u held constant over the integration period. 138 * @param dtSeconds The time over which to integrate. 139 * @param maxError The maximum acceptable truncation error. Usually a small number like 1e-6. 140 * @return the integration of dx/dt = f(x, u) for dt. 141 */ 142 @SuppressWarnings("overloads") 143 public static <States extends Num, Inputs extends Num> Matrix<States, N1> rkdp( 144 BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f, 145 Matrix<States, N1> x, 146 Matrix<Inputs, N1> u, 147 double dtSeconds, 148 double maxError) { 149 // See https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method for the 150 // Butcher tableau the following arrays came from. 151 152 // final double[6][6] 153 final double[][] A = { 154 {1.0 / 5.0}, 155 {3.0 / 40.0, 9.0 / 40.0}, 156 {44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0}, 157 {19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0}, 158 {9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0}, 159 {35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0} 160 }; 161 162 // final double[7] 163 final double[] b1 = { 164 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0 165 }; 166 167 // final double[7] 168 final double[] b2 = { 169 5179.0 / 57600.0, 170 0.0, 171 7571.0 / 16695.0, 172 393.0 / 640.0, 173 -92097.0 / 339200.0, 174 187.0 / 2100.0, 175 1.0 / 40.0 176 }; 177 178 Matrix<States, N1> newX; 179 double truncationError; 180 181 double dtElapsed = 0.0; 182 double h = dtSeconds; 183 184 // Loop until we've gotten to our desired dt 185 while (dtElapsed < dtSeconds) { 186 do { 187 // Only allow us to advance up to the dt remaining 188 h = Math.min(h, dtSeconds - dtElapsed); 189 190 var k1 = f.apply(x, u); 191 var k2 = f.apply(x.plus(k1.times(A[0][0]).times(h)), u); 192 var k3 = f.apply(x.plus(k1.times(A[1][0]).plus(k2.times(A[1][1])).times(h)), u); 193 var k4 = 194 f.apply( 195 x.plus(k1.times(A[2][0]).plus(k2.times(A[2][1])).plus(k3.times(A[2][2])).times(h)), 196 u); 197 var k5 = 198 f.apply( 199 x.plus( 200 k1.times(A[3][0]) 201 .plus(k2.times(A[3][1])) 202 .plus(k3.times(A[3][2])) 203 .plus(k4.times(A[3][3])) 204 .times(h)), 205 u); 206 var k6 = 207 f.apply( 208 x.plus( 209 k1.times(A[4][0]) 210 .plus(k2.times(A[4][1])) 211 .plus(k3.times(A[4][2])) 212 .plus(k4.times(A[4][3])) 213 .plus(k5.times(A[4][4])) 214 .times(h)), 215 u); 216 217 // Since the final row of A and the array b1 have the same coefficients 218 // and k7 has no effect on newX, we can reuse the calculation. 219 newX = 220 x.plus( 221 k1.times(A[5][0]) 222 .plus(k2.times(A[5][1])) 223 .plus(k3.times(A[5][2])) 224 .plus(k4.times(A[5][3])) 225 .plus(k5.times(A[5][4])) 226 .plus(k6.times(A[5][5])) 227 .times(h)); 228 var k7 = f.apply(newX, u); 229 230 truncationError = 231 (k1.times(b1[0] - b2[0]) 232 .plus(k2.times(b1[1] - b2[1])) 233 .plus(k3.times(b1[2] - b2[2])) 234 .plus(k4.times(b1[3] - b2[3])) 235 .plus(k5.times(b1[4] - b2[4])) 236 .plus(k6.times(b1[5] - b2[5])) 237 .plus(k7.times(b1[6] - b2[6])) 238 .times(h)) 239 .normF(); 240 241 if (truncationError == 0.0) { 242 h = dtSeconds - dtElapsed; 243 } else { 244 h *= 0.9 * Math.pow(maxError / truncationError, 1.0 / 5.0); 245 } 246 } while (truncationError > maxError); 247 248 dtElapsed += h; 249 x = newX; 250 } 251 252 return x; 253 } 254}