scuffle_context/
lib.rs

1//! A crate designed to provide the ability to cancel futures using a context
2//! go-like approach, allowing for graceful shutdowns and cancellations.
3#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6//! ## Why do we need this?
7//!
8//! Its often useful to wait for all the futures to shutdown or to cancel them
9//! when we no longer care about the results. This crate provides an interface
10//! to cancel all futures associated with a context or wait for them to finish
11//! before shutting down. Allowing for graceful shutdowns and cancellations.
12//!
13//! ## Usage
14//!
15//! Here is an example of how to use the `Context` to cancel a spawned task.
16//!
17//! ```rust
18//! # use scuffle_context::{Context, ContextFutExt};
19//! # tokio_test::block_on(async {
20//! let (ctx, handler) = Context::new();
21//!
22//! tokio::spawn(async {
23//!     // Do some work
24//!     tokio::time::sleep(std::time::Duration::from_secs(10)).await;
25//! }.with_context(ctx));
26//!
27//! // Will stop the spawned task and cancel all associated futures.
28//! handler.cancel();
29//! # });
30//! ```
31//!
32//! ## License
33//!
34//! This project is licensed under the MIT or Apache-2.0 license.
35//! You can choose between one of them if you use this work.
36//!
37//! `SPDX-License-Identifier: MIT OR Apache-2.0`
38#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
39#![cfg_attr(docsrs, feature(doc_auto_cfg))]
40#![deny(missing_docs)]
41#![deny(unsafe_code)]
42#![deny(unreachable_pub)]
43
44use std::sync::Arc;
45use std::sync::atomic::{AtomicBool, AtomicUsize};
46
47use tokio_util::sync::CancellationToken;
48
49/// For extending types.
50mod ext;
51
52pub use ext::*;
53
54/// Create by calling [`ContextTrackerInner::child`].
55#[derive(Debug)]
56struct ContextTracker(Arc<ContextTrackerInner>);
57
58impl Drop for ContextTracker {
59    fn drop(&mut self) {
60        let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
61        // If this was the last active `ContextTracker` and the context has been
62        // stopped, then notify the waiters
63        if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
64            self.0.notify.notify_waiters();
65        }
66    }
67}
68
69#[derive(Debug)]
70struct ContextTrackerInner {
71    stopped: AtomicBool,
72    /// This count keeps track of the number of `ContextTrackers` that exist for
73    /// this `ContextTrackerInner`.
74    active_count: AtomicUsize,
75    notify: tokio::sync::Notify,
76}
77
78impl ContextTrackerInner {
79    fn new() -> Arc<Self> {
80        Arc::new(Self {
81            stopped: AtomicBool::new(false),
82            active_count: AtomicUsize::new(0),
83            notify: tokio::sync::Notify::new(),
84        })
85    }
86
87    /// Create a new `ContextTracker` from an `Arc<ContextTrackerInner>`.
88    fn child(self: &Arc<Self>) -> ContextTracker {
89        self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
90        ContextTracker(Arc::clone(self))
91    }
92
93    /// Mark this `ContextTrackerInner` as stopped.
94    fn stop(&self) {
95        self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
96    }
97
98    /// Wait for this `ContextTrackerInner` to be stopped and all associated
99    /// `ContextTracker`s to be dropped.
100    async fn wait(&self) {
101        let notify = self.notify.notified();
102
103        // If there are no active children, then the notify will never be called
104        if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
105            return;
106        }
107
108        notify.await;
109    }
110}
111
112/// A context for cancelling futures and waiting for shutdown.
113///
114/// A context can be created from a handler by calling [`Handler::context`] or
115/// from another context by calling [`Context::new_child`] so to have a
116/// hierarchy of contexts.
117///
118/// Contexts can then be attached to futures or streams in order to
119/// automatically cancel them when the context is done, when invoking
120/// [`Handler::cancel`].
121/// The [`Handler::shutdown`] method will block until all contexts have been
122/// dropped allowing for a graceful shutdown.
123#[derive(Debug)]
124pub struct Context {
125    token: CancellationToken,
126    tracker: ContextTracker,
127}
128
129impl Clone for Context {
130    fn clone(&self) -> Self {
131        Self {
132            token: self.token.clone(),
133            tracker: self.tracker.0.child(),
134        }
135    }
136}
137
138impl Context {
139    #[must_use]
140    /// Create a new context using the global handler.
141    /// Returns a child context and child handler of the global handler.
142    pub fn new() -> (Self, Handler) {
143        Handler::global().new_child()
144    }
145
146    #[must_use]
147    /// Create a new child context from this context.
148    /// Returns a new child context and child handler of this context.
149    ///
150    /// # Example
151    ///
152    /// ```rust
153    /// use scuffle_context::Context;
154    ///
155    /// let (parent, parent_handler) = Context::new();
156    /// let (child, child_handler) = parent.new_child();
157    /// ```
158    pub fn new_child(&self) -> (Self, Handler) {
159        let token = self.token.child_token();
160        let tracker = ContextTrackerInner::new();
161
162        (
163            Self {
164                tracker: tracker.child(),
165                token: token.clone(),
166            },
167            Handler {
168                token: Arc::new(TokenDropGuard(token)),
169                tracker,
170            },
171        )
172    }
173
174    #[must_use]
175    /// Returns the global context
176    pub fn global() -> Self {
177        Handler::global().context()
178    }
179
180    /// Wait for the context to be done (the handler to be shutdown).
181    pub async fn done(&self) {
182        self.token.cancelled().await;
183    }
184
185    /// The same as [`Context::done`] but takes ownership of the context.
186    pub async fn into_done(self) {
187        self.done().await;
188    }
189
190    /// Returns true if the context is done.
191    #[must_use]
192    pub fn is_done(&self) -> bool {
193        self.token.is_cancelled()
194    }
195}
196
197/// A wrapper type around [`CancellationToken`] that will cancel the token as
198/// soon as it is dropped.
199#[derive(Debug)]
200struct TokenDropGuard(CancellationToken);
201
202impl TokenDropGuard {
203    #[must_use]
204    fn child(&self) -> CancellationToken {
205        self.0.child_token()
206    }
207
208    fn cancel(&self) {
209        self.0.cancel();
210    }
211}
212
213impl Drop for TokenDropGuard {
214    fn drop(&mut self) {
215        self.cancel();
216    }
217}
218
219/// A handler is used to manage contexts and to cancel them.
220#[derive(Debug, Clone)]
221pub struct Handler {
222    token: Arc<TokenDropGuard>,
223    tracker: Arc<ContextTrackerInner>,
224}
225
226impl Default for Handler {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232impl Handler {
233    #[must_use]
234    /// Create a new handler.
235    pub fn new() -> Handler {
236        let token = CancellationToken::new();
237        let tracker = ContextTrackerInner::new();
238
239        Handler {
240            token: Arc::new(TokenDropGuard(token)),
241            tracker,
242        }
243    }
244
245    #[must_use]
246    /// Returns the global handler.
247    pub fn global() -> &'static Self {
248        static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
249
250        GLOBAL.get_or_init(Handler::new)
251    }
252
253    /// Shutdown the handler and wait for all contexts to be done.
254    pub async fn shutdown(&self) {
255        self.cancel();
256        self.done().await;
257    }
258
259    /// Waits for the handler to be done (waiting for all contexts to be done).
260    pub async fn done(&self) {
261        self.token.0.cancelled().await;
262        self.wait().await;
263    }
264
265    /// Waits for the handler to be done (waiting for all contexts to be done).
266    /// Returns once all contexts are done, even if the handler is not done and
267    /// contexts can be created after this call.
268    pub async fn wait(&self) {
269        self.tracker.wait().await;
270    }
271
272    #[must_use]
273    /// Create a new context from this handler.
274    pub fn context(&self) -> Context {
275        Context {
276            token: self.token.child(),
277            tracker: self.tracker.child(),
278        }
279    }
280
281    #[must_use]
282    /// Create a new child context from this handler
283    pub fn new_child(&self) -> (Context, Handler) {
284        self.context().new_child()
285    }
286
287    /// Cancel the handler.
288    pub fn cancel(&self) {
289        self.tracker.stop();
290        self.token.cancel();
291    }
292
293    /// Returns true if the handler is done.
294    pub fn is_done(&self) -> bool {
295        self.token.0.is_cancelled()
296    }
297}
298
299#[cfg_attr(all(coverage_nightly, test), coverage(off))]
300#[cfg(test)]
301mod tests {
302    use scuffle_future_ext::FutureExt;
303
304    use crate::{Context, Handler};
305
306    #[tokio::test]
307    async fn new() {
308        let (ctx, handler) = Context::new();
309        assert!(!handler.is_done());
310        assert!(!ctx.is_done());
311
312        let handler = Handler::default();
313        assert!(!handler.is_done());
314    }
315
316    #[tokio::test]
317    async fn cancel() {
318        let (ctx, handler) = Context::new();
319        let (child_ctx, child_handler) = ctx.new_child();
320        let child_ctx2 = ctx.clone();
321
322        assert!(!handler.is_done());
323        assert!(!ctx.is_done());
324        assert!(!child_handler.is_done());
325        assert!(!child_ctx.is_done());
326        assert!(!child_ctx2.is_done());
327
328        handler.cancel();
329
330        assert!(handler.is_done());
331        assert!(ctx.is_done());
332        assert!(child_handler.is_done());
333        assert!(child_ctx.is_done());
334        assert!(child_ctx2.is_done());
335    }
336
337    #[tokio::test]
338    async fn cancel_child() {
339        let (ctx, handler) = Context::new();
340        let (child_ctx, child_handler) = ctx.new_child();
341
342        assert!(!handler.is_done());
343        assert!(!ctx.is_done());
344        assert!(!child_handler.is_done());
345        assert!(!child_ctx.is_done());
346
347        child_handler.cancel();
348
349        assert!(!handler.is_done());
350        assert!(!ctx.is_done());
351        assert!(child_handler.is_done());
352        assert!(child_ctx.is_done());
353    }
354
355    #[tokio::test]
356    async fn shutdown() {
357        let (ctx, handler) = Context::new();
358
359        assert!(!handler.is_done());
360        assert!(!ctx.is_done());
361
362        // This is expected to timeout
363        assert!(
364            handler
365                .shutdown()
366                .with_timeout(std::time::Duration::from_millis(200))
367                .await
368                .is_err()
369        );
370        assert!(handler.is_done());
371        assert!(ctx.is_done());
372        assert!(
373            ctx.into_done()
374                .with_timeout(std::time::Duration::from_millis(200))
375                .await
376                .is_ok()
377        );
378
379        assert!(
380            handler
381                .shutdown()
382                .with_timeout(std::time::Duration::from_millis(200))
383                .await
384                .is_ok()
385        );
386        assert!(
387            handler
388                .wait()
389                .with_timeout(std::time::Duration::from_millis(200))
390                .await
391                .is_ok()
392        );
393        assert!(
394            handler
395                .done()
396                .with_timeout(std::time::Duration::from_millis(200))
397                .await
398                .is_ok()
399        );
400        assert!(handler.is_done());
401    }
402
403    #[tokio::test]
404    async fn global_handler() {
405        let handler = Handler::global();
406
407        assert!(!handler.is_done());
408
409        handler.cancel();
410
411        assert!(handler.is_done());
412        assert!(Handler::global().is_done());
413        assert!(Context::global().is_done());
414
415        let (child_ctx, child_handler) = Handler::global().new_child();
416        assert!(child_handler.is_done());
417        assert!(child_ctx.is_done());
418    }
419}
420
421/// Changelogs generated by [scuffle_changelog]
422#[cfg(feature = "docs")]
423#[scuffle_changelog::changelog]
424pub mod changelog {}