scuffle_signal/
lib.rs

1//! A crate designed to provide a more user friendly interface to
2//! `tokio::signal`.
3#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6//! ## Why do we need this?
7//!
8//! The `tokio::signal` module provides a way for us to wait for a signal to be
9//! received in a non-blocking way. This crate extends that with a more helpful
10//! interface allowing the ability to listen to multiple signals concurrently.
11//!
12//! ## Example
13//!
14//! ```rust
15//! # #[cfg(unix)]
16//! # {
17//! use scuffle_signal::SignalHandler;
18//! use tokio::signal::unix::SignalKind;
19//!
20//! # tokio_test::block_on(async {
21//! let mut handler = SignalHandler::new()
22//!     .with_signal(SignalKind::interrupt())
23//!     .with_signal(SignalKind::terminate());
24//!
25//! # // Safety: This is a test, and we control the process.
26//! # unsafe {
27//! #    libc::raise(SignalKind::interrupt().as_raw_value());
28//! # }
29//! // Wait for a signal to be received
30//! let signal = handler.await;
31//!
32//! // Handle the signal
33//! let interrupt = SignalKind::interrupt();
34//! let terminate = SignalKind::terminate();
35//! match signal {
36//!     interrupt => {
37//!         // Handle SIGINT
38//!         println!("received SIGINT");
39//!     },
40//!     terminate => {
41//!         // Handle SIGTERM
42//!         println!("received SIGTERM");
43//!     },
44//! }
45//! # });
46//! # }
47//! ```
48//!
49//! ## License
50//!
51//! This project is licensed under the MIT or Apache-2.0 license.
52//! You can choose between one of them if you use this work.
53//!
54//! `SPDX-License-Identifier: MIT OR Apache-2.0`
55#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
56#![cfg_attr(docsrs, feature(doc_auto_cfg))]
57#![deny(missing_docs)]
58#![deny(unreachable_pub)]
59#![deny(clippy::undocumented_unsafe_blocks)]
60#![deny(clippy::multiple_unsafe_ops_per_block)]
61
62use std::pin::Pin;
63use std::task::{Context, Poll};
64
65#[cfg(unix)]
66use tokio::signal::unix;
67#[cfg(unix)]
68pub use tokio::signal::unix::SignalKind as UnixSignalKind;
69
70#[cfg(feature = "bootstrap")]
71mod bootstrap;
72
73#[cfg(feature = "bootstrap")]
74pub use bootstrap::{SignalConfig, SignalSvc};
75
76/// The type of signal to listen for.
77#[derive(Debug, Clone, Copy, Eq)]
78pub enum SignalKind {
79    /// Represents the interrupt signal, which is `SIGINT` on Unix and `Ctrl-C` on Windows.
80    Interrupt,
81    /// Represents the terminate signal, which is `SIGTERM` on Unix and `Ctrl-Close` on Windows.
82    Terminate,
83    /// Represents a Windows-specific signal kind, as defined in `WindowsSignalKind`.
84    #[cfg(windows)]
85    Windows(WindowsSignalKind),
86    /// Represents a Unix-specific signal kind, wrapping `tokio::signal::unix::SignalKind`.
87    #[cfg(unix)]
88    Unix(UnixSignalKind),
89}
90
91impl PartialEq for SignalKind {
92    fn eq(&self, other: &Self) -> bool {
93        #[cfg(unix)]
94        const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
95        #[cfg(unix)]
96        const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
97
98        match (self, other) {
99            #[cfg(windows)]
100            (
101                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
102                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
103            ) => true,
104            #[cfg(windows)]
105            (
106                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
107                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
108            ) => true,
109            #[cfg(windows)]
110            (Self::Windows(a), Self::Windows(b)) => a == b,
111            #[cfg(unix)]
112            (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
113            #[cfg(unix)]
114            (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
115            #[cfg(unix)]
116            (Self::Unix(a), Self::Unix(b)) => a == b,
117            _ => false,
118        }
119    }
120}
121
122#[cfg(unix)]
123impl From<UnixSignalKind> for SignalKind {
124    fn from(value: UnixSignalKind) -> Self {
125        match value {
126            kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
127            kind if kind == UnixSignalKind::terminate() => Self::Terminate,
128            kind => Self::Unix(kind),
129        }
130    }
131}
132
133#[cfg(unix)]
134impl PartialEq<UnixSignalKind> for SignalKind {
135    fn eq(&self, other: &UnixSignalKind) -> bool {
136        match self {
137            Self::Interrupt => other == &UnixSignalKind::interrupt(),
138            Self::Terminate => other == &UnixSignalKind::terminate(),
139            Self::Unix(kind) => kind == other,
140        }
141    }
142}
143
144/// Represents Windows-specific signal kinds.
145#[cfg(windows)]
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum WindowsSignalKind {
148    /// Represents the `Ctrl-Break` signal.
149    CtrlBreak,
150    /// Represents the `Ctrl-C` signal.
151    CtrlC,
152    /// Represents the `Ctrl-Close` signal.
153    CtrlClose,
154    /// Represents the `Ctrl-Logoff` signal.
155    CtrlLogoff,
156    /// Represents the `Ctrl-Shutdown` signal.
157    CtrlShutdown,
158}
159
160#[cfg(windows)]
161impl From<WindowsSignalKind> for SignalKind {
162    fn from(value: WindowsSignalKind) -> Self {
163        match value {
164            WindowsSignalKind::CtrlC => Self::Interrupt,
165            WindowsSignalKind::CtrlClose => Self::Terminate,
166            WindowsSignalKind::CtrlBreak => Self::Windows(value),
167            WindowsSignalKind::CtrlLogoff => Self::Windows(value),
168            WindowsSignalKind::CtrlShutdown => Self::Windows(value),
169        }
170    }
171}
172
173#[cfg(windows)]
174impl PartialEq<WindowsSignalKind> for SignalKind {
175    fn eq(&self, other: &WindowsSignalKind) -> bool {
176        match self {
177            Self::Interrupt => other == &WindowsSignalKind::CtrlC,
178            Self::Terminate => other == &WindowsSignalKind::CtrlClose,
179            Self::Windows(kind) => kind == other,
180        }
181    }
182}
183
184#[cfg(windows)]
185#[derive(Debug)]
186enum WindowsSignalValue {
187    CtrlBreak(tokio::signal::windows::CtrlBreak),
188    CtrlC(tokio::signal::windows::CtrlC),
189    CtrlClose(tokio::signal::windows::CtrlClose),
190    CtrlLogoff(tokio::signal::windows::CtrlLogoff),
191    CtrlShutdown(tokio::signal::windows::CtrlShutdown),
192    #[cfg(test)]
193    Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
194}
195
196#[cfg(windows)]
197impl WindowsSignalValue {
198    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
199        #[cfg(test)]
200        use futures::Stream;
201
202        match self {
203            Self::CtrlBreak(signal) => signal.poll_recv(cx),
204            Self::CtrlC(signal) => signal.poll_recv(cx),
205            Self::CtrlClose(signal) => signal.poll_recv(cx),
206            Self::CtrlLogoff(signal) => signal.poll_recv(cx),
207            Self::CtrlShutdown(signal) => signal.poll_recv(cx),
208            #[cfg(test)]
209            Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
210                Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
211                Poll::Ready(Some(Ok(_))) => {
212                    cx.waker().wake_by_ref();
213                    Poll::Pending
214                }
215                Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
216                Poll::Pending => {
217                    cx.waker().wake_by_ref();
218                    Poll::Pending
219                }
220            },
221        }
222    }
223}
224
225#[cfg(unix)]
226type Signal = unix::Signal;
227
228#[cfg(windows)]
229type Signal = WindowsSignalValue;
230
231impl SignalKind {
232    #[cfg(unix)]
233    fn listen(&self) -> Result<Signal, std::io::Error> {
234        match self {
235            Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
236            Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
237            Self::Unix(kind) => tokio::signal::unix::signal(*kind),
238        }
239    }
240
241    #[cfg(windows)]
242    fn listen(&self) -> Result<Signal, std::io::Error> {
243        #[cfg(test)]
244        if cfg!(test) {
245            return Ok(WindowsSignalValue::Mock(
246                *self,
247                Box::pin(tokio_stream::wrappers::BroadcastStream::new(test::SignalMocker::subscribe())),
248            ));
249        }
250
251        match self {
252            // https://learn.microsoft.com/en-us/windows/console/ctrl-c-and-ctrl-break-signals
253            Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
254                Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
255            }
256            // https://learn.microsoft.com/en-us/windows/console/ctrl-close-signal
257            Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
258                Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
259            }
260            Self::Windows(WindowsSignalKind::CtrlBreak) => {
261                Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
262            }
263            Self::Windows(WindowsSignalKind::CtrlLogoff) => {
264                Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
265            }
266            Self::Windows(WindowsSignalKind::CtrlShutdown) => {
267                Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
268            }
269        }
270    }
271}
272
273/// A handler for listening to multiple signals, and providing a future for
274/// receiving them.
275///
276/// This is useful for applications that need to listen for multiple signals,
277/// and want to react to them in a non-blocking way. Typically you would need to
278/// use a tokio::select{} to listen for multiple signals, but this provides a
279/// more ergonomic interface for doing so.
280///
281/// After a signal is received you can poll the handler again to wait for
282/// another signal. Dropping the handle will cancel the signal subscription
283///
284/// # Example
285///
286/// ```rust
287/// # #[cfg(unix)]
288/// # {
289/// use scuffle_signal::SignalHandler;
290/// use tokio::signal::unix::SignalKind;
291///
292/// # tokio_test::block_on(async {
293/// let mut handler = SignalHandler::new()
294///     .with_signal(SignalKind::interrupt())
295///     .with_signal(SignalKind::terminate());
296///
297/// # // Safety: This is a test, and we control the process.
298/// # unsafe {
299/// #    libc::raise(SignalKind::interrupt().as_raw_value());
300/// # }
301/// // Wait for a signal to be received
302/// let signal = handler.await;
303///
304/// // Handle the signal
305/// let interrupt = SignalKind::interrupt();
306/// let terminate = SignalKind::terminate();
307/// match signal {
308///     interrupt => {
309///         // Handle SIGINT
310///         println!("received SIGINT");
311///     },
312///     terminate => {
313///         // Handle SIGTERM
314///         println!("received SIGTERM");
315///     },
316/// }
317/// # });
318/// # }
319/// ```
320#[derive(Debug)]
321#[must_use = "signal handlers must be used to wait for signals"]
322pub struct SignalHandler {
323    signals: Vec<(SignalKind, Signal)>,
324}
325
326impl Default for SignalHandler {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332impl SignalHandler {
333    /// Create a new `SignalHandler` with no signals.
334    pub const fn new() -> Self {
335        Self { signals: Vec::new() }
336    }
337
338    /// Create a new `SignalHandler` with the given signals.
339    pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
340        let mut handler = Self::new();
341
342        for signal in signals {
343            handler = handler.with_signal(signal.into());
344        }
345
346        handler
347    }
348
349    /// Add a signal to the handler.
350    ///
351    /// If the signal is already in the handler, it will not be added again.
352    pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
353        self.add_signal(kind);
354        self
355    }
356
357    /// Add a signal to the handler.
358    ///
359    /// If the signal is already in the handler, it will not be added again.
360    pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
361        let kind = kind.into();
362        if self.signals.iter().any(|(k, _)| k == &kind) {
363            return self;
364        }
365
366        let signal = kind.listen().expect("failed to create signal");
367
368        self.signals.push((kind, signal));
369
370        self
371    }
372
373    /// Wait for a signal to be received.
374    /// This is equivilant to calling (&mut handler).await, but is more
375    /// ergonomic if you want to not take ownership of the handler.
376    pub async fn recv(&mut self) -> SignalKind {
377        self.await
378    }
379
380    /// Poll for a signal to be received.
381    /// Does not require pinning the handler.
382    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
383        for (kind, signal) in self.signals.iter_mut() {
384            if signal.poll_recv(cx).is_ready() {
385                return Poll::Ready(*kind);
386            }
387        }
388
389        Poll::Pending
390    }
391}
392
393impl std::future::Future for SignalHandler {
394    type Output = SignalKind;
395
396    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397        self.poll_recv(cx)
398    }
399}
400
401/// Changelogs generated by [scuffle_changelog]
402#[cfg(feature = "docs")]
403#[scuffle_changelog::changelog]
404pub mod changelog {}
405
406#[cfg(test)]
407#[cfg_attr(coverage_nightly, coverage(off))]
408mod tests {
409    use std::time::Duration;
410
411    use scuffle_future_ext::FutureExt;
412
413    use crate::{SignalHandler, SignalKind};
414
415    #[cfg(windows)]
416    pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
417
418    #[cfg(windows)]
419    impl SignalMocker {
420        fn new() -> Self {
421            println!("new");
422            let (sender, _) = tokio::sync::broadcast::channel(100);
423            Self(sender)
424        }
425
426        fn raise(kind: SignalKind) {
427            println!("raising");
428            SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
429        }
430
431        pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
432            println!("subscribing");
433            SIGNAL_MOCKER.with(|local| local.0.subscribe())
434        }
435    }
436
437    #[cfg(windows)]
438    thread_local! {
439        static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
440    }
441
442    #[cfg(windows)]
443    pub(crate) async fn raise_signal(kind: SignalKind) {
444        SignalMocker::raise(kind);
445    }
446
447    #[cfg(unix)]
448    pub(crate) async fn raise_signal(kind: SignalKind) {
449        // Safety: This is a test, and we control the process.
450        unsafe {
451            libc::raise(match kind {
452                SignalKind::Interrupt => libc::SIGINT,
453                SignalKind::Terminate => libc::SIGTERM,
454                SignalKind::Unix(kind) => kind.as_raw_value(),
455            });
456        }
457    }
458
459    #[cfg(windows)]
460    #[tokio::test]
461    async fn signal_handler() {
462        use crate::WindowsSignalKind;
463
464        let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
465
466        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
467
468        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
469
470        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
471
472        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
473
474        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
475        assert!(recv.is_err(), "expected timeout");
476
477        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
478
479        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
480
481        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
482    }
483
484    #[cfg(windows)]
485    #[tokio::test]
486    async fn add_signal() {
487        use crate::WindowsSignalKind;
488
489        let mut handler = SignalHandler::new();
490
491        handler
492            .add_signal(WindowsSignalKind::CtrlC)
493            .add_signal(WindowsSignalKind::CtrlBreak)
494            .add_signal(WindowsSignalKind::CtrlC);
495
496        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
497
498        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
499
500        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
501
502        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
503
504        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
505
506        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
507
508        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
509    }
510
511    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
512    #[tokio::test]
513    async fn signal_handler() {
514        use crate::UnixSignalKind;
515
516        let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
517            .with_signal(UnixSignalKind::user_defined2())
518            .with_signal(UnixSignalKind::user_defined1());
519
520        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
521
522        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
523
524        assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
525
526        // We already received the signal, so polling again should return Poll::Pending
527        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
528
529        assert!(recv.is_err(), "expected timeout");
530
531        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
532
533        // We should be able to receive the signal again
534        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
535
536        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
537    }
538
539    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
540    #[tokio::test]
541    async fn add_signal() {
542        use crate::UnixSignalKind;
543
544        let mut handler = SignalHandler::new();
545
546        handler
547            .add_signal(UnixSignalKind::user_defined1())
548            .add_signal(UnixSignalKind::user_defined2())
549            .add_signal(UnixSignalKind::user_defined2());
550
551        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
552
553        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
554
555        assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
556
557        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
558
559        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
560
561        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
562    }
563
564    #[cfg(not(valgrind))] // test is time-sensitive
565    #[tokio::test]
566    async fn no_signals() {
567        let mut handler = SignalHandler::default();
568
569        // Expected to timeout
570        assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
571    }
572
573    #[cfg(windows)]
574    #[test]
575    fn signal_kind_eq() {
576        use crate::WindowsSignalKind;
577
578        assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
579        assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
580        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
581        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
582        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
583        assert_eq!(
584            SignalKind::Windows(WindowsSignalKind::CtrlBreak),
585            SignalKind::Windows(WindowsSignalKind::CtrlBreak)
586        );
587    }
588
589    #[cfg(unix)]
590    #[test]
591    fn signal_kind_eq() {
592        use crate::UnixSignalKind;
593
594        assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
595        assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
596        assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
597        assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
598        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
599        assert_eq!(
600            SignalKind::Unix(UnixSignalKind::user_defined1()),
601            SignalKind::Unix(UnixSignalKind::user_defined1())
602        );
603    }
604}