001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009public class CubicHermiteSpline extends Spline {
010  private static SimpleMatrix hermiteBasis;
011  private final SimpleMatrix m_coefficients;
012
013  /**
014   * Constructs a cubic hermite spline with the specified control vectors. Each control vector
015   * contains info about the location of the point and its first derivative.
016   *
017   * @param xInitialControlVector The control vector for the initial point in the x dimension.
018   * @param xFinalControlVector The control vector for the final point in the x dimension.
019   * @param yInitialControlVector The control vector for the initial point in the y dimension.
020   * @param yFinalControlVector The control vector for the final point in the y dimension.
021   */
022  public CubicHermiteSpline(
023      double[] xInitialControlVector,
024      double[] xFinalControlVector,
025      double[] yInitialControlVector,
026      double[] yFinalControlVector) {
027    super(3);
028
029    // Populate the coefficients for the actual spline equations.
030    // Row 0 is x coefficients
031    // Row 1 is y coefficients
032    final var hermite = makeHermiteBasis();
033    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
034    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
035
036    final var xCoeffs = (hermite.mult(x)).transpose();
037    final var yCoeffs = (hermite.mult(y)).transpose();
038
039    m_coefficients = new SimpleMatrix(6, 4);
040
041    for (int i = 0; i < 4; i++) {
042      m_coefficients.set(0, i, xCoeffs.get(0, i));
043      m_coefficients.set(1, i, yCoeffs.get(0, i));
044
045      // Populate Row 2 and Row 3 with the derivatives of the equations above.
046      // Then populate row 4 and 5 with the second derivatives.
047      // Here, we are multiplying by (3 - i) to manually take the derivative. The
048      // power of the term in index 0 is 3, index 1 is 2 and so on. To find the
049      // coefficient of the derivative, we can use the power rule and multiply
050      // the existing coefficient by its power.
051      m_coefficients.set(2, i, m_coefficients.get(0, i) * (3 - i));
052      m_coefficients.set(3, i, m_coefficients.get(1, i) * (3 - i));
053    }
054
055    for (int i = 0; i < 3; i++) {
056      // Here, we are multiplying by (2 - i) to manually take the derivative. The
057      // power of the term in index 0 is 2, index 1 is 1 and so on. To find the
058      // coefficient of the derivative, we can use the power rule and multiply
059      // the existing coefficient by its power.
060      m_coefficients.set(4, i, m_coefficients.get(2, i) * (2 - i));
061      m_coefficients.set(5, i, m_coefficients.get(3, i) * (2 - i));
062    }
063  }
064
065  /**
066   * Returns the coefficients matrix.
067   *
068   * @return The coefficients matrix.
069   */
070  @Override
071  protected SimpleMatrix getCoefficients() {
072    return m_coefficients;
073  }
074
075  /**
076   * Returns the hermite basis matrix for cubic hermite spline interpolation.
077   *
078   * @return The hermite basis matrix for cubic hermite spline interpolation.
079   */
080  private SimpleMatrix makeHermiteBasis() {
081    if (hermiteBasis == null) {
082      // Given P(i), P'(i), P(i+1), P'(i+1), the control vectors, we want to find
083      // the coefficients of the spline P(t) = a₃t³ + a₂t² + a₁t + a₀.
084      //
085      // P(i)    = P(0)  = a₀
086      // P'(i)   = P'(0) = a₁
087      // P(i+1)  = P(1)  = a₃ + a₂ + a₁ + a₀
088      // P'(i+1) = P'(1) = 3a₃ + 2a₂ + a₁
089      //
090      // [P(i)   ] = [0 0 0 1][a₃]
091      // [P'(i)  ] = [0 0 1 0][a₂]
092      // [P(i+1) ] = [1 1 1 1][a₁]
093      // [P'(i+1)] = [3 2 1 0][a₀]
094      //
095      // To solve for the coefficients, we can invert the 4x4 matrix and move it
096      // to the other side of the equation.
097      //
098      // [a₃] = [ 2  1 -2  1][P(i)   ]
099      // [a₂] = [-3 -2  3 -1][P'(i)  ]
100      // [a₁] = [ 0  1  0  0][P(i+1) ]
101      // [a₀] = [ 1  0  0  0][P'(i+1)]
102      hermiteBasis =
103          new SimpleMatrix(
104              4,
105              4,
106              true,
107              new double[] {
108                +2.0, +1.0, -2.0, +1.0, -3.0, -2.0, +3.0, -1.0, +0.0, +1.0, +0.0, +0.0, +1.0, +0.0,
109                +0.0, +0.0
110              });
111    }
112    return hermiteBasis;
113  }
114
115  /**
116   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
117   * constructor.
118   *
119   * @param initialVector The control vector for the initial point.
120   * @param finalVector The control vector for the final point.
121   * @return The control vector matrix for a dimension.
122   */
123  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
124    if (initialVector.length != 2 || finalVector.length != 2) {
125      throw new IllegalArgumentException("Size of vectors must be 2");
126    }
127    return new SimpleMatrix(
128        4,
129        1,
130        true,
131        new double[] {
132          initialVector[0], initialVector[1],
133          finalVector[0], finalVector[1]
134        });
135  }
136}