001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.controller;
006
007import edu.wpi.first.math.InterpolatingMatrixTreeMap;
008import edu.wpi.first.math.MatBuilder;
009import edu.wpi.first.math.MathUtil;
010import edu.wpi.first.math.Matrix;
011import edu.wpi.first.math.Nat;
012import edu.wpi.first.math.StateSpaceUtil;
013import edu.wpi.first.math.Vector;
014import edu.wpi.first.math.geometry.Pose2d;
015import edu.wpi.first.math.numbers.N1;
016import edu.wpi.first.math.numbers.N2;
017import edu.wpi.first.math.numbers.N5;
018import edu.wpi.first.math.system.LinearSystem;
019import edu.wpi.first.math.trajectory.Trajectory;
020
021/**
022 * The linear time-varying differential drive controller has a similar form to the LQR, but the
023 * model used to compute the controller gain is the nonlinear model linearized around the
024 * drivetrain's current state. We precomputed gains for important places in our state-space, then
025 * interpolated between them with a LUT to save computational resources.
026 *
027 * <p>See section 8.7 in Controls Engineering in FRC for a derivation of the control law we used
028 * shown in theorem 8.7.4.
029 */
030public class LTVDifferentialDriveController {
031  private final double m_trackwidth;
032
033  // LUT from drivetrain linear velocity to LQR gain
034  private final InterpolatingMatrixTreeMap<Double, N2, N5> m_table =
035      new InterpolatingMatrixTreeMap<>();
036
037  private Matrix<N5, N1> m_error = new Matrix<>(Nat.N5(), Nat.N1());
038  private Matrix<N5, N1> m_tolerance = new Matrix<>(Nat.N5(), Nat.N1());
039
040  /** States of the drivetrain system. */
041  private enum State {
042    kX(0),
043    kY(1),
044    kHeading(2),
045    kLeftVelocity(3),
046    kRightVelocity(4);
047
048    public final int value;
049
050    State(int i) {
051      this.value = i;
052    }
053  }
054
055  /**
056   * Constructs a linear time-varying differential drive controller.
057   *
058   * @param plant The differential drive velocity plant.
059   * @param trackwidth The distance between the differential drive's left and right wheels in
060   *     meters.
061   * @param qelems The maximum desired error tolerance for each state.
062   * @param relems The maximum desired control effort for each input.
063   * @param dt Discretization timestep in seconds.
064   * @throws IllegalArgumentException if max velocity of plant with 12 V input &lt;= 0.
065   */
066  public LTVDifferentialDriveController(
067      LinearSystem<N2, N2, N2> plant,
068      double trackwidth,
069      Vector<N5> qelems,
070      Vector<N2> relems,
071      double dt) {
072    m_trackwidth = trackwidth;
073
074    // Control law derivation is in section 8.7 of
075    // https://file.tavsys.net/control/controls-engineering-in-frc.pdf
076    var A =
077        new MatBuilder<>(Nat.N5(), Nat.N5())
078            .fill(
079                0.0,
080                0.0,
081                0.0,
082                0.5,
083                0.5,
084                0.0,
085                0.0,
086                0.0,
087                0.0,
088                0.0,
089                0.0,
090                0.0,
091                0.0,
092                -1.0 / m_trackwidth,
093                1.0 / m_trackwidth,
094                0.0,
095                0.0,
096                0.0,
097                plant.getA(0, 0),
098                plant.getA(0, 1),
099                0.0,
100                0.0,
101                0.0,
102                plant.getA(1, 0),
103                plant.getA(1, 1));
104    var B =
105        new MatBuilder<>(Nat.N5(), Nat.N2())
106            .fill(
107                0.0,
108                0.0,
109                0.0,
110                0.0,
111                0.0,
112                0.0,
113                plant.getB(0, 0),
114                plant.getB(0, 1),
115                plant.getB(1, 0),
116                plant.getB(1, 1));
117    var Q = StateSpaceUtil.makeCostMatrix(qelems);
118    var R = StateSpaceUtil.makeCostMatrix(relems);
119
120    // dx/dt = Ax + Bu
121    // 0 = Ax + Bu
122    // Ax = -Bu
123    // x = -A⁻¹Bu
124    double maxV =
125        plant
126            .getA()
127            .solve(plant.getB().times(new MatBuilder<>(Nat.N2(), Nat.N1()).fill(12.0, 12.0)))
128            .times(-1.0)
129            .get(0, 0);
130
131    if (maxV <= 0.0) {
132      throw new IllegalArgumentException(
133          "Max velocity of plant with 12 V input must be greater than zero.");
134    }
135
136    for (double velocity = -maxV; velocity < maxV; velocity += 0.01) {
137      // The DARE is ill-conditioned if the velocity is close to zero, so don't
138      // let the system stop.
139      if (Math.abs(velocity) < 1e-4) {
140        m_table.put(velocity, new Matrix<>(Nat.N2(), Nat.N5()));
141      } else {
142        A.set(State.kY.value, State.kHeading.value, velocity);
143        m_table.put(velocity, new LinearQuadraticRegulator<N5, N2, N5>(A, B, Q, R, dt).getK());
144      }
145    }
146  }
147
148  /**
149   * Returns true if the pose error is within tolerance of the reference.
150   *
151   * @return True if the pose error is within tolerance of the reference.
152   */
153  public boolean atReference() {
154    return Math.abs(m_error.get(0, 0)) < m_tolerance.get(0, 0)
155        && Math.abs(m_error.get(1, 0)) < m_tolerance.get(1, 0)
156        && Math.abs(m_error.get(2, 0)) < m_tolerance.get(2, 0)
157        && Math.abs(m_error.get(3, 0)) < m_tolerance.get(3, 0)
158        && Math.abs(m_error.get(4, 0)) < m_tolerance.get(4, 0);
159  }
160
161  /**
162   * Sets the pose error which is considered tolerable for use with atReference().
163   *
164   * @param poseTolerance Pose error which is tolerable.
165   * @param leftVelocityTolerance Left velocity error which is tolerable in meters per second.
166   * @param rightVelocityTolerance Right velocity error which is tolerable in meters per second.
167   */
168  public void setTolerance(
169      Pose2d poseTolerance, double leftVelocityTolerance, double rightVelocityTolerance) {
170    m_tolerance =
171        new MatBuilder<>(Nat.N5(), Nat.N1())
172            .fill(
173                poseTolerance.getX(),
174                poseTolerance.getY(),
175                poseTolerance.getRotation().getRadians(),
176                leftVelocityTolerance,
177                rightVelocityTolerance);
178  }
179
180  /**
181   * Returns the left and right output voltages of the LTV controller.
182   *
183   * <p>The reference pose, linear velocity, and angular velocity should come from a drivetrain
184   * trajectory.
185   *
186   * @param currentPose The current pose.
187   * @param leftVelocity The current left velocity in meters per second.
188   * @param rightVelocity The current right velocity in meters per second.
189   * @param poseRef The desired pose.
190   * @param leftVelocityRef The desired left velocity in meters per second.
191   * @param rightVelocityRef The desired right velocity in meters per second.
192   * @return Left and right output voltages of the LTV controller.
193   */
194  public DifferentialDriveWheelVoltages calculate(
195      Pose2d currentPose,
196      double leftVelocity,
197      double rightVelocity,
198      Pose2d poseRef,
199      double leftVelocityRef,
200      double rightVelocityRef) {
201    // This implements the linear time-varying differential drive controller in
202    // theorem 9.6.3 of https://tavsys.net/controls-in-frc.
203    var x =
204        new MatBuilder<>(Nat.N5(), Nat.N1())
205            .fill(
206                currentPose.getX(),
207                currentPose.getY(),
208                currentPose.getRotation().getRadians(),
209                leftVelocity,
210                rightVelocity);
211
212    var inRobotFrame = Matrix.eye(Nat.N5());
213    inRobotFrame.set(0, 0, Math.cos(x.get(State.kHeading.value, 0)));
214    inRobotFrame.set(0, 1, Math.sin(x.get(State.kHeading.value, 0)));
215    inRobotFrame.set(1, 0, -Math.sin(x.get(State.kHeading.value, 0)));
216    inRobotFrame.set(1, 1, Math.cos(x.get(State.kHeading.value, 0)));
217
218    var r =
219        new MatBuilder<>(Nat.N5(), Nat.N1())
220            .fill(
221                poseRef.getX(),
222                poseRef.getY(),
223                poseRef.getRotation().getRadians(),
224                leftVelocityRef,
225                rightVelocityRef);
226    m_error = r.minus(x);
227    m_error.set(
228        State.kHeading.value, 0, MathUtil.angleModulus(m_error.get(State.kHeading.value, 0)));
229
230    double velocity = (leftVelocity + rightVelocity) / 2.0;
231    var K = m_table.get(velocity);
232
233    var u = K.times(inRobotFrame).times(m_error);
234
235    return new DifferentialDriveWheelVoltages(u.get(0, 0), u.get(1, 0));
236  }
237
238  /**
239   * Returns the left and right output voltages of the LTV controller.
240   *
241   * <p>The reference pose, linear velocity, and angular velocity should come from a drivetrain
242   * trajectory.
243   *
244   * @param currentPose The current pose.
245   * @param leftVelocity The left velocity in meters per second.
246   * @param rightVelocity The right velocity in meters per second.
247   * @param desiredState The desired pose, linear velocity, and angular velocity from a trajectory.
248   * @return The left and right output voltages of the LTV controller.
249   */
250  public DifferentialDriveWheelVoltages calculate(
251      Pose2d currentPose,
252      double leftVelocity,
253      double rightVelocity,
254      Trajectory.State desiredState) {
255    // v = (v_r + v_l) / 2     (1)
256    // w = (v_r - v_l) / (2r)  (2)
257    // k = w / v               (3)
258    //
259    // v_l = v - wr
260    // v_l = v - (vk)r
261    // v_l = v(1 - kr)
262    //
263    // v_r = v + wr
264    // v_r = v + (vk)r
265    // v_r = v(1 + kr)
266    return calculate(
267        currentPose,
268        leftVelocity,
269        rightVelocity,
270        desiredState.poseMeters,
271        desiredState.velocityMetersPerSecond
272            * (1 - (desiredState.curvatureRadPerMeter * m_trackwidth / 2.0)),
273        desiredState.velocityMetersPerSecond
274            * (1 + (desiredState.curvatureRadPerMeter * m_trackwidth / 2.0)));
275  }
276}