diff --git a/actix-rt/src/arbiter.rs b/actix-rt/src/arbiter.rs index 1b94df2e..172d4413 100644 --- a/actix-rt/src/arbiter.rs +++ b/actix-rt/src/arbiter.rs @@ -9,14 +9,18 @@ use std::{fmt, thread}; use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures_channel::oneshot::{channel, Canceled, Sender}; -use tokio::stream::Stream; -use tokio::task::LocalSet; +use futures_util::stream::FuturesUnordered; +use tokio::stream::{Stream, StreamExt}; +use tokio::task::{JoinHandle, LocalSet}; use crate::runtime::Runtime; use crate::system::System; thread_local!( static ADDR: RefCell> = RefCell::new(None); + /// stores join handle for spawned async tasks. + static HANDLE: RefCell>> = + RefCell::new(FuturesUnordered::new()); static STORAGE: RefCell>> = RefCell::new(HashMap::new()); ); @@ -146,7 +150,11 @@ impl Arbiter { where F: Future + 'static, { - tokio::task::spawn_local(future); + HANDLE.with(|handle| { + let handle = handle.borrow(); + handle.push(tokio::task::spawn_local(future)); + }); + let _ = tokio::task::spawn_local(CleanupPending); } /// Executes a future on the current thread. This does not create a new Arbiter @@ -266,9 +274,27 @@ impl Arbiter { /// Returns a future that will be completed once all currently spawned futures /// have completed. - #[deprecated(note = "local_join has been removed")] - pub async fn local_join() { - unimplemented!() + pub fn local_join() -> impl Future { + let handle = HANDLE.with(|fut| std::mem::take(&mut *fut.borrow_mut())); + async move { + handle.collect::>().await; + } + } +} + +/// Future used for cleaning-up already finished `JoinHandle`s +/// from the `PENDING` list so the vector doesn't grow indefinitely +struct CleanupPending; + +impl Future for CleanupPending { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + HANDLE.with(move |handle| { + let _ = Pin::new(&mut *handle.borrow_mut()).poll_next(cx); + }); + + Poll::Ready(()) } } @@ -299,7 +325,11 @@ impl Future for ArbiterController { Poll::Ready(Some(item)) => match item { ArbiterCommand::Stop => return Poll::Ready(()), ArbiterCommand::Execute(fut) => { - tokio::task::spawn_local(fut); + HANDLE.with(|handle| { + let mut handle = handle.borrow_mut(); + handle.push(tokio::task::spawn_local(fut)); + let _ = Pin::new(&mut *handle).poll_next(cx); + }); } ArbiterCommand::ExecuteFn(f) => { f.call_box(); diff --git a/actix-rt/tests/integration_tests.rs b/actix-rt/tests/integration_tests.rs index e3296e89..af5d0224 100644 --- a/actix-rt/tests/integration_tests.rs +++ b/actix-rt/tests/integration_tests.rs @@ -61,3 +61,40 @@ fn join_another_arbiter() { "Premature stop of arbiter should conclude regardless of it's current state" ); } + +#[test] +fn join_current_arbiter() { + let time = Duration::from_secs(2); + + let instant = Instant::now(); + actix_rt::System::new("test_join_current_arbiter").block_on(async move { + actix_rt::spawn(async move { + tokio::time::delay_for(time).await; + actix_rt::Arbiter::current().stop(); + }); + actix_rt::Arbiter::local_join().await; + }); + assert!( + instant.elapsed() >= time, + "Join on current arbiter should wait for all spawned futures" + ); + + let large_timer = Duration::from_secs(20); + let instant = Instant::now(); + actix_rt::System::new("test_join_current_arbiter").block_on(async move { + actix_rt::spawn(async move { + tokio::time::delay_for(time).await; + actix_rt::Arbiter::current().stop(); + }); + let f = actix_rt::Arbiter::local_join(); + actix_rt::spawn(async move { + tokio::time::delay_for(large_timer).await; + actix_rt::Arbiter::current().stop(); + }); + f.await; + }); + assert!( + instant.elapsed() < large_timer, + "local_join should await only for the already spawned futures" + ); +}