001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math;
006
007import org.ejml.simple.SimpleMatrix;
008
009public final class Drake {
010  private Drake() {}
011
012  /**
013   * Solves the discrete algebraic Riccati equation.
014   *
015   * @param A System matrix.
016   * @param B Input matrix.
017   * @param Q State cost matrix.
018   * @param R Input cost matrix.
019   * @return Solution of DARE.
020   */
021  public static SimpleMatrix discreteAlgebraicRiccatiEquation(
022      SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
023    var S = new SimpleMatrix(A.numRows(), A.numCols());
024    WPIMathJNI.discreteAlgebraicRiccatiEquation(
025        A.getDDRM().getData(),
026        B.getDDRM().getData(),
027        Q.getDDRM().getData(),
028        R.getDDRM().getData(),
029        A.numCols(),
030        B.numCols(),
031        S.getDDRM().getData());
032    return S;
033  }
034
035  /**
036   * Solves the discrete algebraic Riccati equation.
037   *
038   * @param <States> Number of states.
039   * @param <Inputs> Number of inputs.
040   * @param A System matrix.
041   * @param B Input matrix.
042   * @param Q State cost matrix.
043   * @param R Input cost matrix.
044   * @return Solution of DARE.
045   */
046  public static <States extends Num, Inputs extends Num>
047      Matrix<States, States> discreteAlgebraicRiccatiEquation(
048          Matrix<States, States> A,
049          Matrix<States, Inputs> B,
050          Matrix<States, States> Q,
051          Matrix<Inputs, Inputs> R) {
052    return new Matrix<>(
053        discreteAlgebraicRiccatiEquation(
054            A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage()));
055  }
056
057  /**
058   * Solves the discrete algebraic Riccati equation.
059   *
060   * @param A System matrix.
061   * @param B Input matrix.
062   * @param Q State cost matrix.
063   * @param R Input cost matrix.
064   * @param N State-input cross-term cost matrix.
065   * @return Solution of DARE.
066   */
067  public static SimpleMatrix discreteAlgebraicRiccatiEquation(
068      SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R, SimpleMatrix N) {
069    // See
070    // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
071    // for the change of variables used here.
072    var scrA = A.minus(B.mult(R.solve(N.transpose())));
073    var scrQ = Q.minus(N.mult(R.solve(N.transpose())));
074
075    var S = new SimpleMatrix(A.numRows(), A.numCols());
076    WPIMathJNI.discreteAlgebraicRiccatiEquation(
077        scrA.getDDRM().getData(),
078        B.getDDRM().getData(),
079        scrQ.getDDRM().getData(),
080        R.getDDRM().getData(),
081        A.numCols(),
082        B.numCols(),
083        S.getDDRM().getData());
084    return S;
085  }
086
087  /**
088   * Solves the discrete algebraic Riccati equation.
089   *
090   * @param <States> Number of states.
091   * @param <Inputs> Number of inputs.
092   * @param A System matrix.
093   * @param B Input matrix.
094   * @param Q State cost matrix.
095   * @param R Input cost matrix.
096   * @param N State-input cross-term cost matrix.
097   * @return Solution of DARE.
098   */
099  public static <States extends Num, Inputs extends Num>
100      Matrix<States, States> discreteAlgebraicRiccatiEquation(
101          Matrix<States, States> A,
102          Matrix<States, Inputs> B,
103          Matrix<States, States> Q,
104          Matrix<Inputs, Inputs> R,
105          Matrix<States, Inputs> N) {
106    // See
107    // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
108    // for the change of variables used here.
109    var scrA = A.minus(B.times(R.solve(N.transpose())));
110    var scrQ = Q.minus(N.times(R.solve(N.transpose())));
111
112    return new Matrix<>(
113        discreteAlgebraicRiccatiEquation(
114            scrA.getStorage(), B.getStorage(), scrQ.getStorage(), R.getStorage()));
115  }
116}