scuffle_signal/
lib.rs

1//! A crate designed to provide a more user friendly interface to
2//! `tokio::signal`.
3#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6//! ## Why do we need this?
7//!
8//! The `tokio::signal` module provides a way for us to wait for a signal to be
9//! received in a non-blocking way. This crate extends that with a more helpful
10//! interface allowing the ability to listen to multiple signals concurrently.
11//!
12//! ## Example
13//!
14//! ```rust
15//! # #[cfg(unix)]
16//! # {
17//! use scuffle_signal::SignalHandler;
18//! use tokio::signal::unix::SignalKind;
19//!
20//! # tokio_test::block_on(async {
21//! let mut handler = SignalHandler::new()
22//!     .with_signal(SignalKind::interrupt())
23//!     .with_signal(SignalKind::terminate());
24//!
25//! # // Safety: This is a test, and we control the process.
26//! # unsafe {
27//! #    libc::raise(SignalKind::interrupt().as_raw_value());
28//! # }
29//! // Wait for a signal to be received
30//! let signal = handler.await;
31//!
32//! // Handle the signal
33//! let interrupt = SignalKind::interrupt();
34//! let terminate = SignalKind::terminate();
35//! match signal {
36//!     interrupt => {
37//!         // Handle SIGINT
38//!         println!("received SIGINT");
39//!     },
40//!     terminate => {
41//!         // Handle SIGTERM
42//!         println!("received SIGTERM");
43//!     },
44//! }
45//! # });
46//! # }
47//! ```
48//!
49//! ## License
50//!
51//! This project is licensed under the MIT or Apache-2.0 license.
52//! You can choose between one of them if you use this work.
53//!
54//! `SPDX-License-Identifier: MIT OR Apache-2.0`
55#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
56#![cfg_attr(docsrs, feature(doc_auto_cfg))]
57#![deny(missing_docs)]
58#![deny(unreachable_pub)]
59#![deny(clippy::mod_module_files)]
60#![deny(clippy::undocumented_unsafe_blocks)]
61#![deny(clippy::multiple_unsafe_ops_per_block)]
62
63use std::pin::Pin;
64use std::task::{Context, Poll};
65
66#[cfg(unix)]
67use tokio::signal::unix;
68#[cfg(unix)]
69pub use tokio::signal::unix::SignalKind as UnixSignalKind;
70
71#[cfg(feature = "bootstrap")]
72mod bootstrap;
73
74#[cfg(feature = "bootstrap")]
75pub use bootstrap::{SignalConfig, SignalSvc};
76
77/// The type of signal to listen for.
78#[derive(Debug, Clone, Copy, Eq)]
79pub enum SignalKind {
80    /// Represents the interrupt signal, which is `SIGINT` on Unix and `Ctrl-C` on Windows.
81    Interrupt,
82    /// Represents the terminate signal, which is `SIGTERM` on Unix and `Ctrl-Close` on Windows.
83    Terminate,
84    /// Represents a Windows-specific signal kind, as defined in `WindowsSignalKind`.
85    #[cfg(windows)]
86    Windows(WindowsSignalKind),
87    /// Represents a Unix-specific signal kind, wrapping `tokio::signal::unix::SignalKind`.
88    #[cfg(unix)]
89    Unix(UnixSignalKind),
90}
91
92impl PartialEq for SignalKind {
93    fn eq(&self, other: &Self) -> bool {
94        #[cfg(unix)]
95        const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
96        #[cfg(unix)]
97        const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
98
99        match (self, other) {
100            #[cfg(windows)]
101            (
102                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
103                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
104            ) => true,
105            #[cfg(windows)]
106            (
107                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
108                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
109            ) => true,
110            #[cfg(windows)]
111            (Self::Windows(a), Self::Windows(b)) => a == b,
112            #[cfg(unix)]
113            (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
114            #[cfg(unix)]
115            (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
116            #[cfg(unix)]
117            (Self::Unix(a), Self::Unix(b)) => a == b,
118            _ => false,
119        }
120    }
121}
122
123#[cfg(unix)]
124impl From<UnixSignalKind> for SignalKind {
125    fn from(value: UnixSignalKind) -> Self {
126        match value {
127            kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
128            kind if kind == UnixSignalKind::terminate() => Self::Terminate,
129            kind => Self::Unix(kind),
130        }
131    }
132}
133
134#[cfg(unix)]
135impl PartialEq<UnixSignalKind> for SignalKind {
136    fn eq(&self, other: &UnixSignalKind) -> bool {
137        match self {
138            Self::Interrupt => other == &UnixSignalKind::interrupt(),
139            Self::Terminate => other == &UnixSignalKind::terminate(),
140            Self::Unix(kind) => kind == other,
141        }
142    }
143}
144
145/// Represents Windows-specific signal kinds.
146#[cfg(windows)]
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum WindowsSignalKind {
149    /// Represents the `Ctrl-Break` signal.
150    CtrlBreak,
151    /// Represents the `Ctrl-C` signal.
152    CtrlC,
153    /// Represents the `Ctrl-Close` signal.
154    CtrlClose,
155    /// Represents the `Ctrl-Logoff` signal.
156    CtrlLogoff,
157    /// Represents the `Ctrl-Shutdown` signal.
158    CtrlShutdown,
159}
160
161#[cfg(windows)]
162impl From<WindowsSignalKind> for SignalKind {
163    fn from(value: WindowsSignalKind) -> Self {
164        match value {
165            WindowsSignalKind::CtrlC => Self::Interrupt,
166            WindowsSignalKind::CtrlClose => Self::Terminate,
167            WindowsSignalKind::CtrlBreak => Self::Windows(value),
168            WindowsSignalKind::CtrlLogoff => Self::Windows(value),
169            WindowsSignalKind::CtrlShutdown => Self::Windows(value),
170        }
171    }
172}
173
174#[cfg(windows)]
175impl PartialEq<WindowsSignalKind> for SignalKind {
176    fn eq(&self, other: &WindowsSignalKind) -> bool {
177        match self {
178            Self::Interrupt => other == &WindowsSignalKind::CtrlC,
179            Self::Terminate => other == &WindowsSignalKind::CtrlClose,
180            Self::Windows(kind) => kind == other,
181        }
182    }
183}
184
185#[cfg(windows)]
186#[derive(Debug)]
187enum WindowsSignalValue {
188    CtrlBreak(tokio::signal::windows::CtrlBreak),
189    CtrlC(tokio::signal::windows::CtrlC),
190    CtrlClose(tokio::signal::windows::CtrlClose),
191    CtrlLogoff(tokio::signal::windows::CtrlLogoff),
192    CtrlShutdown(tokio::signal::windows::CtrlShutdown),
193    #[cfg(test)]
194    Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
195}
196
197#[cfg(windows)]
198impl WindowsSignalValue {
199    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
200        #[cfg(test)]
201        use futures::Stream;
202
203        match self {
204            Self::CtrlBreak(signal) => signal.poll_recv(cx),
205            Self::CtrlC(signal) => signal.poll_recv(cx),
206            Self::CtrlClose(signal) => signal.poll_recv(cx),
207            Self::CtrlLogoff(signal) => signal.poll_recv(cx),
208            Self::CtrlShutdown(signal) => signal.poll_recv(cx),
209            #[cfg(test)]
210            Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
211                Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
212                Poll::Ready(Some(Ok(_))) => {
213                    cx.waker().wake_by_ref();
214                    Poll::Pending
215                }
216                Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
217                Poll::Pending => {
218                    cx.waker().wake_by_ref();
219                    Poll::Pending
220                }
221            },
222        }
223    }
224}
225
226#[cfg(unix)]
227type Signal = unix::Signal;
228
229#[cfg(windows)]
230type Signal = WindowsSignalValue;
231
232impl SignalKind {
233    #[cfg(unix)]
234    fn listen(&self) -> Result<Signal, std::io::Error> {
235        match self {
236            Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
237            Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
238            Self::Unix(kind) => tokio::signal::unix::signal(*kind),
239        }
240    }
241
242    #[cfg(windows)]
243    fn listen(&self) -> Result<Signal, std::io::Error> {
244        #[cfg(test)]
245        if cfg!(test) {
246            return Ok(WindowsSignalValue::Mock(
247                *self,
248                Box::pin(tokio_stream::wrappers::BroadcastStream::new(tests::SignalMocker::subscribe())),
249            ));
250        }
251
252        match self {
253            // https://learn.microsoft.com/en-us/windows/console/ctrl-c-and-ctrl-break-signals
254            Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
255                Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
256            }
257            // https://learn.microsoft.com/en-us/windows/console/ctrl-close-signal
258            Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
259                Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
260            }
261            Self::Windows(WindowsSignalKind::CtrlBreak) => {
262                Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
263            }
264            Self::Windows(WindowsSignalKind::CtrlLogoff) => {
265                Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
266            }
267            Self::Windows(WindowsSignalKind::CtrlShutdown) => {
268                Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
269            }
270        }
271    }
272}
273
274/// A handler for listening to multiple signals, and providing a future for
275/// receiving them.
276///
277/// This is useful for applications that need to listen for multiple signals,
278/// and want to react to them in a non-blocking way. Typically you would need to
279/// use a tokio::select{} to listen for multiple signals, but this provides a
280/// more ergonomic interface for doing so.
281///
282/// After a signal is received you can poll the handler again to wait for
283/// another signal. Dropping the handle will cancel the signal subscription
284///
285/// # Example
286///
287/// ```rust
288/// # #[cfg(unix)]
289/// # {
290/// use scuffle_signal::SignalHandler;
291/// use tokio::signal::unix::SignalKind;
292///
293/// # tokio_test::block_on(async {
294/// let mut handler = SignalHandler::new()
295///     .with_signal(SignalKind::interrupt())
296///     .with_signal(SignalKind::terminate());
297///
298/// # // Safety: This is a test, and we control the process.
299/// # unsafe {
300/// #    libc::raise(SignalKind::interrupt().as_raw_value());
301/// # }
302/// // Wait for a signal to be received
303/// let signal = handler.await;
304///
305/// // Handle the signal
306/// let interrupt = SignalKind::interrupt();
307/// let terminate = SignalKind::terminate();
308/// match signal {
309///     interrupt => {
310///         // Handle SIGINT
311///         println!("received SIGINT");
312///     },
313///     terminate => {
314///         // Handle SIGTERM
315///         println!("received SIGTERM");
316///     },
317/// }
318/// # });
319/// # }
320/// ```
321#[derive(Debug)]
322#[must_use = "signal handlers must be used to wait for signals"]
323pub struct SignalHandler {
324    signals: Vec<(SignalKind, Signal)>,
325}
326
327impl Default for SignalHandler {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl SignalHandler {
334    /// Create a new `SignalHandler` with no signals.
335    pub const fn new() -> Self {
336        Self { signals: Vec::new() }
337    }
338
339    /// Create a new `SignalHandler` with the given signals.
340    pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
341        let mut handler = Self::new();
342
343        for signal in signals {
344            handler = handler.with_signal(signal.into());
345        }
346
347        handler
348    }
349
350    /// Add a signal to the handler.
351    ///
352    /// If the signal is already in the handler, it will not be added again.
353    pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
354        self.add_signal(kind);
355        self
356    }
357
358    /// Add a signal to the handler.
359    ///
360    /// If the signal is already in the handler, it will not be added again.
361    pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
362        let kind = kind.into();
363        if self.signals.iter().any(|(k, _)| k == &kind) {
364            return self;
365        }
366
367        let signal = kind.listen().expect("failed to create signal");
368
369        self.signals.push((kind, signal));
370
371        self
372    }
373
374    /// Wait for a signal to be received.
375    /// This is equivilant to calling (&mut handler).await, but is more
376    /// ergonomic if you want to not take ownership of the handler.
377    pub async fn recv(&mut self) -> SignalKind {
378        self.await
379    }
380
381    /// Poll for a signal to be received.
382    /// Does not require pinning the handler.
383    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
384        for (kind, signal) in self.signals.iter_mut() {
385            if signal.poll_recv(cx).is_ready() {
386                return Poll::Ready(*kind);
387            }
388        }
389
390        Poll::Pending
391    }
392}
393
394impl std::future::Future for SignalHandler {
395    type Output = SignalKind;
396
397    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
398        self.poll_recv(cx)
399    }
400}
401
402/// Changelogs generated by [scuffle_changelog]
403#[cfg(feature = "docs")]
404#[scuffle_changelog::changelog]
405pub mod changelog {}
406
407#[cfg(test)]
408#[cfg_attr(coverage_nightly, coverage(off))]
409mod tests {
410    use std::time::Duration;
411
412    use scuffle_future_ext::FutureExt;
413
414    use crate::{SignalHandler, SignalKind};
415
416    #[cfg(windows)]
417    pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
418
419    #[cfg(windows)]
420    impl SignalMocker {
421        fn new() -> Self {
422            println!("new");
423            let (sender, _) = tokio::sync::broadcast::channel(100);
424            Self(sender)
425        }
426
427        fn raise(kind: SignalKind) {
428            println!("raising");
429            SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
430        }
431
432        pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
433            println!("subscribing");
434            SIGNAL_MOCKER.with(|local| local.0.subscribe())
435        }
436    }
437
438    #[cfg(windows)]
439    thread_local! {
440        static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
441    }
442
443    #[cfg(windows)]
444    pub(crate) async fn raise_signal(kind: SignalKind) {
445        SignalMocker::raise(kind);
446    }
447
448    #[cfg(unix)]
449    pub(crate) async fn raise_signal(kind: SignalKind) {
450        // Safety: This is a test, and we control the process.
451        unsafe {
452            libc::raise(match kind {
453                SignalKind::Interrupt => libc::SIGINT,
454                SignalKind::Terminate => libc::SIGTERM,
455                SignalKind::Unix(kind) => kind.as_raw_value(),
456            });
457        }
458    }
459
460    #[cfg(windows)]
461    #[tokio::test]
462    async fn signal_handler() {
463        use crate::WindowsSignalKind;
464
465        let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
466
467        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
468
469        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
470
471        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
472
473        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
474
475        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
476        assert!(recv.is_err(), "expected timeout");
477
478        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
479
480        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
481
482        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
483    }
484
485    #[cfg(windows)]
486    #[tokio::test]
487    async fn add_signal() {
488        use crate::WindowsSignalKind;
489
490        let mut handler = SignalHandler::new();
491
492        handler
493            .add_signal(WindowsSignalKind::CtrlC)
494            .add_signal(WindowsSignalKind::CtrlBreak)
495            .add_signal(WindowsSignalKind::CtrlC);
496
497        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
498
499        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
500
501        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
502
503        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
504
505        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
506
507        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
508
509        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
510    }
511
512    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
513    #[tokio::test]
514    async fn signal_handler() {
515        use crate::UnixSignalKind;
516
517        let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
518            .with_signal(UnixSignalKind::user_defined2())
519            .with_signal(UnixSignalKind::user_defined1());
520
521        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
522
523        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
524
525        assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
526
527        // We already received the signal, so polling again should return Poll::Pending
528        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
529
530        assert!(recv.is_err(), "expected timeout");
531
532        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
533
534        // We should be able to receive the signal again
535        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
536
537        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
538    }
539
540    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
541    #[tokio::test]
542    async fn add_signal() {
543        use crate::UnixSignalKind;
544
545        let mut handler = SignalHandler::new();
546
547        handler
548            .add_signal(UnixSignalKind::user_defined1())
549            .add_signal(UnixSignalKind::user_defined2())
550            .add_signal(UnixSignalKind::user_defined2());
551
552        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
553
554        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
555
556        assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
557
558        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
559
560        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
561
562        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
563    }
564
565    #[cfg(not(valgrind))] // test is time-sensitive
566    #[tokio::test]
567    async fn no_signals() {
568        let mut handler = SignalHandler::default();
569
570        // Expected to timeout
571        assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
572    }
573
574    #[cfg(windows)]
575    #[test]
576    fn signal_kind_eq() {
577        use crate::WindowsSignalKind;
578
579        assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
580        assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
581        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
582        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
583        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
584        assert_eq!(
585            SignalKind::Windows(WindowsSignalKind::CtrlBreak),
586            SignalKind::Windows(WindowsSignalKind::CtrlBreak)
587        );
588    }
589
590    #[cfg(unix)]
591    #[test]
592    fn signal_kind_eq() {
593        use crate::UnixSignalKind;
594
595        assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
596        assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
597        assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
598        assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
599        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
600        assert_eq!(
601            SignalKind::Unix(UnixSignalKind::user_defined1()),
602            SignalKind::Unix(UnixSignalKind::user_defined1())
603        );
604    }
605}