1#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
39#![cfg_attr(docsrs, feature(doc_auto_cfg))]
40#![deny(missing_docs)]
41#![deny(unsafe_code)]
42#![deny(unreachable_pub)]
43
44use std::sync::Arc;
45use std::sync::atomic::{AtomicBool, AtomicUsize};
46
47use tokio_util::sync::CancellationToken;
48
49mod ext;
51
52pub use ext::*;
53
54#[derive(Debug)]
56struct ContextTracker(Arc<ContextTrackerInner>);
57
58impl Drop for ContextTracker {
59 fn drop(&mut self) {
60 let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
61 if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
64 self.0.notify.notify_waiters();
65 }
66 }
67}
68
69#[derive(Debug)]
70struct ContextTrackerInner {
71 stopped: AtomicBool,
72 active_count: AtomicUsize,
75 notify: tokio::sync::Notify,
76}
77
78impl ContextTrackerInner {
79 fn new() -> Arc<Self> {
80 Arc::new(Self {
81 stopped: AtomicBool::new(false),
82 active_count: AtomicUsize::new(0),
83 notify: tokio::sync::Notify::new(),
84 })
85 }
86
87 fn child(self: &Arc<Self>) -> ContextTracker {
89 self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
90 ContextTracker(Arc::clone(self))
91 }
92
93 fn stop(&self) {
95 self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
96 }
97
98 async fn wait(&self) {
101 let notify = self.notify.notified();
102
103 if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
105 return;
106 }
107
108 notify.await;
109 }
110}
111
112#[derive(Debug)]
124pub struct Context {
125 token: CancellationToken,
126 tracker: ContextTracker,
127}
128
129impl Clone for Context {
130 fn clone(&self) -> Self {
131 Self {
132 token: self.token.clone(),
133 tracker: self.tracker.0.child(),
134 }
135 }
136}
137
138impl Context {
139 #[must_use]
140 pub fn new() -> (Self, Handler) {
143 Handler::global().new_child()
144 }
145
146 #[must_use]
147 pub fn new_child(&self) -> (Self, Handler) {
159 let token = self.token.child_token();
160 let tracker = ContextTrackerInner::new();
161
162 (
163 Self {
164 tracker: tracker.child(),
165 token: token.clone(),
166 },
167 Handler {
168 token: Arc::new(TokenDropGuard(token)),
169 tracker,
170 },
171 )
172 }
173
174 #[must_use]
175 pub fn global() -> Self {
177 Handler::global().context()
178 }
179
180 pub async fn done(&self) {
182 self.token.cancelled().await;
183 }
184
185 pub async fn into_done(self) {
187 self.done().await;
188 }
189
190 #[must_use]
192 pub fn is_done(&self) -> bool {
193 self.token.is_cancelled()
194 }
195}
196
197#[derive(Debug)]
200struct TokenDropGuard(CancellationToken);
201
202impl TokenDropGuard {
203 #[must_use]
204 fn child(&self) -> CancellationToken {
205 self.0.child_token()
206 }
207
208 fn cancel(&self) {
209 self.0.cancel();
210 }
211}
212
213impl Drop for TokenDropGuard {
214 fn drop(&mut self) {
215 self.cancel();
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct Handler {
222 token: Arc<TokenDropGuard>,
223 tracker: Arc<ContextTrackerInner>,
224}
225
226impl Default for Handler {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232impl Handler {
233 #[must_use]
234 pub fn new() -> Handler {
236 let token = CancellationToken::new();
237 let tracker = ContextTrackerInner::new();
238
239 Handler {
240 token: Arc::new(TokenDropGuard(token)),
241 tracker,
242 }
243 }
244
245 #[must_use]
246 pub fn global() -> &'static Self {
248 static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
249
250 GLOBAL.get_or_init(Handler::new)
251 }
252
253 pub async fn shutdown(&self) {
255 self.cancel();
256 self.done().await;
257 }
258
259 pub async fn done(&self) {
261 self.token.0.cancelled().await;
262 self.wait().await;
263 }
264
265 pub async fn wait(&self) {
269 self.tracker.wait().await;
270 }
271
272 #[must_use]
273 pub fn context(&self) -> Context {
275 Context {
276 token: self.token.child(),
277 tracker: self.tracker.child(),
278 }
279 }
280
281 #[must_use]
282 pub fn new_child(&self) -> (Context, Handler) {
284 self.context().new_child()
285 }
286
287 pub fn cancel(&self) {
289 self.tracker.stop();
290 self.token.cancel();
291 }
292
293 pub fn is_done(&self) -> bool {
295 self.token.0.is_cancelled()
296 }
297}
298
299#[cfg_attr(all(coverage_nightly, test), coverage(off))]
300#[cfg(test)]
301mod tests {
302 use scuffle_future_ext::FutureExt;
303
304 use crate::{Context, Handler};
305
306 #[tokio::test]
307 async fn new() {
308 let (ctx, handler) = Context::new();
309 assert!(!handler.is_done());
310 assert!(!ctx.is_done());
311
312 let handler = Handler::default();
313 assert!(!handler.is_done());
314 }
315
316 #[tokio::test]
317 async fn cancel() {
318 let (ctx, handler) = Context::new();
319 let (child_ctx, child_handler) = ctx.new_child();
320 let child_ctx2 = ctx.clone();
321
322 assert!(!handler.is_done());
323 assert!(!ctx.is_done());
324 assert!(!child_handler.is_done());
325 assert!(!child_ctx.is_done());
326 assert!(!child_ctx2.is_done());
327
328 handler.cancel();
329
330 assert!(handler.is_done());
331 assert!(ctx.is_done());
332 assert!(child_handler.is_done());
333 assert!(child_ctx.is_done());
334 assert!(child_ctx2.is_done());
335 }
336
337 #[tokio::test]
338 async fn cancel_child() {
339 let (ctx, handler) = Context::new();
340 let (child_ctx, child_handler) = ctx.new_child();
341
342 assert!(!handler.is_done());
343 assert!(!ctx.is_done());
344 assert!(!child_handler.is_done());
345 assert!(!child_ctx.is_done());
346
347 child_handler.cancel();
348
349 assert!(!handler.is_done());
350 assert!(!ctx.is_done());
351 assert!(child_handler.is_done());
352 assert!(child_ctx.is_done());
353 }
354
355 #[tokio::test]
356 async fn shutdown() {
357 let (ctx, handler) = Context::new();
358
359 assert!(!handler.is_done());
360 assert!(!ctx.is_done());
361
362 assert!(
364 handler
365 .shutdown()
366 .with_timeout(std::time::Duration::from_millis(200))
367 .await
368 .is_err()
369 );
370 assert!(handler.is_done());
371 assert!(ctx.is_done());
372 assert!(
373 ctx.into_done()
374 .with_timeout(std::time::Duration::from_millis(200))
375 .await
376 .is_ok()
377 );
378
379 assert!(
380 handler
381 .shutdown()
382 .with_timeout(std::time::Duration::from_millis(200))
383 .await
384 .is_ok()
385 );
386 assert!(
387 handler
388 .wait()
389 .with_timeout(std::time::Duration::from_millis(200))
390 .await
391 .is_ok()
392 );
393 assert!(
394 handler
395 .done()
396 .with_timeout(std::time::Duration::from_millis(200))
397 .await
398 .is_ok()
399 );
400 assert!(handler.is_done());
401 }
402
403 #[tokio::test]
404 async fn global_handler() {
405 let handler = Handler::global();
406
407 assert!(!handler.is_done());
408
409 handler.cancel();
410
411 assert!(handler.is_done());
412 assert!(Handler::global().is_done());
413 assert!(Context::global().is_done());
414
415 let (child_ctx, child_handler) = Handler::global().new_child();
416 assert!(child_handler.is_done());
417 assert!(child_ctx.is_done());
418 }
419}
420
421#[cfg(feature = "docs")]
423#[scuffle_changelog::changelog]
424pub mod changelog {}