1#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
39#![cfg_attr(docsrs, feature(doc_auto_cfg))]
40#![deny(missing_docs)]
41#![deny(unsafe_code)]
42#![deny(unreachable_pub)]
43#![deny(clippy::mod_module_files)]
44
45use std::sync::Arc;
46use std::sync::atomic::{AtomicBool, AtomicUsize};
47
48use tokio_util::sync::CancellationToken;
49
50mod ext;
52
53pub use ext::*;
54
55#[derive(Debug)]
57struct ContextTracker(Arc<ContextTrackerInner>);
58
59impl Drop for ContextTracker {
60 fn drop(&mut self) {
61 let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
62 if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
65 self.0.notify.notify_waiters();
66 }
67 }
68}
69
70#[derive(Debug)]
71struct ContextTrackerInner {
72 stopped: AtomicBool,
73 active_count: AtomicUsize,
76 notify: tokio::sync::Notify,
77}
78
79impl ContextTrackerInner {
80 fn new() -> Arc<Self> {
81 Arc::new(Self {
82 stopped: AtomicBool::new(false),
83 active_count: AtomicUsize::new(0),
84 notify: tokio::sync::Notify::new(),
85 })
86 }
87
88 fn child(self: &Arc<Self>) -> ContextTracker {
90 self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
91 ContextTracker(Arc::clone(self))
92 }
93
94 fn stop(&self) {
96 self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
97 }
98
99 async fn wait(&self) {
102 let notify = self.notify.notified();
103
104 if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
106 return;
107 }
108
109 notify.await;
110 }
111}
112
113#[derive(Debug)]
125pub struct Context {
126 token: CancellationToken,
127 tracker: ContextTracker,
128}
129
130impl Clone for Context {
131 fn clone(&self) -> Self {
132 Self {
133 token: self.token.clone(),
134 tracker: self.tracker.0.child(),
135 }
136 }
137}
138
139impl Context {
140 #[must_use]
141 pub fn new() -> (Self, Handler) {
144 Handler::global().new_child()
145 }
146
147 #[must_use]
148 pub fn new_child(&self) -> (Self, Handler) {
160 let token = self.token.child_token();
161 let tracker = ContextTrackerInner::new();
162
163 (
164 Self {
165 tracker: tracker.child(),
166 token: token.clone(),
167 },
168 Handler {
169 token: Arc::new(TokenDropGuard(token)),
170 tracker,
171 },
172 )
173 }
174
175 #[must_use]
176 pub fn global() -> Self {
178 Handler::global().context()
179 }
180
181 pub async fn done(&self) {
183 self.token.cancelled().await;
184 }
185
186 pub async fn into_done(self) {
188 self.done().await;
189 }
190
191 #[must_use]
193 pub fn is_done(&self) -> bool {
194 self.token.is_cancelled()
195 }
196}
197
198#[derive(Debug)]
201struct TokenDropGuard(CancellationToken);
202
203impl TokenDropGuard {
204 #[must_use]
205 fn child(&self) -> CancellationToken {
206 self.0.child_token()
207 }
208
209 fn cancel(&self) {
210 self.0.cancel();
211 }
212}
213
214impl Drop for TokenDropGuard {
215 fn drop(&mut self) {
216 self.cancel();
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct Handler {
223 token: Arc<TokenDropGuard>,
224 tracker: Arc<ContextTrackerInner>,
225}
226
227impl Default for Handler {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl Handler {
234 #[must_use]
235 pub fn new() -> Handler {
237 let token = CancellationToken::new();
238 let tracker = ContextTrackerInner::new();
239
240 Handler {
241 token: Arc::new(TokenDropGuard(token)),
242 tracker,
243 }
244 }
245
246 #[must_use]
247 pub fn global() -> &'static Self {
249 static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
250
251 GLOBAL.get_or_init(Handler::new)
252 }
253
254 pub async fn shutdown(&self) {
256 self.cancel();
257 self.done().await;
258 }
259
260 pub async fn done(&self) {
262 self.token.0.cancelled().await;
263 self.wait().await;
264 }
265
266 pub async fn wait(&self) {
270 self.tracker.wait().await;
271 }
272
273 #[must_use]
274 pub fn context(&self) -> Context {
276 Context {
277 token: self.token.child(),
278 tracker: self.tracker.child(),
279 }
280 }
281
282 #[must_use]
283 pub fn new_child(&self) -> (Context, Handler) {
285 self.context().new_child()
286 }
287
288 pub fn cancel(&self) {
290 self.tracker.stop();
291 self.token.cancel();
292 }
293
294 pub fn is_done(&self) -> bool {
296 self.token.0.is_cancelled()
297 }
298}
299
300#[cfg_attr(all(coverage_nightly, test), coverage(off))]
301#[cfg(test)]
302mod tests {
303 use scuffle_future_ext::FutureExt;
304
305 use crate::{Context, Handler};
306
307 #[tokio::test]
308 async fn new() {
309 let (ctx, handler) = Context::new();
310 assert!(!handler.is_done());
311 assert!(!ctx.is_done());
312
313 let handler = Handler::default();
314 assert!(!handler.is_done());
315 }
316
317 #[tokio::test]
318 async fn cancel() {
319 let (ctx, handler) = Context::new();
320 let (child_ctx, child_handler) = ctx.new_child();
321 let child_ctx2 = ctx.clone();
322
323 assert!(!handler.is_done());
324 assert!(!ctx.is_done());
325 assert!(!child_handler.is_done());
326 assert!(!child_ctx.is_done());
327 assert!(!child_ctx2.is_done());
328
329 handler.cancel();
330
331 assert!(handler.is_done());
332 assert!(ctx.is_done());
333 assert!(child_handler.is_done());
334 assert!(child_ctx.is_done());
335 assert!(child_ctx2.is_done());
336 }
337
338 #[tokio::test]
339 async fn cancel_child() {
340 let (ctx, handler) = Context::new();
341 let (child_ctx, child_handler) = ctx.new_child();
342
343 assert!(!handler.is_done());
344 assert!(!ctx.is_done());
345 assert!(!child_handler.is_done());
346 assert!(!child_ctx.is_done());
347
348 child_handler.cancel();
349
350 assert!(!handler.is_done());
351 assert!(!ctx.is_done());
352 assert!(child_handler.is_done());
353 assert!(child_ctx.is_done());
354 }
355
356 #[tokio::test]
357 async fn shutdown() {
358 let (ctx, handler) = Context::new();
359
360 assert!(!handler.is_done());
361 assert!(!ctx.is_done());
362
363 assert!(
365 handler
366 .shutdown()
367 .with_timeout(std::time::Duration::from_millis(200))
368 .await
369 .is_err()
370 );
371 assert!(handler.is_done());
372 assert!(ctx.is_done());
373 assert!(
374 ctx.into_done()
375 .with_timeout(std::time::Duration::from_millis(200))
376 .await
377 .is_ok()
378 );
379
380 assert!(
381 handler
382 .shutdown()
383 .with_timeout(std::time::Duration::from_millis(200))
384 .await
385 .is_ok()
386 );
387 assert!(
388 handler
389 .wait()
390 .with_timeout(std::time::Duration::from_millis(200))
391 .await
392 .is_ok()
393 );
394 assert!(
395 handler
396 .done()
397 .with_timeout(std::time::Duration::from_millis(200))
398 .await
399 .is_ok()
400 );
401 assert!(handler.is_done());
402 }
403
404 #[tokio::test]
405 async fn global_handler() {
406 let handler = Handler::global();
407
408 assert!(!handler.is_done());
409
410 handler.cancel();
411
412 assert!(handler.is_done());
413 assert!(Handler::global().is_done());
414 assert!(Context::global().is_done());
415
416 let (child_ctx, child_handler) = Handler::global().new_child();
417 assert!(child_handler.is_done());
418 assert!(child_ctx.is_done());
419 }
420}
421
422#[cfg(feature = "docs")]
424#[scuffle_changelog::changelog]
425pub mod changelog {}