scuffle_signal/
bootstrap.rs

1use std::sync::Arc;
2
3use scuffle_bootstrap::global::Global;
4use scuffle_bootstrap::service::Service;
5use scuffle_context::ContextFutExt;
6
7/// A [`Service`] that listens for signals and cancels the context when a signal is received.
8#[derive(Default, Debug, Clone, Copy)]
9pub struct SignalSvc;
10
11/// Configuration for the signal service.
12pub trait SignalConfig: Global {
13    /// The signals to listen for.
14    ///
15    /// By default, listens for `SIGTERM` and `SIGINT`.
16    fn signals(&self) -> Vec<crate::SignalKind> {
17        vec![crate::SignalKind::Terminate, crate::SignalKind::Interrupt]
18    }
19
20    /// The timeout before forcing a shutdown.
21    fn timeout(&self) -> Option<std::time::Duration> {
22        Some(std::time::Duration::from_secs(30))
23    }
24
25    /// Called when the service is shutting down.
26    fn on_shutdown(self: &Arc<Self>) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
27        std::future::ready(Ok(()))
28    }
29
30    /// Called when the service is force shutting down.
31    fn on_force_shutdown(
32        &self,
33        signal: Option<crate::SignalKind>,
34    ) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
35        let err = if let Some(signal) = signal {
36            anyhow::anyhow!("received signal, shutting down immediately: {:?}", signal)
37        } else {
38            anyhow::anyhow!("timeout reached, shutting down immediately")
39        };
40
41        std::future::ready(Err(err))
42    }
43
44    /// Waits for the global shutdown to complete after a signal cancels the local context.
45    /// Defaults to the global context’s shutdown ([`scuffle_context::Handler::global().shutdown()`]).
46    /// Override to use a custom context or condition for shutdown completion.
47    fn block_global_shutdown(&self) -> impl std::future::Future<Output = ()> + Send {
48        scuffle_context::Handler::global().shutdown()
49    }
50}
51
52impl<Global: SignalConfig> Service<Global> for SignalSvc {
53    fn enabled(&self, global: &Arc<Global>) -> impl std::future::Future<Output = anyhow::Result<bool>> + Send {
54        std::future::ready(Ok(!global.signals().is_empty()))
55    }
56
57    async fn run(self, global: Arc<Global>, ctx: scuffle_context::Context) -> anyhow::Result<()> {
58        let timeout = global.timeout();
59
60        let signals = global.signals();
61        anyhow::ensure!(!signals.is_empty(), "no signals to listen for");
62
63        let mut handler = crate::SignalHandler::with_signals(signals);
64
65        // Wait for a signal, or for the context to be done.
66        handler.recv().with_context(&ctx).await;
67        global.on_shutdown().await?;
68        drop(ctx);
69
70        tokio::select! {
71            signal = handler.recv() => {
72                global.on_force_shutdown(Some(signal)).await?;
73            },
74            _ = global.block_global_shutdown() => {}
75            Some(()) = async {
76                if let Some(timeout) = timeout {
77                    tokio::time::sleep(timeout).await;
78                    Some(())
79                } else {
80                    None
81                }
82            } => {
83                global.on_force_shutdown(None).await?;
84            },
85        };
86
87        Ok(())
88    }
89}
90
91#[cfg(test)]
92#[cfg_attr(all(coverage_nightly, test), coverage(off))]
93mod test {
94    use std::sync::Arc;
95
96    use scuffle_bootstrap::{GlobalWithoutConfig, Service};
97    use scuffle_future_ext::FutureExt;
98
99    use super::SignalConfig;
100    use crate::tests::raise_signal;
101    use crate::{SignalKind, SignalSvc};
102
103    async fn force_shutdown_two_signals<Global: GlobalWithoutConfig + SignalConfig>() {
104        let (ctx, handler) = scuffle_context::Context::new();
105
106        let _global_ctx = scuffle_context::Context::global();
107
108        let svc = SignalSvc;
109        let global = <Global as GlobalWithoutConfig>::init().await.unwrap();
110
111        assert!(svc.enabled(&global).await.unwrap());
112        let result = tokio::spawn(svc.run(global, ctx));
113
114        // Wait for the service to start
115        tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
116
117        raise_signal(SignalKind::Interrupt).await;
118        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
119        raise_signal(SignalKind::Interrupt).await;
120        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
121
122        match result.with_timeout(tokio::time::Duration::from_millis(1000)).await {
123            Ok(Ok(Err(e))) => {
124                assert_eq!(e.to_string(), "received signal, shutting down immediately: Interrupt");
125            }
126            r => panic!("unexpected result: {r:?}"),
127        }
128
129        assert!(
130            handler
131                .shutdown()
132                .with_timeout(tokio::time::Duration::from_millis(1000))
133                .await
134                .is_ok()
135        );
136    }
137
138    struct TestGlobal;
139
140    impl GlobalWithoutConfig for TestGlobal {
141        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
142            std::future::ready(Ok(Arc::new(Self)))
143        }
144    }
145
146    impl SignalConfig for TestGlobal {
147        async fn block_global_shutdown(&self) {
148            std::future::pending().await
149        }
150    }
151
152    #[tokio::test]
153    #[cfg(not(valgrind))]
154    async fn default_bootstrap_service() {
155        force_shutdown_two_signals::<TestGlobal>().await;
156    }
157
158    struct NoTimeoutTestGlobal(tokio::sync::Notify);
159
160    impl GlobalWithoutConfig for NoTimeoutTestGlobal {
161        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
162            std::future::ready(Ok(Arc::new(Self(tokio::sync::Notify::new()))))
163        }
164    }
165
166    impl SignalConfig for NoTimeoutTestGlobal {
167        fn timeout(&self) -> Option<std::time::Duration> {
168            None
169        }
170
171        // We dont want to block the global shutdown
172        async fn block_global_shutdown(&self) {
173            self.0.notified().await;
174        }
175    }
176
177    #[tokio::test]
178    #[cfg(not(valgrind))]
179    async fn bootstrap_service_no_timeout() {
180        let (ctx, handler) = scuffle_context::Context::new();
181        let svc = SignalSvc;
182        let global = <NoTimeoutTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
183
184        assert!(svc.enabled(&global).await.unwrap());
185        let mut result = tokio::spawn(svc.run(global.clone(), ctx));
186
187        // Wait for the service to start
188        println!("waiting for service to start");
189        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
190
191        raise_signal(SignalKind::Interrupt).await;
192        // no timeout so it should block indefinitely
193        assert!(
194            (&mut result)
195                .with_timeout(tokio::time::Duration::from_millis(100))
196                .await
197                .is_err()
198        );
199
200        global.0.notify_one();
201
202        assert!(result.with_timeout(tokio::time::Duration::from_millis(100)).await.is_ok());
203
204        assert!(
205            handler
206                .shutdown()
207                .with_timeout(tokio::time::Duration::from_millis(1000))
208                .await
209                .is_ok()
210        );
211    }
212
213    #[tokio::test]
214    #[cfg(not(valgrind))]
215    async fn bootstrap_service_force_shutdown() {
216        force_shutdown_two_signals::<NoTimeoutTestGlobal>().await;
217    }
218
219    struct NoSignalsTestGlobal;
220
221    impl GlobalWithoutConfig for NoSignalsTestGlobal {
222        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
223            std::future::ready(Ok(Arc::new(Self)))
224        }
225    }
226
227    impl SignalConfig for NoSignalsTestGlobal {
228        fn signals(&self) -> Vec<crate::SignalKind> {
229            vec![]
230        }
231
232        fn timeout(&self) -> Option<std::time::Duration> {
233            None
234        }
235
236        async fn block_global_shutdown(&self) {
237            std::future::pending().await
238        }
239    }
240
241    #[tokio::test]
242    async fn bootstrap_service_no_signals() {
243        let (ctx, handler) = scuffle_context::Context::new();
244        let svc = SignalSvc;
245        let global = <NoSignalsTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
246
247        assert!(!svc.enabled(&global).await.unwrap());
248        let result = svc.run(global, ctx).await.unwrap_err();
249
250        assert_eq!(result.to_string(), "no signals to listen for");
251
252        assert!(
253            handler
254                .shutdown()
255                .with_timeout(tokio::time::Duration::from_millis(1000))
256                .await
257                .is_ok()
258        );
259    }
260
261    struct SmallTimeoutTestGlobal;
262
263    impl GlobalWithoutConfig for SmallTimeoutTestGlobal {
264        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
265            std::future::ready(Ok(Arc::new(Self)))
266        }
267    }
268
269    impl SignalConfig for SmallTimeoutTestGlobal {
270        fn timeout(&self) -> Option<std::time::Duration> {
271            Some(std::time::Duration::from_millis(50))
272        }
273
274        async fn block_global_shutdown(&self) {
275            std::future::pending().await
276        }
277    }
278
279    #[tokio::test]
280    #[cfg(not(valgrind))]
281    async fn bootstrap_service_timeout_force_shutdown() {
282        let (ctx, handler) = scuffle_context::Context::new();
283
284        // Block the global context
285        let _global_ctx = scuffle_context::Context::global();
286
287        let svc = SignalSvc;
288        let global = <SmallTimeoutTestGlobal as GlobalWithoutConfig>::init().await.unwrap();
289
290        assert!(svc.enabled(&global).await.unwrap());
291        let result = tokio::spawn(svc.run(global, ctx));
292
293        // Wait for the service to start
294        tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
295
296        raise_signal(crate::SignalKind::Interrupt).await;
297
298        match result.with_timeout(tokio::time::Duration::from_millis(1000)).await {
299            Ok(Ok(Err(e))) => {
300                assert_eq!(e.to_string(), "timeout reached, shutting down immediately");
301            }
302            _ => panic!("unexpected result"),
303        }
304
305        assert!(
306            handler
307                .shutdown()
308                .with_timeout(tokio::time::Duration::from_millis(1000))
309                .await
310                .is_ok()
311        );
312    }
313}