001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.estimator;
006
007import edu.wpi.first.math.MathSharedStore;
008import edu.wpi.first.math.MathUtil;
009import edu.wpi.first.math.Matrix;
010import edu.wpi.first.math.Nat;
011import edu.wpi.first.math.VecBuilder;
012import edu.wpi.first.math.geometry.Pose2d;
013import edu.wpi.first.math.geometry.Rotation2d;
014import edu.wpi.first.math.geometry.Twist2d;
015import edu.wpi.first.math.interpolation.Interpolatable;
016import edu.wpi.first.math.interpolation.TimeInterpolatableBuffer;
017import edu.wpi.first.math.kinematics.SwerveDriveKinematics;
018import edu.wpi.first.math.kinematics.SwerveDriveOdometry;
019import edu.wpi.first.math.kinematics.SwerveModulePosition;
020import edu.wpi.first.math.numbers.N1;
021import edu.wpi.first.math.numbers.N3;
022import java.util.Arrays;
023import java.util.Map;
024import java.util.NoSuchElementException;
025import java.util.Objects;
026
027/**
028 * This class wraps {@link SwerveDriveOdometry Swerve Drive Odometry} to fuse latency-compensated
029 * vision measurements with swerve drive encoder distance measurements. It is intended to be a
030 * drop-in replacement for {@link edu.wpi.first.math.kinematics.SwerveDriveOdometry}.
031 *
032 * <p>{@link SwerveDrivePoseEstimator#update} should be called every robot loop.
033 *
034 * <p>{@link SwerveDrivePoseEstimator#addVisionMeasurement} can be called as infrequently as you
035 * want; if you never call it, then this class will behave as regular encoder odometry.
036 */
037public class SwerveDrivePoseEstimator {
038  private final SwerveDriveKinematics m_kinematics;
039  private final SwerveDriveOdometry m_odometry;
040  private final Matrix<N3, N1> m_q = new Matrix<>(Nat.N3(), Nat.N1());
041  private final int m_numModules;
042  private Matrix<N3, N3> m_visionK = new Matrix<>(Nat.N3(), Nat.N3());
043
044  private static final double kBufferDuration = 1.5;
045
046  private final TimeInterpolatableBuffer<InterpolationRecord> m_poseBuffer =
047      TimeInterpolatableBuffer.createBuffer(kBufferDuration);
048
049  /**
050   * Constructs a SwerveDrivePoseEstimator with default standard deviations for the model and vision
051   * measurements.
052   *
053   * <p>The default standard deviations of the model states are 0.1 meters for x, 0.1 meters for y,
054   * and 0.1 radians for heading. The default standard deviations of the vision measurements are 0.9
055   * meters for x, 0.9 meters for y, and 0.9 radians for heading.
056   *
057   * @param kinematics A correctly-configured kinematics object for your drivetrain.
058   * @param gyroAngle The current gyro angle.
059   * @param modulePositions The current distance measurements and rotations of the swerve modules.
060   * @param initialPoseMeters The starting pose estimate.
061   */
062  public SwerveDrivePoseEstimator(
063      SwerveDriveKinematics kinematics,
064      Rotation2d gyroAngle,
065      SwerveModulePosition[] modulePositions,
066      Pose2d initialPoseMeters) {
067    this(
068        kinematics,
069        gyroAngle,
070        modulePositions,
071        initialPoseMeters,
072        VecBuilder.fill(0.1, 0.1, 0.1),
073        VecBuilder.fill(0.9, 0.9, 0.9));
074  }
075
076  /**
077   * Constructs a SwerveDrivePoseEstimator.
078   *
079   * @param kinematics A correctly-configured kinematics object for your drivetrain.
080   * @param gyroAngle The current gyro angle.
081   * @param modulePositions The current distance and rotation measurements of the swerve modules.
082   * @param initialPoseMeters The starting pose estimate.
083   * @param stateStdDevs Standard deviations of the pose estimate (x position in meters, y position
084   *     in meters, and heading in radians). Increase these numbers to trust your state estimate
085   *     less.
086   * @param visionMeasurementStdDevs Standard deviations of the vision pose measurement (x position
087   *     in meters, y position in meters, and heading in radians). Increase these numbers to trust
088   *     the vision pose measurement less.
089   */
090  public SwerveDrivePoseEstimator(
091      SwerveDriveKinematics kinematics,
092      Rotation2d gyroAngle,
093      SwerveModulePosition[] modulePositions,
094      Pose2d initialPoseMeters,
095      Matrix<N3, N1> stateStdDevs,
096      Matrix<N3, N1> visionMeasurementStdDevs) {
097    m_kinematics = kinematics;
098    m_odometry = new SwerveDriveOdometry(kinematics, gyroAngle, modulePositions, initialPoseMeters);
099
100    for (int i = 0; i < 3; ++i) {
101      m_q.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0));
102    }
103
104    m_numModules = modulePositions.length;
105
106    setVisionMeasurementStdDevs(visionMeasurementStdDevs);
107  }
108
109  /**
110   * Sets the pose estimator's trust of global measurements. This might be used to change trust in
111   * vision measurements after the autonomous period, or to change trust as distance to a vision
112   * target increases.
113   *
114   * @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase these
115   *     numbers to trust global measurements from vision less. This matrix is in the form [x, y,
116   *     theta]ᵀ, with units in meters and radians.
117   */
118  public void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementStdDevs) {
119    var r = new double[3];
120    for (int i = 0; i < 3; ++i) {
121      r[i] = visionMeasurementStdDevs.get(i, 0) * visionMeasurementStdDevs.get(i, 0);
122    }
123
124    // Solve for closed form Kalman gain for continuous Kalman filter with A = 0
125    // and C = I. See wpimath/algorithms.md.
126    for (int row = 0; row < 3; ++row) {
127      if (m_q.get(row, 0) == 0.0) {
128        m_visionK.set(row, row, 0.0);
129      } else {
130        m_visionK.set(
131            row, row, m_q.get(row, 0) / (m_q.get(row, 0) + Math.sqrt(m_q.get(row, 0) * r[row])));
132      }
133    }
134  }
135
136  /**
137   * Resets the robot's position on the field.
138   *
139   * <p>The gyroscope angle does not need to be reset in the user's robot code. The library
140   * automatically takes care of offsetting the gyro angle.
141   *
142   * @param gyroAngle The angle reported by the gyroscope.
143   * @param modulePositions The current distance measurements and rotations of the swerve modules.
144   * @param poseMeters The position on the field that your robot is at.
145   */
146  public void resetPosition(
147      Rotation2d gyroAngle, SwerveModulePosition[] modulePositions, Pose2d poseMeters) {
148    // Reset state estimate and error covariance
149    m_odometry.resetPosition(gyroAngle, modulePositions, poseMeters);
150    m_poseBuffer.clear();
151  }
152
153  /**
154   * Gets the estimated robot pose.
155   *
156   * @return The estimated robot pose in meters.
157   */
158  public Pose2d getEstimatedPosition() {
159    return m_odometry.getPoseMeters();
160  }
161
162  /**
163   * Adds a vision measurement to the Kalman Filter. This will correct the odometry pose estimate
164   * while still accounting for measurement noise.
165   *
166   * <p>This method can be called as infrequently as you want, as long as you are calling {@link
167   * SwerveDrivePoseEstimator#update} every loop.
168   *
169   * <p>To promote stability of the pose estimate and make it robust to bad vision data, we
170   * recommend only adding vision measurements that are already within one meter or so of the
171   * current pose estimate.
172   *
173   * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
174   * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
175   *     don't use your own time source by calling {@link
176   *     SwerveDrivePoseEstimator#updateWithTime(double,Rotation2d,SwerveModulePosition[])} then you
177   *     must use a timestamp with an epoch since FPGA startup (i.e., the epoch of this timestamp is
178   *     the same epoch as {@link edu.wpi.first.wpilibj.Timer#getFPGATimestamp()}.) This means that
179   *     you should use {@link edu.wpi.first.wpilibj.Timer#getFPGATimestamp()} as your time source
180   *     or sync the epochs.
181   */
182  public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
183    // Step 0: If this measurement is old enough to be outside the pose buffer's timespan, skip.
184    try {
185      if (m_poseBuffer.getInternalBuffer().lastKey() - kBufferDuration > timestampSeconds) {
186        return;
187      }
188    } catch (NoSuchElementException ex) {
189      return;
190    }
191
192    // Step 1: Get the pose odometry measured at the moment the vision measurement was made.
193    var sample = m_poseBuffer.getSample(timestampSeconds);
194
195    if (sample.isEmpty()) {
196      return;
197    }
198
199    // Step 2: Measure the twist between the odometry pose and the vision pose.
200    var twist = sample.get().poseMeters.log(visionRobotPoseMeters);
201
202    // Step 3: We should not trust the twist entirely, so instead we scale this twist by a Kalman
203    // gain matrix representing how much we trust vision measurements compared to our current pose.
204    var k_times_twist = m_visionK.times(VecBuilder.fill(twist.dx, twist.dy, twist.dtheta));
205
206    // Step 4: Convert back to Twist2d.
207    var scaledTwist =
208        new Twist2d(k_times_twist.get(0, 0), k_times_twist.get(1, 0), k_times_twist.get(2, 0));
209
210    // Step 5: Reset Odometry to state at sample with vision adjustment.
211    m_odometry.resetPosition(
212        sample.get().gyroAngle,
213        sample.get().modulePositions,
214        sample.get().poseMeters.exp(scaledTwist));
215
216    // Step 6: Record the current pose to allow multiple measurements from the same timestamp
217    m_poseBuffer.addSample(
218        timestampSeconds,
219        new InterpolationRecord(
220            getEstimatedPosition(), sample.get().gyroAngle, sample.get().modulePositions));
221
222    // Step 7: Replay odometry inputs between sample time and latest recorded sample to update the
223    // pose buffer and correct odometry.
224    for (Map.Entry<Double, InterpolationRecord> entry :
225        m_poseBuffer.getInternalBuffer().tailMap(timestampSeconds).entrySet()) {
226      updateWithTime(entry.getKey(), entry.getValue().gyroAngle, entry.getValue().modulePositions);
227    }
228  }
229
230  /**
231   * Adds a vision measurement to the Kalman Filter. This will correct the odometry pose estimate
232   * while still accounting for measurement noise.
233   *
234   * <p>This method can be called as infrequently as you want, as long as you are calling {@link
235   * SwerveDrivePoseEstimator#update} every loop.
236   *
237   * <p>To promote stability of the pose estimate and make it robust to bad vision data, we
238   * recommend only adding vision measurements that are already within one meter or so of the
239   * current pose estimate.
240   *
241   * <p>Note that the vision measurement standard deviations passed into this method will continue
242   * to apply to future measurements until a subsequent call to {@link
243   * SwerveDrivePoseEstimator#setVisionMeasurementStdDevs(Matrix)} or this method.
244   *
245   * @param visionRobotPoseMeters The pose of the robot as measured by the vision camera.
246   * @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if you
247   *     don't use your own time source by calling {@link
248   *     SwerveDrivePoseEstimator#updateWithTime(double,Rotation2d,SwerveModulePosition[])}, then
249   *     you must use a timestamp with an epoch since FPGA startup (i.e., the epoch of this
250   *     timestamp is the same epoch as {@link edu.wpi.first.wpilibj.Timer#getFPGATimestamp()}).
251   *     This means that you should use {@link edu.wpi.first.wpilibj.Timer#getFPGATimestamp()} as
252   *     your time source in this case.
253   * @param visionMeasurementStdDevs Standard deviations of the vision pose measurement (x position
254   *     in meters, y position in meters, and heading in radians). Increase these numbers to trust
255   *     the vision pose measurement less.
256   */
257  public void addVisionMeasurement(
258      Pose2d visionRobotPoseMeters,
259      double timestampSeconds,
260      Matrix<N3, N1> visionMeasurementStdDevs) {
261    setVisionMeasurementStdDevs(visionMeasurementStdDevs);
262    addVisionMeasurement(visionRobotPoseMeters, timestampSeconds);
263  }
264
265  /**
266   * Updates the pose estimator with wheel encoder and gyro information. This should be called every
267   * loop.
268   *
269   * @param gyroAngle The current gyro angle.
270   * @param modulePositions The current distance measurements and rotations of the swerve modules.
271   * @return The estimated pose of the robot in meters.
272   */
273  public Pose2d update(Rotation2d gyroAngle, SwerveModulePosition[] modulePositions) {
274    return updateWithTime(MathSharedStore.getTimestamp(), gyroAngle, modulePositions);
275  }
276
277  /**
278   * Updates the pose estimator with wheel encoder and gyro information. This should be called every
279   * loop.
280   *
281   * @param currentTimeSeconds Time at which this method was called, in seconds.
282   * @param gyroAngle The current gyroscope angle.
283   * @param modulePositions The current distance measurements and rotations of the swerve modules.
284   * @return The estimated pose of the robot in meters.
285   */
286  public Pose2d updateWithTime(
287      double currentTimeSeconds, Rotation2d gyroAngle, SwerveModulePosition[] modulePositions) {
288    if (modulePositions.length != m_numModules) {
289      throw new IllegalArgumentException(
290          "Number of modules is not consistent with number of wheel locations provided in "
291              + "constructor");
292    }
293
294    var internalModulePositions = new SwerveModulePosition[m_numModules];
295
296    for (int i = 0; i < m_numModules; i++) {
297      internalModulePositions[i] =
298          new SwerveModulePosition(modulePositions[i].distanceMeters, modulePositions[i].angle);
299    }
300
301    m_odometry.update(gyroAngle, internalModulePositions);
302
303    m_poseBuffer.addSample(
304        currentTimeSeconds,
305        new InterpolationRecord(getEstimatedPosition(), gyroAngle, internalModulePositions));
306
307    return getEstimatedPosition();
308  }
309
310  /**
311   * Represents an odometry record. The record contains the inputs provided as well as the pose that
312   * was observed based on these inputs, as well as the previous record and its inputs.
313   */
314  private class InterpolationRecord implements Interpolatable<InterpolationRecord> {
315    // The pose observed given the current sensor inputs and the previous pose.
316    private final Pose2d poseMeters;
317
318    // The current gyro angle.
319    private final Rotation2d gyroAngle;
320
321    // The distances and rotations measured at each module.
322    private final SwerveModulePosition[] modulePositions;
323
324    /**
325     * Constructs an Interpolation Record with the specified parameters.
326     *
327     * @param pose The pose observed given the current sensor inputs and the previous pose.
328     * @param gyro The current gyro angle.
329     * @param wheelPositions The distances and rotations measured at each wheel.
330     */
331    private InterpolationRecord(
332        Pose2d poseMeters, Rotation2d gyro, SwerveModulePosition[] modulePositions) {
333      this.poseMeters = poseMeters;
334      this.gyroAngle = gyro;
335      this.modulePositions = modulePositions;
336    }
337
338    /**
339     * Return the interpolated record. This object is assumed to be the starting position, or lower
340     * bound.
341     *
342     * @param endValue The upper bound, or end.
343     * @param t How far between the lower and upper bound we are. This should be bounded in [0, 1].
344     * @return The interpolated value.
345     */
346    @Override
347    public InterpolationRecord interpolate(InterpolationRecord endValue, double t) {
348      if (t < 0) {
349        return this;
350      } else if (t >= 1) {
351        return endValue;
352      } else {
353        // Find the new wheel distances.
354        var modulePositions = new SwerveModulePosition[m_numModules];
355
356        // Find the distance travelled between this measurement and the interpolated measurement.
357        var moduleDeltas = new SwerveModulePosition[m_numModules];
358
359        for (int i = 0; i < m_numModules; i++) {
360          double ds =
361              MathUtil.interpolate(
362                  this.modulePositions[i].distanceMeters,
363                  endValue.modulePositions[i].distanceMeters,
364                  t);
365          Rotation2d theta =
366              this.modulePositions[i].angle.interpolate(endValue.modulePositions[i].angle, t);
367          modulePositions[i] = new SwerveModulePosition(ds, theta);
368          moduleDeltas[i] =
369              new SwerveModulePosition(ds - this.modulePositions[i].distanceMeters, theta);
370        }
371
372        // Find the new gyro angle.
373        var gyro_lerp = gyroAngle.interpolate(endValue.gyroAngle, t);
374
375        // Create a twist to represent this change based on the interpolated sensor inputs.
376        Twist2d twist = m_kinematics.toTwist2d(moduleDeltas);
377        twist.dtheta = gyro_lerp.minus(gyroAngle).getRadians();
378
379        return new InterpolationRecord(poseMeters.exp(twist), gyro_lerp, modulePositions);
380      }
381    }
382
383    @Override
384    public boolean equals(Object obj) {
385      if (this == obj) {
386        return true;
387      }
388      if (!(obj instanceof InterpolationRecord)) {
389        return false;
390      }
391      InterpolationRecord record = (InterpolationRecord) obj;
392      return Objects.equals(gyroAngle, record.gyroAngle)
393          && Arrays.equals(modulePositions, record.modulePositions)
394          && Objects.equals(poseMeters, record.poseMeters);
395    }
396
397    @Override
398    public int hashCode() {
399      return Objects.hash(gyroAngle, Arrays.hashCode(modulePositions), poseMeters);
400    }
401  }
402}