scuffle_context/
ext.rs

1use std::future::{Future, IntoFuture};
2use std::pin::Pin;
3use std::task::Poll;
4
5use futures_lite::Stream;
6use tokio_util::sync::{WaitForCancellationFuture, WaitForCancellationFutureOwned};
7
8use crate::{Context, ContextTracker};
9
10/// A reference to a context which implements [`Future`] and can be polled.
11/// Can either be owned or borrowed.
12///
13/// Create by using the [`From`] implementations.
14pub struct ContextRef<'a> {
15    inner: ContextRefInner<'a>,
16}
17
18impl From<Context> for ContextRef<'_> {
19    fn from(ctx: Context) -> Self {
20        ContextRef {
21            inner: ContextRefInner::Owned {
22                fut: ctx.token.cancelled_owned(),
23                tracker: ctx.tracker,
24            },
25        }
26    }
27}
28
29impl<'a> From<&'a Context> for ContextRef<'a> {
30    fn from(ctx: &'a Context) -> Self {
31        ContextRef {
32            inner: ContextRefInner::Ref {
33                fut: ctx.token.cancelled(),
34            },
35        }
36    }
37}
38
39pin_project_lite::pin_project! {
40    #[project = ContextRefInnerProj]
41    enum ContextRefInner<'a> {
42        Owned {
43            #[pin] fut: WaitForCancellationFutureOwned,
44            tracker: ContextTracker,
45        },
46        Ref {
47            #[pin] fut: WaitForCancellationFuture<'a>,
48        },
49    }
50}
51
52impl std::future::Future for ContextRefInner<'_> {
53    type Output = ();
54
55    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
56        match self.project() {
57            ContextRefInnerProj::Owned { fut, .. } => fut.poll(cx),
58            ContextRefInnerProj::Ref { fut } => fut.poll(cx),
59        }
60    }
61}
62
63pin_project_lite::pin_project! {
64    /// A future with a context attached to it.
65    ///
66    /// This future will be cancelled when the context is done.
67    pub struct FutureWithContext<'a, F> {
68        #[pin]
69        future: F,
70        #[pin]
71        ctx: ContextRefInner<'a>,
72        _marker: std::marker::PhantomData<&'a ()>,
73    }
74}
75
76impl<F: Future> Future for FutureWithContext<'_, F> {
77    type Output = Option<F::Output>;
78
79    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
80        let this = self.project();
81
82        match (this.ctx.poll(cx), this.future.poll(cx)) {
83            (_, Poll::Ready(v)) => std::task::Poll::Ready(Some(v)),
84            (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None),
85            (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
86        }
87    }
88}
89
90/// Extends a future with useful functions.
91pub trait ContextFutExt<Fut> {
92    /// Wraps a future with a context and cancels the future when the context is
93    /// done.
94    ///
95    /// # Example
96    ///
97    /// ```rust
98    /// # use scuffle_context::{Context, ContextFutExt};
99    /// # tokio_test::block_on(async {
100    /// let (ctx, handler) = Context::new();
101    ///
102    /// tokio::spawn(async {
103    ///    // Do some work
104    ///    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
105    /// }.with_context(ctx));
106    ///
107    /// // Will stop the spawned task and cancel all associated futures.
108    /// handler.cancel();
109    /// # });
110    /// ```
111    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, Fut>
112    where
113        Self: Sized;
114}
115
116impl<F: IntoFuture> ContextFutExt<F::IntoFuture> for F {
117    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, F::IntoFuture>
118    where
119        F: IntoFuture,
120    {
121        FutureWithContext {
122            future: self.into_future(),
123            ctx: ctx.into().inner,
124            _marker: std::marker::PhantomData,
125        }
126    }
127}
128
129pin_project_lite::pin_project! {
130    /// A stream with a context attached to it.
131    ///
132    /// This stream will be cancelled when the context is done.
133    pub struct StreamWithContext<'a, F> {
134        #[pin]
135        stream: F,
136        #[pin]
137        ctx: ContextRefInner<'a>,
138        _marker: std::marker::PhantomData<&'a ()>,
139    }
140}
141
142impl<F: Stream> Stream for StreamWithContext<'_, F> {
143    type Item = F::Item;
144
145    fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
146        let this = self.project();
147
148        match (this.ctx.poll(cx), this.stream.poll_next(cx)) {
149            (Poll::Ready(_), _) => std::task::Poll::Ready(None),
150            (Poll::Pending, Poll::Ready(v)) => std::task::Poll::Ready(v),
151            (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
152        }
153    }
154
155    fn size_hint(&self) -> (usize, Option<usize>) {
156        self.stream.size_hint()
157    }
158}
159
160/// Extends a stream with useful functions.
161pub trait ContextStreamExt<Stream> {
162    /// Wraps a stream with a context and stops the stream when the context is
163    /// done.
164    ///
165    /// # Example
166    ///
167    /// ```rust
168    /// # use scuffle_context::{Context, ContextStreamExt};
169    /// # use futures_lite as futures;
170    /// # use futures_lite::StreamExt;
171    /// # tokio_test::block_on(async {
172    /// let (ctx, handler) = Context::new();
173    ///
174    /// tokio::spawn(async {
175    ///     futures::stream::iter(1..=10).then(|d| async move {
176    ///         // Do some work
177    ///         tokio::time::sleep(std::time::Duration::from_secs(d)).await;
178    ///     }).with_context(ctx);
179    /// });
180    ///
181    /// // Will stop the spawned task and cancel all associated streams.
182    /// handler.cancel();
183    /// # });
184    /// ```
185    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, Stream>
186    where
187        Self: Sized;
188}
189
190impl<F: Stream> ContextStreamExt<F> for F {
191    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, F> {
192        StreamWithContext {
193            stream: self,
194            ctx: ctx.into().inner,
195            _marker: std::marker::PhantomData,
196        }
197    }
198}
199
200#[cfg_attr(all(coverage_nightly, test), coverage(off))]
201#[cfg(test)]
202mod tests {
203    use std::pin::pin;
204
205    use futures_lite::{Stream, StreamExt};
206    use scuffle_future_ext::FutureExt;
207
208    use super::{Context, ContextFutExt, ContextStreamExt};
209
210    #[tokio::test]
211    async fn future() {
212        let (ctx, handler) = Context::new();
213
214        let task = tokio::spawn(
215            async {
216                // Do some work
217                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
218            }
219            .with_context(ctx),
220        );
221
222        // Sleep for a bit to make sure the future is polled at least once.
223        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
224
225        // Will stop the spawned task and cancel all associated futures.
226        handler.shutdown().await;
227
228        task.await.unwrap();
229    }
230
231    #[tokio::test]
232    async fn future_result() {
233        let (ctx, handler) = Context::new();
234
235        let task = tokio::spawn(async { 1 }.with_context(ctx));
236
237        // Will stop the spawned task and cancel all associated futures.
238        handler.shutdown().await;
239
240        assert_eq!(task.await.unwrap(), Some(1));
241    }
242
243    #[tokio::test]
244    async fn future_ctx_by_ref() {
245        let (ctx, handler) = Context::new();
246
247        let task = tokio::spawn(async move {
248            async {
249                // Do some work
250                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
251            }
252            .with_context(&ctx)
253            .await;
254
255            drop(ctx);
256        });
257
258        // Will stop the spawned task and cancel all associated futures.
259        handler.shutdown().await;
260
261        task.await.unwrap();
262    }
263
264    #[tokio::test]
265    async fn stream() {
266        let (ctx, handler) = Context::new();
267
268        {
269            let mut stream = pin!(futures_lite::stream::iter(0..10).with_context(ctx));
270
271            assert_eq!(stream.size_hint(), (10, Some(10)));
272
273            assert_eq!(stream.next().await, Some(0));
274            assert_eq!(stream.next().await, Some(1));
275            assert_eq!(stream.next().await, Some(2));
276            assert_eq!(stream.next().await, Some(3));
277
278            // Will stop the spawned task and cancel all associated streams.
279            handler.cancel();
280
281            assert_eq!(stream.next().await, None);
282        }
283
284        handler.shutdown().await;
285    }
286
287    #[tokio::test]
288    async fn pending_stream() {
289        let (ctx, handler) = Context::new();
290
291        {
292            let mut stream = pin!(futures_lite::stream::pending::<()>().with_context(ctx));
293
294            // This is expected to timeout
295            assert!(
296                stream
297                    .next()
298                    .with_timeout(std::time::Duration::from_millis(200))
299                    .await
300                    .is_err()
301            );
302        }
303
304        handler.shutdown().await;
305    }
306}