001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.trajectory;
006
007import com.fasterxml.jackson.annotation.JsonProperty;
008import edu.wpi.first.math.geometry.Pose2d;
009import edu.wpi.first.math.geometry.Transform2d;
010import java.util.ArrayList;
011import java.util.List;
012import java.util.Objects;
013import java.util.stream.Collectors;
014
015/**
016 * Represents a time-parameterized trajectory. The trajectory contains of various States that
017 * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
018 */
019public class Trajectory {
020  private final double m_totalTimeSeconds;
021  private final List<State> m_states;
022
023  /** Constructs an empty trajectory. */
024  public Trajectory() {
025    m_states = new ArrayList<>();
026    m_totalTimeSeconds = 0.0;
027  }
028
029  /**
030   * Constructs a trajectory from a vector of states.
031   *
032   * @param states A vector of states.
033   */
034  public Trajectory(final List<State> states) {
035    m_states = states;
036    m_totalTimeSeconds = m_states.get(m_states.size() - 1).timeSeconds;
037  }
038
039  /**
040   * Linearly interpolates between two values.
041   *
042   * @param startValue The start value.
043   * @param endValue The end value.
044   * @param t The fraction for interpolation.
045   * @return The interpolated value.
046   */
047  private static double lerp(double startValue, double endValue, double t) {
048    return startValue + (endValue - startValue) * t;
049  }
050
051  /**
052   * Linearly interpolates between two poses.
053   *
054   * @param startValue The start pose.
055   * @param endValue The end pose.
056   * @param t The fraction for interpolation.
057   * @return The interpolated pose.
058   */
059  private static Pose2d lerp(Pose2d startValue, Pose2d endValue, double t) {
060    return startValue.plus((endValue.minus(startValue)).times(t));
061  }
062
063  /**
064   * Returns the initial pose of the trajectory.
065   *
066   * @return The initial pose of the trajectory.
067   */
068  public Pose2d getInitialPose() {
069    return sample(0).poseMeters;
070  }
071
072  /**
073   * Returns the overall duration of the trajectory.
074   *
075   * @return The duration of the trajectory.
076   */
077  public double getTotalTimeSeconds() {
078    return m_totalTimeSeconds;
079  }
080
081  /**
082   * Return the states of the trajectory.
083   *
084   * @return The states of the trajectory.
085   */
086  public List<State> getStates() {
087    return m_states;
088  }
089
090  /**
091   * Sample the trajectory at a point in time.
092   *
093   * @param timeSeconds The point in time since the beginning of the trajectory to sample.
094   * @return The state at that point in time.
095   */
096  public State sample(double timeSeconds) {
097    if (timeSeconds <= m_states.get(0).timeSeconds) {
098      return m_states.get(0);
099    }
100    if (timeSeconds >= m_totalTimeSeconds) {
101      return m_states.get(m_states.size() - 1);
102    }
103
104    // To get the element that we want, we will use a binary search algorithm
105    // instead of iterating over a for-loop. A binary search is O(std::log(n))
106    // whereas searching using a loop is O(n).
107
108    // This starts at 1 because we use the previous state later on for
109    // interpolation.
110    int low = 1;
111    int high = m_states.size() - 1;
112
113    while (low != high) {
114      int mid = (low + high) / 2;
115      if (m_states.get(mid).timeSeconds < timeSeconds) {
116        // This index and everything under it are less than the requested
117        // timestamp. Therefore, we can discard them.
118        low = mid + 1;
119      } else {
120        // t is at least as large as the element at this index. This means that
121        // anything after it cannot be what we are looking for.
122        high = mid;
123      }
124    }
125
126    // High and Low should be the same.
127
128    // The sample's timestamp is now greater than or equal to the requested
129    // timestamp. If it is greater, we need to interpolate between the
130    // previous state and the current state to get the exact state that we
131    // want.
132    final State sample = m_states.get(low);
133    final State prevSample = m_states.get(low - 1);
134
135    // If the difference in states is negligible, then we are spot on!
136    if (Math.abs(sample.timeSeconds - prevSample.timeSeconds) < 1E-9) {
137      return sample;
138    }
139    // Interpolate between the two states for the state that we want.
140    return prevSample.interpolate(
141        sample,
142        (timeSeconds - prevSample.timeSeconds) / (sample.timeSeconds - prevSample.timeSeconds));
143  }
144
145  /**
146   * Transforms all poses in the trajectory by the given transform. This is useful for converting a
147   * robot-relative trajectory into a field-relative trajectory. This works with respect to the
148   * first pose in the trajectory.
149   *
150   * @param transform The transform to transform the trajectory by.
151   * @return The transformed trajectory.
152   */
153  public Trajectory transformBy(Transform2d transform) {
154    var firstState = m_states.get(0);
155    var firstPose = firstState.poseMeters;
156
157    // Calculate the transformed first pose.
158    var newFirstPose = firstPose.plus(transform);
159    List<State> newStates = new ArrayList<>();
160
161    newStates.add(
162        new State(
163            firstState.timeSeconds,
164            firstState.velocityMetersPerSecond,
165            firstState.accelerationMetersPerSecondSq,
166            newFirstPose,
167            firstState.curvatureRadPerMeter));
168
169    for (int i = 1; i < m_states.size(); i++) {
170      var state = m_states.get(i);
171      // We are transforming relative to the coordinate frame of the new initial pose.
172      newStates.add(
173          new State(
174              state.timeSeconds,
175              state.velocityMetersPerSecond,
176              state.accelerationMetersPerSecondSq,
177              newFirstPose.plus(state.poseMeters.minus(firstPose)),
178              state.curvatureRadPerMeter));
179    }
180
181    return new Trajectory(newStates);
182  }
183
184  /**
185   * Transforms all poses in the trajectory so that they are relative to the given pose. This is
186   * useful for converting a field-relative trajectory into a robot-relative trajectory.
187   *
188   * @param pose The pose that is the origin of the coordinate frame that the current trajectory
189   *     will be transformed into.
190   * @return The transformed trajectory.
191   */
192  public Trajectory relativeTo(Pose2d pose) {
193    return new Trajectory(
194        m_states.stream()
195            .map(
196                state ->
197                    new State(
198                        state.timeSeconds,
199                        state.velocityMetersPerSecond,
200                        state.accelerationMetersPerSecondSq,
201                        state.poseMeters.relativeTo(pose),
202                        state.curvatureRadPerMeter))
203            .collect(Collectors.toList()));
204  }
205
206  /**
207   * Concatenates another trajectory to the current trajectory. The user is responsible for making
208   * sure that the end pose of this trajectory and the start pose of the other trajectory match (if
209   * that is the desired behavior).
210   *
211   * @param other The trajectory to concatenate.
212   * @return The concatenated trajectory.
213   */
214  public Trajectory concatenate(Trajectory other) {
215    // If this is a default constructed trajectory with no states, then we can
216    // simply return the rhs trajectory.
217    if (m_states.isEmpty()) {
218      return other;
219    }
220
221    // Deep copy the current states.
222    List<State> states =
223        m_states.stream()
224            .map(
225                state ->
226                    new State(
227                        state.timeSeconds,
228                        state.velocityMetersPerSecond,
229                        state.accelerationMetersPerSecondSq,
230                        state.poseMeters,
231                        state.curvatureRadPerMeter))
232            .collect(Collectors.toList());
233
234    // Here we omit the first state of the other trajectory because we don't want
235    // two time points with different states. Sample() will automatically
236    // interpolate between the end of this trajectory and the second state of the
237    // other trajectory.
238    for (int i = 1; i < other.getStates().size(); ++i) {
239      var s = other.getStates().get(i);
240      states.add(
241          new State(
242              s.timeSeconds + m_totalTimeSeconds,
243              s.velocityMetersPerSecond,
244              s.accelerationMetersPerSecondSq,
245              s.poseMeters,
246              s.curvatureRadPerMeter));
247    }
248    return new Trajectory(states);
249  }
250
251  /**
252   * Represents a time-parameterized trajectory. The trajectory contains of various States that
253   * represent the pose, curvature, time elapsed, velocity, and acceleration at that point.
254   */
255  public static class State {
256    // The time elapsed since the beginning of the trajectory.
257    @JsonProperty("time")
258    public double timeSeconds;
259
260    // The speed at that point of the trajectory.
261    @JsonProperty("velocity")
262    public double velocityMetersPerSecond;
263
264    // The acceleration at that point of the trajectory.
265    @JsonProperty("acceleration")
266    public double accelerationMetersPerSecondSq;
267
268    // The pose at that point of the trajectory.
269    @JsonProperty("pose")
270    public Pose2d poseMeters;
271
272    // The curvature at that point of the trajectory.
273    @JsonProperty("curvature")
274    public double curvatureRadPerMeter;
275
276    public State() {
277      poseMeters = new Pose2d();
278    }
279
280    /**
281     * Constructs a State with the specified parameters.
282     *
283     * @param timeSeconds The time elapsed since the beginning of the trajectory.
284     * @param velocityMetersPerSecond The speed at that point of the trajectory.
285     * @param accelerationMetersPerSecondSq The acceleration at that point of the trajectory.
286     * @param poseMeters The pose at that point of the trajectory.
287     * @param curvatureRadPerMeter The curvature at that point of the trajectory.
288     */
289    public State(
290        double timeSeconds,
291        double velocityMetersPerSecond,
292        double accelerationMetersPerSecondSq,
293        Pose2d poseMeters,
294        double curvatureRadPerMeter) {
295      this.timeSeconds = timeSeconds;
296      this.velocityMetersPerSecond = velocityMetersPerSecond;
297      this.accelerationMetersPerSecondSq = accelerationMetersPerSecondSq;
298      this.poseMeters = poseMeters;
299      this.curvatureRadPerMeter = curvatureRadPerMeter;
300    }
301
302    /**
303     * Interpolates between two States.
304     *
305     * @param endValue The end value for the interpolation.
306     * @param i The interpolant (fraction).
307     * @return The interpolated state.
308     */
309    State interpolate(State endValue, double i) {
310      // Find the new t value.
311      final double newT = lerp(timeSeconds, endValue.timeSeconds, i);
312
313      // Find the delta time between the current state and the interpolated state.
314      final double deltaT = newT - timeSeconds;
315
316      // If delta time is negative, flip the order of interpolation.
317      if (deltaT < 0) {
318        return endValue.interpolate(this, 1 - i);
319      }
320
321      // Check whether the robot is reversing at this stage.
322      final boolean reversing =
323          velocityMetersPerSecond < 0
324              || Math.abs(velocityMetersPerSecond) < 1E-9 && accelerationMetersPerSecondSq < 0;
325
326      // Calculate the new velocity
327      // v_f = v_0 + at
328      final double newV = velocityMetersPerSecond + (accelerationMetersPerSecondSq * deltaT);
329
330      // Calculate the change in position.
331      // delta_s = v_0 t + 0.5at²
332      final double newS =
333          (velocityMetersPerSecond * deltaT
334                  + 0.5 * accelerationMetersPerSecondSq * Math.pow(deltaT, 2))
335              * (reversing ? -1.0 : 1.0);
336
337      // Return the new state. To find the new position for the new state, we need
338      // to interpolate between the two endpoint poses. The fraction for
339      // interpolation is the change in position (delta s) divided by the total
340      // distance between the two endpoints.
341      final double interpolationFrac =
342          newS / endValue.poseMeters.getTranslation().getDistance(poseMeters.getTranslation());
343
344      return new State(
345          newT,
346          newV,
347          accelerationMetersPerSecondSq,
348          lerp(poseMeters, endValue.poseMeters, interpolationFrac),
349          lerp(curvatureRadPerMeter, endValue.curvatureRadPerMeter, interpolationFrac));
350    }
351
352    @Override
353    public String toString() {
354      return String.format(
355          "State(Sec: %.2f, Vel m/s: %.2f, Accel m/s/s: %.2f, Pose: %s, Curvature: %.2f)",
356          timeSeconds,
357          velocityMetersPerSecond,
358          accelerationMetersPerSecondSq,
359          poseMeters,
360          curvatureRadPerMeter);
361    }
362
363    @Override
364    public boolean equals(Object obj) {
365      if (this == obj) {
366        return true;
367      }
368      if (!(obj instanceof State)) {
369        return false;
370      }
371      State state = (State) obj;
372      return Double.compare(state.timeSeconds, timeSeconds) == 0
373          && Double.compare(state.velocityMetersPerSecond, velocityMetersPerSecond) == 0
374          && Double.compare(state.accelerationMetersPerSecondSq, accelerationMetersPerSecondSq) == 0
375          && Double.compare(state.curvatureRadPerMeter, curvatureRadPerMeter) == 0
376          && Objects.equals(poseMeters, state.poseMeters);
377    }
378
379    @Override
380    public int hashCode() {
381      return Objects.hash(
382          timeSeconds,
383          velocityMetersPerSecond,
384          accelerationMetersPerSecondSq,
385          poseMeters,
386          curvatureRadPerMeter);
387    }
388  }
389
390  @Override
391  public String toString() {
392    String stateList = m_states.stream().map(State::toString).collect(Collectors.joining(", \n"));
393    return String.format("Trajectory - Seconds: %.2f, States:\n%s", m_totalTimeSeconds, stateList);
394  }
395
396  @Override
397  public int hashCode() {
398    return m_states.hashCode();
399  }
400
401  @Override
402  public boolean equals(Object obj) {
403    return obj instanceof Trajectory && m_states.equals(((Trajectory) obj).getStates());
404  }
405}