WPILibC++ 2023.4.3-108-ge5452e3
CallbackManager.h
Go to the documentation of this file.
1// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#ifndef WPIUTIL_WPI_CALLBACKMANAGER_H_
6#define WPIUTIL_WPI_CALLBACKMANAGER_H_
7
8#include <atomic>
9#include <climits>
10#include <functional>
11#include <memory>
12#include <queue>
13#include <utility>
14#include <vector>
15
16#include "wpi/SafeThread.h"
17#include "wpi/UidVector.h"
19#include "wpi/mutex.h"
20#include "wpi/raw_ostream.h"
21
22namespace wpi {
23
24template <typename Callback>
26 public:
28 explicit CallbackListenerData(Callback callback_) : callback(callback_) {}
29 explicit CallbackListenerData(unsigned int poller_uid_)
30 : poller_uid(poller_uid_) {}
31
32 explicit operator bool() const { return callback || poller_uid != UINT_MAX; }
33
34 Callback callback;
35 unsigned int poller_uid = UINT_MAX;
36};
37
38// CRTP callback manager thread
39// @tparam Derived derived class
40// @tparam NotifierData data buffered for each callback
41// @tparam ListenerData data stored for each listener
42// Derived must define the following functions:
43// bool Matches(const ListenerData& listener, const NotifierData& data);
44// void SetListener(NotifierData* data, unsigned int listener_uid);
45// void DoCallback(Callback callback, const NotifierData& data);
46template <typename Derived, typename TUserInfo,
47 typename TListenerData =
48 CallbackListenerData<std::function<void(const TUserInfo& info)>>,
49 typename TNotifierData = TUserInfo>
51 public:
52 using UserInfo = TUserInfo;
53 using NotifierData = TNotifierData;
54 using ListenerData = TListenerData;
55
56 CallbackThread(std::function<void()> on_start, std::function<void()> on_exit)
57 : m_on_start(std::move(on_start)), m_on_exit(std::move(on_exit)) {}
58
59 ~CallbackThread() override {
60 // Wake up any blocked pollers
61 for (size_t i = 0; i < m_pollers.size(); ++i) {
62 if (auto poller = m_pollers[i]) {
63 poller->Terminate();
64 }
65 }
66 }
67
68 void Main() override;
69
71
72 std::queue<std::pair<unsigned int, NotifierData>> m_queue;
74
75 struct Poller {
76 void Terminate() {
77 {
78 std::scoped_lock lock(poll_mutex);
79 terminating = true;
80 }
81 poll_cond.notify_all();
82 }
83 std::queue<NotifierData> poll_queue;
86 bool terminating = false;
87 bool canceling = false;
88 };
90
91 std::function<void()> m_on_start;
92 std::function<void()> m_on_exit;
93
94 // Must be called with m_mutex held
95 template <typename... Args>
96 void SendPoller(unsigned int poller_uid, Args&&... args) {
97 if (poller_uid > m_pollers.size()) {
98 return;
99 }
100 auto poller = m_pollers[poller_uid];
101 if (!poller) {
102 return;
103 }
104 {
105 std::scoped_lock lock(poller->poll_mutex);
106 poller->poll_queue.emplace(std::forward<Args>(args)...);
107 }
108 poller->poll_cond.notify_one();
109 }
110};
111
112template <typename Derived, typename TUserInfo, typename TListenerData,
113 typename TNotifierData>
115 if (m_on_start) {
116 m_on_start();
117 }
118
119 std::unique_lock lock(m_mutex);
120 while (m_active) {
121 while (m_queue.empty()) {
122 m_cond.wait(lock);
123 if (!m_active) {
124 goto done;
125 }
126 }
127
128 while (!m_queue.empty()) {
129 if (!m_active) {
130 goto done;
131 }
132 auto item = std::move(m_queue.front());
133
134 if (item.first != UINT_MAX) {
135 if (item.first < m_listeners.size()) {
136 auto& listener = m_listeners[item.first];
137 if (listener &&
138 static_cast<Derived*>(this)->Matches(listener, item.second)) {
139 static_cast<Derived*>(this)->SetListener(&item.second, item.first);
140 if (listener.callback) {
141 lock.unlock();
142 static_cast<Derived*>(this)->DoCallback(listener.callback,
143 item.second);
144 lock.lock();
145 } else if (listener.poller_uid != UINT_MAX) {
146 SendPoller(listener.poller_uid, std::move(item.second));
147 }
148 }
149 }
150 } else {
151 // Use index because iterator might get invalidated.
152 for (size_t i = 0; i < m_listeners.size(); ++i) {
153 auto& listener = m_listeners[i];
154 if (!listener) {
155 continue;
156 }
157
158 if (!static_cast<Derived*>(this)->Matches(listener, item.second)) {
159 continue;
160 }
161 static_cast<Derived*>(this)->SetListener(&item.second,
162 static_cast<unsigned>(i));
163 if (listener.callback) {
164 lock.unlock();
165 static_cast<Derived*>(this)->DoCallback(listener.callback,
166 item.second);
167 lock.lock();
168 } else if (listener.poller_uid != UINT_MAX) {
169 SendPoller(listener.poller_uid, item.second);
170 }
171 }
172 }
173 m_queue.pop();
174 }
175
176 m_queue_empty.notify_all();
177 }
178
179done:
180 if (m_on_exit) {
181 m_on_exit();
182 }
183}
184
185// CRTP callback manager
186// @tparam Derived derived class
187// @tparam Thread custom thread (must be derived from impl::CallbackThread)
188//
189// Derived must define the following functions:
190// void Start();
191template <typename Derived, typename Thread>
193 friend class RpcServerTest;
194
195 public:
196 void SetOnStart(std::function<void()> on_start) {
197 m_on_start = std::move(on_start);
198 }
199
200 void SetOnExit(std::function<void()> on_exit) {
201 m_on_exit = std::move(on_exit);
202 }
203
204 void Stop() { m_owner.Stop(); }
205
206 void Remove(unsigned int listener_uid) {
207 auto thr = m_owner.GetThread();
208 if (!thr) {
209 return;
210 }
211 thr->m_listeners.erase(listener_uid);
212 }
213
214 unsigned int CreatePoller() {
215 static_cast<Derived*>(this)->Start();
216 auto thr = m_owner.GetThread();
217 return thr->m_pollers.emplace_back(
218 std::make_shared<typename Thread::Poller>());
219 }
220
221 void RemovePoller(unsigned int poller_uid) {
222 auto thr = m_owner.GetThread();
223 if (!thr) {
224 return;
225 }
226
227 // Remove any listeners that are associated with this poller
228 for (size_t i = 0; i < thr->m_listeners.size(); ++i) {
229 if (thr->m_listeners[i].poller_uid == poller_uid) {
230 thr->m_listeners.erase(i);
231 }
232 }
233
234 // Wake up any blocked pollers
235 if (poller_uid >= thr->m_pollers.size()) {
236 return;
237 }
238 auto poller = thr->m_pollers[poller_uid];
239 if (!poller) {
240 return;
241 }
242 poller->Terminate();
243 thr->m_pollers.erase(poller_uid);
244 }
245
246 bool WaitForQueue(double timeout) {
247 auto thr = m_owner.GetThread();
248 if (!thr) {
249 return true;
250 }
251
252 auto& lock = thr.GetLock();
253 auto timeout_time = std::chrono::steady_clock::now() +
254 std::chrono::duration<double>(timeout);
255 while (!thr->m_queue.empty()) {
256 if (!thr->m_active) {
257 return true;
258 }
259 if (timeout == 0) {
260 return false;
261 }
262 if (timeout < 0) {
263 thr->m_queue_empty.wait(lock);
264 } else {
265 auto cond_timed_out = thr->m_queue_empty.wait_until(lock, timeout_time);
266 if (cond_timed_out == std::cv_status::timeout) {
267 return false;
268 }
269 }
270 }
271
272 return true;
273 }
274
275 std::vector<typename Thread::UserInfo> Poll(unsigned int poller_uid) {
276 bool timed_out = false;
277 return Poll(poller_uid, -1, &timed_out);
278 }
279
280 std::vector<typename Thread::UserInfo> Poll(unsigned int poller_uid,
281 double timeout, bool* timed_out) {
282 std::vector<typename Thread::UserInfo> infos;
283 std::shared_ptr<typename Thread::Poller> poller;
284 {
285 auto thr = m_owner.GetThread();
286 if (!thr) {
287 return infos;
288 }
289 if (poller_uid > thr->m_pollers.size()) {
290 return infos;
291 }
292 poller = thr->m_pollers[poller_uid];
293 if (!poller) {
294 return infos;
295 }
296 }
297
298 std::unique_lock lock(poller->poll_mutex);
299 auto timeout_time = std::chrono::steady_clock::now() +
300 std::chrono::duration<double>(timeout);
301 *timed_out = false;
302 while (poller->poll_queue.empty()) {
303 if (poller->terminating) {
304 return infos;
305 }
306 if (poller->canceling) {
307 // Note: this only works if there's a single thread calling this
308 // function for any particular poller, but that's the intended use.
309 poller->canceling = false;
310 return infos;
311 }
312 if (timeout == 0) {
313 *timed_out = true;
314 return infos;
315 }
316 if (timeout < 0) {
317 poller->poll_cond.wait(lock);
318 } else {
319 auto cond_timed_out = poller->poll_cond.wait_until(lock, timeout_time);
320 if (cond_timed_out == std::cv_status::timeout) {
321 *timed_out = true;
322 return infos;
323 }
324 }
325 }
326
327 while (!poller->poll_queue.empty()) {
328 infos.emplace_back(std::move(poller->poll_queue.front()));
329 poller->poll_queue.pop();
330 }
331 return infos;
332 }
333
334 void CancelPoll(unsigned int poller_uid) {
335 std::shared_ptr<typename Thread::Poller> poller;
336 {
337 auto thr = m_owner.GetThread();
338 if (!thr) {
339 return;
340 }
341 if (poller_uid > thr->m_pollers.size()) {
342 return;
343 }
344 poller = thr->m_pollers[poller_uid];
345 if (!poller) {
346 return;
347 }
348 }
349
350 {
351 std::scoped_lock lock(poller->poll_mutex);
352 poller->canceling = true;
353 }
354 poller->poll_cond.notify_one();
355 }
356
357 protected:
358 template <typename... Args>
359 void DoStart(Args&&... args) {
360 m_owner.Start(m_on_start, m_on_exit, std::forward<Args>(args)...);
361 }
362
363 template <typename... Args>
364 unsigned int DoAdd(Args&&... args) {
365 static_cast<Derived*>(this)->Start();
366 auto thr = m_owner.GetThread();
367 return thr->m_listeners.emplace_back(std::forward<Args>(args)...);
368 }
369
370 template <typename... Args>
371 void Send(unsigned int only_listener, Args&&... args) {
372 auto thr = m_owner.GetThread();
373 if (!thr || thr->m_listeners.empty()) {
374 return;
375 }
376 thr->m_queue.emplace(std::piecewise_construct,
377 std::make_tuple(only_listener),
378 std::forward_as_tuple(std::forward<Args>(args)...));
379 thr->m_cond.notify_one();
380 }
381
383 return m_owner.GetThread();
384 }
385
386 private:
388
389 std::function<void()> m_on_start;
390 std::function<void()> m_on_exit;
391};
392
393} // namespace wpi
394
395#endif // WPIUTIL_WPI_CALLBACKMANAGER_H_
Definition: CallbackManager.h:25
CallbackListenerData(unsigned int poller_uid_)
Definition: CallbackManager.h:29
CallbackListenerData(Callback callback_)
Definition: CallbackManager.h:28
unsigned int poller_uid
Definition: CallbackManager.h:35
Callback callback
Definition: CallbackManager.h:34
Definition: CallbackManager.h:192
wpi::SafeThreadOwner< Thread >::Proxy GetThread() const
Definition: CallbackManager.h:382
void SetOnExit(std::function< void()> on_exit)
Definition: CallbackManager.h:200
void Send(unsigned int only_listener, Args &&... args)
Definition: CallbackManager.h:371
void SetOnStart(std::function< void()> on_start)
Definition: CallbackManager.h:196
bool WaitForQueue(double timeout)
Definition: CallbackManager.h:246
void RemovePoller(unsigned int poller_uid)
Definition: CallbackManager.h:221
void CancelPoll(unsigned int poller_uid)
Definition: CallbackManager.h:334
unsigned int DoAdd(Args &&... args)
Definition: CallbackManager.h:364
friend class RpcServerTest
Definition: CallbackManager.h:193
void Remove(unsigned int listener_uid)
Definition: CallbackManager.h:206
void Stop()
Definition: CallbackManager.h:204
std::vector< typename Thread::UserInfo > Poll(unsigned int poller_uid)
Definition: CallbackManager.h:275
std::vector< typename Thread::UserInfo > Poll(unsigned int poller_uid, double timeout, bool *timed_out)
Definition: CallbackManager.h:280
void DoStart(Args &&... args)
Definition: CallbackManager.h:359
unsigned int CreatePoller()
Definition: CallbackManager.h:214
Definition: CallbackManager.h:50
std::queue< std::pair< unsigned int, NotifierData > > m_queue
Definition: CallbackManager.h:72
TUserInfo UserInfo
Definition: CallbackManager.h:52
std::function< void()> m_on_start
Definition: CallbackManager.h:91
TListenerData ListenerData
Definition: CallbackManager.h:54
CallbackThread(std::function< void()> on_start, std::function< void()> on_exit)
Definition: CallbackManager.h:56
TNotifierData NotifierData
Definition: CallbackManager.h:53
void SendPoller(unsigned int poller_uid, Args &&... args)
Definition: CallbackManager.h:96
~CallbackThread() override
Definition: CallbackManager.h:59
std::function< void()> m_on_exit
Definition: CallbackManager.h:92
void Main() override
Definition: CallbackManager.h:114
wpi::condition_variable m_queue_empty
Definition: CallbackManager.h:73
wpi::UidVector< std::shared_ptr< Poller >, 64 > m_pollers
Definition: CallbackManager.h:89
wpi::UidVector< ListenerData, 64 > m_listeners
Definition: CallbackManager.h:70
Definition: SafeThread.h:33
void Start(Args &&... args)
Definition: SafeThread.h:127
Proxy GetThread() const
Definition: SafeThread.h:133
typename detail::SafeThreadProxy< T > Proxy
Definition: SafeThread.h:132
@ done
Definition: format.h:2566
Definition: BFloat16.h:88
Definition: AprilTagFieldLayout.h:18
::std::mutex mutex
Definition: mutex.h:17
::std::condition_variable condition_variable
Definition: condition_variable.h:16
Definition: CallbackManager.h:75
wpi::mutex poll_mutex
Definition: CallbackManager.h:84
wpi::condition_variable poll_cond
Definition: CallbackManager.h:85
void Terminate()
Definition: CallbackManager.h:76
std::queue< NotifierData > poll_queue
Definition: CallbackManager.h:83
bool canceling
Definition: CallbackManager.h:87
bool terminating
Definition: CallbackManager.h:86