001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math;
006
007import org.ejml.simple.SimpleMatrix;
008
009public final class DARE {
010  private DARE() {
011    throw new UnsupportedOperationException("This is a utility class!");
012  }
013
014  /**
015   * Solves the discrete algebraic Riccati equation.
016   *
017   * @param A System matrix.
018   * @param B Input matrix.
019   * @param Q State cost matrix.
020   * @param R Input cost matrix.
021   * @return Solution of DARE.
022   */
023  public static SimpleMatrix dare(SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
024    var S = new SimpleMatrix(A.numRows(), A.numCols());
025    WPIMathJNI.dare(
026        A.getDDRM().getData(),
027        B.getDDRM().getData(),
028        Q.getDDRM().getData(),
029        R.getDDRM().getData(),
030        A.numCols(),
031        B.numCols(),
032        S.getDDRM().getData());
033    return S;
034  }
035
036  /**
037   * Solves the discrete algebraic Riccati equation.
038   *
039   * @param <States> Number of states.
040   * @param <Inputs> Number of inputs.
041   * @param A System matrix.
042   * @param B Input matrix.
043   * @param Q State cost matrix.
044   * @param R Input cost matrix.
045   * @return Solution of DARE.
046   */
047  public static <States extends Num, Inputs extends Num> Matrix<States, States> dare(
048      Matrix<States, States> A,
049      Matrix<States, Inputs> B,
050      Matrix<States, States> Q,
051      Matrix<Inputs, Inputs> R) {
052    return new Matrix<>(dare(A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage()));
053  }
054
055  /**
056   * Solves the discrete algebraic Riccati equation.
057   *
058   * @param A System matrix.
059   * @param B Input matrix.
060   * @param Q State cost matrix.
061   * @param R Input cost matrix.
062   * @param N State-input cross-term cost matrix.
063   * @return Solution of DARE.
064   */
065  public static SimpleMatrix dare(
066      SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R, SimpleMatrix N) {
067    // See
068    // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
069    // for the change of variables used here.
070    var scrA = A.minus(B.mult(R.solve(N.transpose())));
071    var scrQ = Q.minus(N.mult(R.solve(N.transpose())));
072
073    var S = new SimpleMatrix(A.numRows(), A.numCols());
074    WPIMathJNI.dare(
075        scrA.getDDRM().getData(),
076        B.getDDRM().getData(),
077        scrQ.getDDRM().getData(),
078        R.getDDRM().getData(),
079        A.numCols(),
080        B.numCols(),
081        S.getDDRM().getData());
082    return S;
083  }
084
085  /**
086   * Solves the discrete algebraic Riccati equation.
087   *
088   * @param <States> Number of states.
089   * @param <Inputs> Number of inputs.
090   * @param A System matrix.
091   * @param B Input matrix.
092   * @param Q State cost matrix.
093   * @param R Input cost matrix.
094   * @param N State-input cross-term cost matrix.
095   * @return Solution of DARE.
096   */
097  public static <States extends Num, Inputs extends Num> Matrix<States, States> dare(
098      Matrix<States, States> A,
099      Matrix<States, Inputs> B,
100      Matrix<States, States> Q,
101      Matrix<Inputs, Inputs> R,
102      Matrix<States, Inputs> N) {
103    // This is a change of variables to make the DARE that includes Q, R, and N
104    // cost matrices fit the form of the DARE that includes only Q and R cost
105    // matrices.
106    //
107    // This is equivalent to solving the original DARE:
108    //
109    //   A₂ᵀXA₂ − X − A₂ᵀXB(BᵀXB + R)⁻¹BᵀXA₂ + Q₂ = 0
110    //
111    // where A₂ and Q₂ are a change of variables:
112    //
113    //   A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ
114    return new Matrix<>(
115        dare(
116            A.minus(B.times(R.solve(N.transpose()))).getStorage(),
117            B.getStorage(),
118            Q.minus(N.times(R.solve(N.transpose()))).getStorage(),
119            R.getStorage()));
120  }
121}