001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.numbers.N1; 011import java.util.ArrayList; 012import java.util.List; 013import java.util.Map; 014import java.util.function.BiConsumer; 015 016public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> { 017 private static final int kMaxPastObserverStates = 300; 018 019 private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots; 020 021 KalmanFilterLatencyCompensator() { 022 m_pastObserverSnapshots = new ArrayList<>(); 023 } 024 025 /** Clears the observer snapshot buffer. */ 026 public void reset() { 027 m_pastObserverSnapshots.clear(); 028 } 029 030 /** 031 * Add past observer states to the observer snapshots list. 032 * 033 * @param observer The observer. 034 * @param u The input at the timestamp. 035 * @param localY The local output at the timestamp 036 * @param timestampSeconds The timestamp of the state. 037 */ 038 public void addObserverState( 039 KalmanTypeFilter<S, I, O> observer, 040 Matrix<I, N1> u, 041 Matrix<O, N1> localY, 042 double timestampSeconds) { 043 m_pastObserverSnapshots.add( 044 Map.entry(timestampSeconds, new ObserverSnapshot(observer, u, localY))); 045 046 if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) { 047 m_pastObserverSnapshots.remove(0); 048 } 049 } 050 051 /** 052 * Add past global measurements (such as from vision)to the estimator. 053 * 054 * @param <R> The rows in the global measurement vector. 055 * @param rows The rows in the global measurement vector. 056 * @param observer The observer to apply the past global measurement. 057 * @param nominalDtSeconds The nominal timestep. 058 * @param y The measurement. 059 * @param globalMeasurementCorrect The function take calls correct() on the observer. 060 * @param timestampSeconds The timestamp of the measurement. 061 */ 062 public <R extends Num> void applyPastGlobalMeasurement( 063 Nat<R> rows, 064 KalmanTypeFilter<S, I, O> observer, 065 double nominalDtSeconds, 066 Matrix<R, N1> y, 067 BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect, 068 double timestampSeconds) { 069 if (m_pastObserverSnapshots.isEmpty()) { 070 // State map was empty, which means that we got a past measurement right at startup. The only 071 // thing we can really do is ignore the measurement. 072 return; 073 } 074 075 // Use a less verbose name for timestamp 076 double timestamp = timestampSeconds; 077 078 int maxIdx = m_pastObserverSnapshots.size() - 1; 079 int low = 0; 080 int high = maxIdx; 081 082 // Perform a binary search to find the index of first snapshot whose 083 // timestamp is greater than or equal to the global measurement timestamp 084 while (low != high) { 085 int mid = (low + high) / 2; 086 if (m_pastObserverSnapshots.get(mid).getKey() < timestamp) { 087 // This index and everything under it are less than the requested timestamp. Therefore, we 088 // can discard them. 089 low = mid + 1; 090 } else { 091 // t is at least as large as the element at this index. This means that anything after it 092 // cannot be what we are looking for. 093 high = mid; 094 } 095 } 096 097 int indexOfClosestEntry; 098 099 if (low == 0) { 100 // If the global measurement is older than any snapshot, throw out the 101 // measurement because there's no state estimate into which to incorporate 102 // the measurement 103 if (timestamp < m_pastObserverSnapshots.get(low).getKey()) { 104 return; 105 } 106 107 // If the first snapshot has same timestamp as the global measurement, use 108 // that snapshot 109 indexOfClosestEntry = 0; 110 } else if (low == maxIdx && m_pastObserverSnapshots.get(low).getKey() < timestamp) { 111 // If all snapshots are older than the global measurement, use the newest 112 // snapshot 113 indexOfClosestEntry = maxIdx; 114 } else { 115 // Index of snapshot taken after the global measurement 116 int nextIdx = low; 117 118 // Index of snapshot taken before the global measurement. Since we already 119 // handled the case where the index points to the first snapshot, this 120 // computation is guaranteed to be non-negative. 121 int prevIdx = nextIdx - 1; 122 123 // Find the snapshot closest in time to global measurement 124 double prevTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(prevIdx).getKey()); 125 double nextTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(nextIdx).getKey()); 126 indexOfClosestEntry = prevTimeDiff <= nextTimeDiff ? prevIdx : nextIdx; 127 } 128 129 double lastTimestamp = 130 m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds; 131 132 // We will now go back in time to the state of the system at the time when 133 // the measurement was captured. We will reset the observer to that state, 134 // and apply correction based on the measurement. Then, we will go back 135 // through all observer states until the present and apply past inputs to 136 // get the present estimated state. 137 for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) { 138 var key = m_pastObserverSnapshots.get(i).getKey(); 139 var snapshot = m_pastObserverSnapshots.get(i).getValue(); 140 141 if (i == indexOfClosestEntry) { 142 observer.setP(snapshot.errorCovariances); 143 observer.setXhat(snapshot.xHat); 144 } 145 146 observer.predict(snapshot.inputs, key - lastTimestamp); 147 observer.correct(snapshot.inputs, snapshot.localMeasurements); 148 149 if (i == indexOfClosestEntry) { 150 // Note that the measurement is at a timestep close but probably not exactly equal to the 151 // timestep for which we called predict. 152 // This makes the assumption that the dt is small enough that the difference between the 153 // measurement time and the time that the inputs were captured at is very small. 154 globalMeasurementCorrect.accept(snapshot.inputs, y); 155 } 156 lastTimestamp = key; 157 158 m_pastObserverSnapshots.set( 159 i, 160 Map.entry( 161 key, new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements))); 162 } 163 } 164 165 /** This class contains all the information about our observer at a given time. */ 166 public class ObserverSnapshot { 167 public final Matrix<S, N1> xHat; 168 public final Matrix<S, S> errorCovariances; 169 public final Matrix<I, N1> inputs; 170 public final Matrix<O, N1> localMeasurements; 171 172 private ObserverSnapshot( 173 KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY) { 174 this.xHat = observer.getXhat(); 175 this.errorCovariances = observer.getP(); 176 177 inputs = u; 178 localMeasurements = localY; 179 } 180 } 181}