1#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
56#![cfg_attr(docsrs, feature(doc_auto_cfg))]
57#![deny(missing_docs)]
58#![deny(unreachable_pub)]
59#![deny(clippy::undocumented_unsafe_blocks)]
60#![deny(clippy::multiple_unsafe_ops_per_block)]
61
62use std::pin::Pin;
63use std::task::{Context, Poll};
64
65#[cfg(unix)]
66use tokio::signal::unix;
67#[cfg(unix)]
68pub use tokio::signal::unix::SignalKind as UnixSignalKind;
69
70#[cfg(feature = "bootstrap")]
71mod bootstrap;
72
73#[cfg(feature = "bootstrap")]
74pub use bootstrap::{SignalConfig, SignalSvc};
75
76#[derive(Debug, Clone, Copy, Eq)]
78pub enum SignalKind {
79 Interrupt,
81 Terminate,
83 #[cfg(windows)]
85 Windows(WindowsSignalKind),
86 #[cfg(unix)]
88 Unix(UnixSignalKind),
89}
90
91impl PartialEq for SignalKind {
92 fn eq(&self, other: &Self) -> bool {
93 #[cfg(unix)]
94 const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
95 #[cfg(unix)]
96 const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
97
98 match (self, other) {
99 #[cfg(windows)]
100 (
101 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
102 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
103 ) => true,
104 #[cfg(windows)]
105 (
106 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
107 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
108 ) => true,
109 #[cfg(windows)]
110 (Self::Windows(a), Self::Windows(b)) => a == b,
111 #[cfg(unix)]
112 (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
113 #[cfg(unix)]
114 (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
115 #[cfg(unix)]
116 (Self::Unix(a), Self::Unix(b)) => a == b,
117 _ => false,
118 }
119 }
120}
121
122#[cfg(unix)]
123impl From<UnixSignalKind> for SignalKind {
124 fn from(value: UnixSignalKind) -> Self {
125 match value {
126 kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
127 kind if kind == UnixSignalKind::terminate() => Self::Terminate,
128 kind => Self::Unix(kind),
129 }
130 }
131}
132
133#[cfg(unix)]
134impl PartialEq<UnixSignalKind> for SignalKind {
135 fn eq(&self, other: &UnixSignalKind) -> bool {
136 match self {
137 Self::Interrupt => other == &UnixSignalKind::interrupt(),
138 Self::Terminate => other == &UnixSignalKind::terminate(),
139 Self::Unix(kind) => kind == other,
140 }
141 }
142}
143
144#[cfg(windows)]
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum WindowsSignalKind {
148 CtrlBreak,
150 CtrlC,
152 CtrlClose,
154 CtrlLogoff,
156 CtrlShutdown,
158}
159
160#[cfg(windows)]
161impl From<WindowsSignalKind> for SignalKind {
162 fn from(value: WindowsSignalKind) -> Self {
163 match value {
164 WindowsSignalKind::CtrlC => Self::Interrupt,
165 WindowsSignalKind::CtrlClose => Self::Terminate,
166 WindowsSignalKind::CtrlBreak => Self::Windows(value),
167 WindowsSignalKind::CtrlLogoff => Self::Windows(value),
168 WindowsSignalKind::CtrlShutdown => Self::Windows(value),
169 }
170 }
171}
172
173#[cfg(windows)]
174impl PartialEq<WindowsSignalKind> for SignalKind {
175 fn eq(&self, other: &WindowsSignalKind) -> bool {
176 match self {
177 Self::Interrupt => other == &WindowsSignalKind::CtrlC,
178 Self::Terminate => other == &WindowsSignalKind::CtrlClose,
179 Self::Windows(kind) => kind == other,
180 }
181 }
182}
183
184#[cfg(windows)]
185#[derive(Debug)]
186enum WindowsSignalValue {
187 CtrlBreak(tokio::signal::windows::CtrlBreak),
188 CtrlC(tokio::signal::windows::CtrlC),
189 CtrlClose(tokio::signal::windows::CtrlClose),
190 CtrlLogoff(tokio::signal::windows::CtrlLogoff),
191 CtrlShutdown(tokio::signal::windows::CtrlShutdown),
192 #[cfg(test)]
193 Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
194}
195
196#[cfg(windows)]
197impl WindowsSignalValue {
198 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
199 #[cfg(test)]
200 use futures::Stream;
201
202 match self {
203 Self::CtrlBreak(signal) => signal.poll_recv(cx),
204 Self::CtrlC(signal) => signal.poll_recv(cx),
205 Self::CtrlClose(signal) => signal.poll_recv(cx),
206 Self::CtrlLogoff(signal) => signal.poll_recv(cx),
207 Self::CtrlShutdown(signal) => signal.poll_recv(cx),
208 #[cfg(test)]
209 Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
210 Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
211 Poll::Ready(Some(Ok(_))) => {
212 cx.waker().wake_by_ref();
213 Poll::Pending
214 }
215 Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
216 Poll::Pending => {
217 cx.waker().wake_by_ref();
218 Poll::Pending
219 }
220 },
221 }
222 }
223}
224
225#[cfg(unix)]
226type Signal = unix::Signal;
227
228#[cfg(windows)]
229type Signal = WindowsSignalValue;
230
231impl SignalKind {
232 #[cfg(unix)]
233 fn listen(&self) -> Result<Signal, std::io::Error> {
234 match self {
235 Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
236 Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
237 Self::Unix(kind) => tokio::signal::unix::signal(*kind),
238 }
239 }
240
241 #[cfg(windows)]
242 fn listen(&self) -> Result<Signal, std::io::Error> {
243 #[cfg(test)]
244 if cfg!(test) {
245 return Ok(WindowsSignalValue::Mock(
246 *self,
247 Box::pin(tokio_stream::wrappers::BroadcastStream::new(test::SignalMocker::subscribe())),
248 ));
249 }
250
251 match self {
252 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
254 Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
255 }
256 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
258 Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
259 }
260 Self::Windows(WindowsSignalKind::CtrlBreak) => {
261 Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
262 }
263 Self::Windows(WindowsSignalKind::CtrlLogoff) => {
264 Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
265 }
266 Self::Windows(WindowsSignalKind::CtrlShutdown) => {
267 Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
268 }
269 }
270 }
271}
272
273#[derive(Debug)]
321#[must_use = "signal handlers must be used to wait for signals"]
322pub struct SignalHandler {
323 signals: Vec<(SignalKind, Signal)>,
324}
325
326impl Default for SignalHandler {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332impl SignalHandler {
333 pub const fn new() -> Self {
335 Self { signals: Vec::new() }
336 }
337
338 pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
340 let mut handler = Self::new();
341
342 for signal in signals {
343 handler = handler.with_signal(signal.into());
344 }
345
346 handler
347 }
348
349 pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
353 self.add_signal(kind);
354 self
355 }
356
357 pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
361 let kind = kind.into();
362 if self.signals.iter().any(|(k, _)| k == &kind) {
363 return self;
364 }
365
366 let signal = kind.listen().expect("failed to create signal");
367
368 self.signals.push((kind, signal));
369
370 self
371 }
372
373 pub async fn recv(&mut self) -> SignalKind {
377 self.await
378 }
379
380 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
383 for (kind, signal) in self.signals.iter_mut() {
384 if signal.poll_recv(cx).is_ready() {
385 return Poll::Ready(*kind);
386 }
387 }
388
389 Poll::Pending
390 }
391}
392
393impl std::future::Future for SignalHandler {
394 type Output = SignalKind;
395
396 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397 self.poll_recv(cx)
398 }
399}
400
401#[cfg(feature = "docs")]
403#[scuffle_changelog::changelog]
404pub mod changelog {}
405
406#[cfg(test)]
407#[cfg_attr(coverage_nightly, coverage(off))]
408mod tests {
409 use std::time::Duration;
410
411 use scuffle_future_ext::FutureExt;
412
413 use crate::{SignalHandler, SignalKind};
414
415 #[cfg(windows)]
416 pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
417
418 #[cfg(windows)]
419 impl SignalMocker {
420 fn new() -> Self {
421 println!("new");
422 let (sender, _) = tokio::sync::broadcast::channel(100);
423 Self(sender)
424 }
425
426 fn raise(kind: SignalKind) {
427 println!("raising");
428 SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
429 }
430
431 pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
432 println!("subscribing");
433 SIGNAL_MOCKER.with(|local| local.0.subscribe())
434 }
435 }
436
437 #[cfg(windows)]
438 thread_local! {
439 static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
440 }
441
442 #[cfg(windows)]
443 pub(crate) async fn raise_signal(kind: SignalKind) {
444 SignalMocker::raise(kind);
445 }
446
447 #[cfg(unix)]
448 pub(crate) async fn raise_signal(kind: SignalKind) {
449 unsafe {
451 libc::raise(match kind {
452 SignalKind::Interrupt => libc::SIGINT,
453 SignalKind::Terminate => libc::SIGTERM,
454 SignalKind::Unix(kind) => kind.as_raw_value(),
455 });
456 }
457 }
458
459 #[cfg(windows)]
460 #[tokio::test]
461 async fn signal_handler() {
462 use crate::WindowsSignalKind;
463
464 let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
465
466 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
467
468 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
469
470 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
471
472 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
473
474 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
475 assert!(recv.is_err(), "expected timeout");
476
477 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
478
479 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
480
481 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
482 }
483
484 #[cfg(windows)]
485 #[tokio::test]
486 async fn add_signal() {
487 use crate::WindowsSignalKind;
488
489 let mut handler = SignalHandler::new();
490
491 handler
492 .add_signal(WindowsSignalKind::CtrlC)
493 .add_signal(WindowsSignalKind::CtrlBreak)
494 .add_signal(WindowsSignalKind::CtrlC);
495
496 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
497
498 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
499
500 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
501
502 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
503
504 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
505
506 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
507
508 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
509 }
510
511 #[cfg(all(not(valgrind), unix))] #[tokio::test]
513 async fn signal_handler() {
514 use crate::UnixSignalKind;
515
516 let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
517 .with_signal(UnixSignalKind::user_defined2())
518 .with_signal(UnixSignalKind::user_defined1());
519
520 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
521
522 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
523
524 assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
525
526 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
528
529 assert!(recv.is_err(), "expected timeout");
530
531 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
532
533 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
535
536 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
537 }
538
539 #[cfg(all(not(valgrind), unix))] #[tokio::test]
541 async fn add_signal() {
542 use crate::UnixSignalKind;
543
544 let mut handler = SignalHandler::new();
545
546 handler
547 .add_signal(UnixSignalKind::user_defined1())
548 .add_signal(UnixSignalKind::user_defined2())
549 .add_signal(UnixSignalKind::user_defined2());
550
551 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
552
553 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
554
555 assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
556
557 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
558
559 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
560
561 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
562 }
563
564 #[cfg(not(valgrind))] #[tokio::test]
566 async fn no_signals() {
567 let mut handler = SignalHandler::default();
568
569 assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
571 }
572
573 #[cfg(windows)]
574 #[test]
575 fn signal_kind_eq() {
576 use crate::WindowsSignalKind;
577
578 assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
579 assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
580 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
581 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
582 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
583 assert_eq!(
584 SignalKind::Windows(WindowsSignalKind::CtrlBreak),
585 SignalKind::Windows(WindowsSignalKind::CtrlBreak)
586 );
587 }
588
589 #[cfg(unix)]
590 #[test]
591 fn signal_kind_eq() {
592 use crate::UnixSignalKind;
593
594 assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
595 assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
596 assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
597 assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
598 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
599 assert_eq!(
600 SignalKind::Unix(UnixSignalKind::user_defined1()),
601 SignalKind::Unix(UnixSignalKind::user_defined1())
602 );
603 }
604}