001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.math.spline;
006
007import org.ejml.simple.SimpleMatrix;
008
009public class QuinticHermiteSpline extends Spline {
010  private static SimpleMatrix hermiteBasis;
011  private final SimpleMatrix m_coefficients;
012
013  /**
014   * Constructs a quintic hermite spline with the specified control vectors. Each control vector
015   * contains into about the location of the point, its first derivative, and its second derivative.
016   *
017   * @param xInitialControlVector The control vector for the initial point in the x dimension.
018   * @param xFinalControlVector The control vector for the final point in the x dimension.
019   * @param yInitialControlVector The control vector for the initial point in the y dimension.
020   * @param yFinalControlVector The control vector for the final point in the y dimension.
021   */
022  public QuinticHermiteSpline(
023      double[] xInitialControlVector,
024      double[] xFinalControlVector,
025      double[] yInitialControlVector,
026      double[] yFinalControlVector) {
027    super(5);
028
029    // Populate the coefficients for the actual spline equations.
030    // Row 0 is x coefficients
031    // Row 1 is y coefficients
032    final var hermite = makeHermiteBasis();
033    final var x = getControlVectorFromArrays(xInitialControlVector, xFinalControlVector);
034    final var y = getControlVectorFromArrays(yInitialControlVector, yFinalControlVector);
035
036    final var xCoeffs = (hermite.mult(x)).transpose();
037    final var yCoeffs = (hermite.mult(y)).transpose();
038
039    m_coefficients = new SimpleMatrix(6, 6);
040
041    for (int i = 0; i < 6; i++) {
042      m_coefficients.set(0, i, xCoeffs.get(0, i));
043      m_coefficients.set(1, i, yCoeffs.get(0, i));
044    }
045    for (int i = 0; i < 6; i++) {
046      // Populate Row 2 and Row 3 with the derivatives of the equations above.
047      // Here, we are multiplying by (5 - i) to manually take the derivative. The
048      // power of the term in index 0 is 5, index 1 is 4 and so on. To find the
049      // coefficient of the derivative, we can use the power rule and multiply
050      // the existing coefficient by its power.
051      m_coefficients.set(2, i, m_coefficients.get(0, i) * (5 - i));
052      m_coefficients.set(3, i, m_coefficients.get(1, i) * (5 - i));
053    }
054    for (int i = 0; i < 5; i++) {
055      // Then populate row 4 and 5 with the second derivatives.
056      // Here, we are multiplying by (4 - i) to manually take the derivative. The
057      // power of the term in index 0 is 4, index 1 is 3 and so on. To find the
058      // coefficient of the derivative, we can use the power rule and multiply
059      // the existing coefficient by its power.
060      m_coefficients.set(4, i, m_coefficients.get(2, i) * (4 - i));
061      m_coefficients.set(5, i, m_coefficients.get(3, i) * (4 - i));
062    }
063  }
064
065  /**
066   * Returns the coefficients matrix.
067   *
068   * @return The coefficients matrix.
069   */
070  @Override
071  protected SimpleMatrix getCoefficients() {
072    return m_coefficients;
073  }
074
075  /**
076   * Returns the hermite basis matrix for quintic hermite spline interpolation.
077   *
078   * @return The hermite basis matrix for quintic hermite spline interpolation.
079   */
080  private SimpleMatrix makeHermiteBasis() {
081    if (hermiteBasis == null) {
082      // Given P(i), P'(i), P"(i), P(i+1), P'(i+1), P"(i+1), the control vectors,
083      // we want to find the coefficients of the spline
084      // P(t) = a₅t⁵ + a₄t⁴ + a₃t³ + a₂t² + a₁t + a₀.
085      //
086      // P(i)    = P(0)  = a₀
087      // P'(i)   = P'(0) = a₁
088      // P''(i)  = P"(0) = 2a₂
089      // P(i+1)  = P(1)  = a₅ + a₄ + a₃ + a₂ + a₁ + a₀
090      // P'(i+1) = P'(1) = 5a₅ + 4a₄ + 3a₃ + 2a₂ + a₁
091      // P"(i+1) = P"(1) = 20a₅ + 12a₄ + 6a₃ + 2a₂
092      //
093      // [P(i)   ] = [ 0  0  0  0  0  1][a₅]
094      // [P'(i)  ] = [ 0  0  0  0  1  0][a₄]
095      // [P"(i)  ] = [ 0  0  0  2  0  0][a₃]
096      // [P(i+1) ] = [ 1  1  1  1  1  1][a₂]
097      // [P'(i+1)] = [ 5  4  3  2  1  0][a₁]
098      // [P"(i+1)] = [20 12  6  2  0  0][a₀]
099      //
100      // To solve for the coefficients, we can invert the 6x6 matrix and move it
101      // to the other side of the equation.
102      //
103      // [a₅] = [ -6.0  -3.0  -0.5   6.0  -3.0   0.5][P(i)   ]
104      // [a₄] = [ 15.0   8.0   1.5 -15.0   7.0  -1.0][P'(i)  ]
105      // [a₃] = [-10.0  -6.0  -1.5  10.0  -4.0   0.5][P"(i)  ]
106      // [a₂] = [  0.0   0.0   0.5   0.0   0.0   0.0][P(i+1) ]
107      // [a₁] = [  0.0   1.0   0.0   0.0   0.0   0.0][P'(i+1)]
108      // [a₀] = [  1.0   0.0   0.0   0.0   0.0   0.0][P"(i+1)]
109      hermiteBasis =
110          new SimpleMatrix(
111              6,
112              6,
113              true,
114              new double[] {
115                -06.0, -03.0, -00.5, +06.0, -03.0, +00.5, +15.0, +08.0, +01.5, -15.0, +07.0, -01.0,
116                -10.0, -06.0, -01.5, +10.0, -04.0, +00.5, +00.0, +00.0, +00.5, +00.0, +00.0, +00.0,
117                +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +01.0, +00.0, +00.0, +00.0, +00.0, +00.0
118              });
119    }
120    return hermiteBasis;
121  }
122
123  /**
124   * Returns the control vector for each dimension as a matrix from the user-provided arrays in the
125   * constructor.
126   *
127   * @param initialVector The control vector for the initial point.
128   * @param finalVector The control vector for the final point.
129   * @return The control vector matrix for a dimension.
130   */
131  private SimpleMatrix getControlVectorFromArrays(double[] initialVector, double[] finalVector) {
132    if (initialVector.length != 3 || finalVector.length != 3) {
133      throw new IllegalArgumentException("Size of vectors must be 3");
134    }
135    return new SimpleMatrix(
136        6,
137        1,
138        true,
139        new double[] {
140          initialVector[0], initialVector[1], initialVector[2],
141          finalVector[0], finalVector[1], finalVector[2]
142        });
143  }
144}