scuffle_context/
lib.rs

1//! A crate designed to provide the ability to cancel futures using a context
2//! go-like approach, allowing for graceful shutdowns and cancellations.
3#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6//! ## Why do we need this?
7//!
8//! Its often useful to wait for all the futures to shutdown or to cancel them
9//! when we no longer care about the results. This crate provides an interface
10//! to cancel all futures associated with a context or wait for them to finish
11//! before shutting down. Allowing for graceful shutdowns and cancellations.
12//!
13//! ## Usage
14//!
15//! Here is an example of how to use the `Context` to cancel a spawned task.
16//!
17//! ```rust
18//! # use scuffle_context::{Context, ContextFutExt};
19//! # tokio_test::block_on(async {
20//! let (ctx, handler) = Context::new();
21//!
22//! tokio::spawn(async {
23//!     // Do some work
24//!     tokio::time::sleep(std::time::Duration::from_secs(10)).await;
25//! }.with_context(ctx));
26//!
27//! // Will stop the spawned task and cancel all associated futures.
28//! handler.cancel();
29//! # });
30//! ```
31//!
32//! ## License
33//!
34//! This project is licensed under the MIT or Apache-2.0 license.
35//! You can choose between one of them if you use this work.
36//!
37//! `SPDX-License-Identifier: MIT OR Apache-2.0`
38#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
39#![cfg_attr(docsrs, feature(doc_auto_cfg))]
40#![deny(missing_docs)]
41#![deny(unsafe_code)]
42#![deny(unreachable_pub)]
43#![deny(clippy::mod_module_files)]
44
45use std::sync::Arc;
46use std::sync::atomic::{AtomicBool, AtomicUsize};
47
48use tokio_util::sync::CancellationToken;
49
50/// For extending types.
51mod ext;
52
53pub use ext::*;
54
55/// Create by calling [`ContextTrackerInner::child`].
56#[derive(Debug)]
57struct ContextTracker(Arc<ContextTrackerInner>);
58
59impl Drop for ContextTracker {
60    fn drop(&mut self) {
61        let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
62        // If this was the last active `ContextTracker` and the context has been
63        // stopped, then notify the waiters
64        if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
65            self.0.notify.notify_waiters();
66        }
67    }
68}
69
70#[derive(Debug)]
71struct ContextTrackerInner {
72    stopped: AtomicBool,
73    /// This count keeps track of the number of `ContextTrackers` that exist for
74    /// this `ContextTrackerInner`.
75    active_count: AtomicUsize,
76    notify: tokio::sync::Notify,
77}
78
79impl ContextTrackerInner {
80    fn new() -> Arc<Self> {
81        Arc::new(Self {
82            stopped: AtomicBool::new(false),
83            active_count: AtomicUsize::new(0),
84            notify: tokio::sync::Notify::new(),
85        })
86    }
87
88    /// Create a new `ContextTracker` from an `Arc<ContextTrackerInner>`.
89    fn child(self: &Arc<Self>) -> ContextTracker {
90        self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
91        ContextTracker(Arc::clone(self))
92    }
93
94    /// Mark this `ContextTrackerInner` as stopped.
95    fn stop(&self) {
96        self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
97    }
98
99    /// Wait for this `ContextTrackerInner` to be stopped and all associated
100    /// `ContextTracker`s to be dropped.
101    async fn wait(&self) {
102        let notify = self.notify.notified();
103
104        // If there are no active children, then the notify will never be called
105        if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
106            return;
107        }
108
109        notify.await;
110    }
111}
112
113/// A context for cancelling futures and waiting for shutdown.
114///
115/// A context can be created from a handler by calling [`Handler::context`] or
116/// from another context by calling [`Context::new_child`] so to have a
117/// hierarchy of contexts.
118///
119/// Contexts can then be attached to futures or streams in order to
120/// automatically cancel them when the context is done, when invoking
121/// [`Handler::cancel`].
122/// The [`Handler::shutdown`] method will block until all contexts have been
123/// dropped allowing for a graceful shutdown.
124#[derive(Debug)]
125pub struct Context {
126    token: CancellationToken,
127    tracker: ContextTracker,
128}
129
130impl Clone for Context {
131    fn clone(&self) -> Self {
132        Self {
133            token: self.token.clone(),
134            tracker: self.tracker.0.child(),
135        }
136    }
137}
138
139impl Context {
140    #[must_use]
141    /// Create a new context using the global handler.
142    /// Returns a child context and child handler of the global handler.
143    pub fn new() -> (Self, Handler) {
144        Handler::global().new_child()
145    }
146
147    #[must_use]
148    /// Create a new child context from this context.
149    /// Returns a new child context and child handler of this context.
150    ///
151    /// # Example
152    ///
153    /// ```rust
154    /// use scuffle_context::Context;
155    ///
156    /// let (parent, parent_handler) = Context::new();
157    /// let (child, child_handler) = parent.new_child();
158    /// ```
159    pub fn new_child(&self) -> (Self, Handler) {
160        let token = self.token.child_token();
161        let tracker = ContextTrackerInner::new();
162
163        (
164            Self {
165                tracker: tracker.child(),
166                token: token.clone(),
167            },
168            Handler {
169                token: Arc::new(TokenDropGuard(token)),
170                tracker,
171            },
172        )
173    }
174
175    #[must_use]
176    /// Returns the global context
177    pub fn global() -> Self {
178        Handler::global().context()
179    }
180
181    /// Wait for the context to be done (the handler to be shutdown).
182    pub async fn done(&self) {
183        self.token.cancelled().await;
184    }
185
186    /// The same as [`Context::done`] but takes ownership of the context.
187    pub async fn into_done(self) {
188        self.done().await;
189    }
190
191    /// Returns true if the context is done.
192    #[must_use]
193    pub fn is_done(&self) -> bool {
194        self.token.is_cancelled()
195    }
196}
197
198/// A wrapper type around [`CancellationToken`] that will cancel the token as
199/// soon as it is dropped.
200#[derive(Debug)]
201struct TokenDropGuard(CancellationToken);
202
203impl TokenDropGuard {
204    #[must_use]
205    fn child(&self) -> CancellationToken {
206        self.0.child_token()
207    }
208
209    fn cancel(&self) {
210        self.0.cancel();
211    }
212}
213
214impl Drop for TokenDropGuard {
215    fn drop(&mut self) {
216        self.cancel();
217    }
218}
219
220/// A handler is used to manage contexts and to cancel them.
221#[derive(Debug, Clone)]
222pub struct Handler {
223    token: Arc<TokenDropGuard>,
224    tracker: Arc<ContextTrackerInner>,
225}
226
227impl Default for Handler {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl Handler {
234    #[must_use]
235    /// Create a new handler.
236    pub fn new() -> Handler {
237        let token = CancellationToken::new();
238        let tracker = ContextTrackerInner::new();
239
240        Handler {
241            token: Arc::new(TokenDropGuard(token)),
242            tracker,
243        }
244    }
245
246    #[must_use]
247    /// Returns the global handler.
248    pub fn global() -> &'static Self {
249        static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
250
251        GLOBAL.get_or_init(Handler::new)
252    }
253
254    /// Shutdown the handler and wait for all contexts to be done.
255    pub async fn shutdown(&self) {
256        self.cancel();
257        self.done().await;
258    }
259
260    /// Waits for the handler to be done (waiting for all contexts to be done).
261    pub async fn done(&self) {
262        self.token.0.cancelled().await;
263        self.wait().await;
264    }
265
266    /// Waits for the handler to be done (waiting for all contexts to be done).
267    /// Returns once all contexts are done, even if the handler is not done and
268    /// contexts can be created after this call.
269    pub async fn wait(&self) {
270        self.tracker.wait().await;
271    }
272
273    #[must_use]
274    /// Create a new context from this handler.
275    pub fn context(&self) -> Context {
276        Context {
277            token: self.token.child(),
278            tracker: self.tracker.child(),
279        }
280    }
281
282    #[must_use]
283    /// Create a new child context from this handler
284    pub fn new_child(&self) -> (Context, Handler) {
285        self.context().new_child()
286    }
287
288    /// Cancel the handler.
289    pub fn cancel(&self) {
290        self.tracker.stop();
291        self.token.cancel();
292    }
293
294    /// Returns true if the handler is done.
295    pub fn is_done(&self) -> bool {
296        self.token.0.is_cancelled()
297    }
298}
299
300#[cfg_attr(all(coverage_nightly, test), coverage(off))]
301#[cfg(test)]
302mod tests {
303    use scuffle_future_ext::FutureExt;
304
305    use crate::{Context, Handler};
306
307    #[tokio::test]
308    async fn new() {
309        let (ctx, handler) = Context::new();
310        assert!(!handler.is_done());
311        assert!(!ctx.is_done());
312
313        let handler = Handler::default();
314        assert!(!handler.is_done());
315    }
316
317    #[tokio::test]
318    async fn cancel() {
319        let (ctx, handler) = Context::new();
320        let (child_ctx, child_handler) = ctx.new_child();
321        let child_ctx2 = ctx.clone();
322
323        assert!(!handler.is_done());
324        assert!(!ctx.is_done());
325        assert!(!child_handler.is_done());
326        assert!(!child_ctx.is_done());
327        assert!(!child_ctx2.is_done());
328
329        handler.cancel();
330
331        assert!(handler.is_done());
332        assert!(ctx.is_done());
333        assert!(child_handler.is_done());
334        assert!(child_ctx.is_done());
335        assert!(child_ctx2.is_done());
336    }
337
338    #[tokio::test]
339    async fn cancel_child() {
340        let (ctx, handler) = Context::new();
341        let (child_ctx, child_handler) = ctx.new_child();
342
343        assert!(!handler.is_done());
344        assert!(!ctx.is_done());
345        assert!(!child_handler.is_done());
346        assert!(!child_ctx.is_done());
347
348        child_handler.cancel();
349
350        assert!(!handler.is_done());
351        assert!(!ctx.is_done());
352        assert!(child_handler.is_done());
353        assert!(child_ctx.is_done());
354    }
355
356    #[tokio::test]
357    async fn shutdown() {
358        let (ctx, handler) = Context::new();
359
360        assert!(!handler.is_done());
361        assert!(!ctx.is_done());
362
363        // This is expected to timeout
364        assert!(
365            handler
366                .shutdown()
367                .with_timeout(std::time::Duration::from_millis(200))
368                .await
369                .is_err()
370        );
371        assert!(handler.is_done());
372        assert!(ctx.is_done());
373        assert!(
374            ctx.into_done()
375                .with_timeout(std::time::Duration::from_millis(200))
376                .await
377                .is_ok()
378        );
379
380        assert!(
381            handler
382                .shutdown()
383                .with_timeout(std::time::Duration::from_millis(200))
384                .await
385                .is_ok()
386        );
387        assert!(
388            handler
389                .wait()
390                .with_timeout(std::time::Duration::from_millis(200))
391                .await
392                .is_ok()
393        );
394        assert!(
395            handler
396                .done()
397                .with_timeout(std::time::Duration::from_millis(200))
398                .await
399                .is_ok()
400        );
401        assert!(handler.is_done());
402    }
403
404    #[tokio::test]
405    async fn global_handler() {
406        let handler = Handler::global();
407
408        assert!(!handler.is_done());
409
410        handler.cancel();
411
412        assert!(handler.is_done());
413        assert!(Handler::global().is_done());
414        assert!(Context::global().is_done());
415
416        let (child_ctx, child_handler) = Handler::global().new_child();
417        assert!(child_handler.is_done());
418        assert!(child_ctx.is_done());
419    }
420}
421
422/// Changelogs generated by [scuffle_changelog]
423#[cfg(feature = "docs")]
424#[scuffle_changelog::changelog]
425pub mod changelog {}