diff --git a/actix-server/src/builder.rs b/actix-server/src/builder.rs index a509a21b..fed9a92a 100644 --- a/actix-server/src/builder.rs +++ b/actix-server/src/builder.rs @@ -7,6 +7,7 @@ use std::{ }; use actix_rt::{self as rt, net::TcpStream, time::sleep, System}; +use futures_core::future::BoxFuture; use log::{error, info}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; use tokio::sync::oneshot; @@ -24,7 +25,6 @@ use crate::worker::{ WorkerHandleServer, }; use crate::{join_all, Token}; -use futures_core::future::LocalBoxFuture; /// Server builder pub struct ServerBuilder { @@ -39,7 +39,7 @@ pub struct ServerBuilder { no_signals: bool, cmd: UnboundedReceiver, server: Server, - on_stop: Box LocalBoxFuture<'static, ()>>, + on_stop: Box BoxFuture<'static, ()>>, notify: Vec>, worker_config: ServerWorkerConfig, } @@ -329,7 +329,7 @@ impl ServerBuilder { pub fn on_stop(mut self, func: F) -> Self where F: Fn() -> Fut + 'static, - Fut: Future + 'static, + Fut: Future + Send + 'static, { self.on_stop = Box::new(move || { let fut = func(); @@ -338,8 +338,12 @@ impl ServerBuilder { self } - fn start_worker(&self, idx: usize, waker: WakerQueue) -> WorkerHandle { - let avail = WorkerAvailability::new(waker); + fn start_worker( + &self, + idx: usize, + waker: WakerQueue, + ) -> (WorkerHandleAccept, WorkerHandleServer) { + let avail = WorkerAvailability::new(idx, waker); let services = self.services.iter().map(|v| v.clone_factory()).collect(); ServerWorker::start(idx, services, avail, self.worker_config) @@ -401,7 +405,7 @@ impl ServerBuilder { // take the on_stop future. let mut on_stop = - Box::new(|| Box::pin(async {}) as LocalBoxFuture<'static, ()>) as _; + Box::new(|| Box::pin(async {}) as BoxFuture<'static, ()>) as _; std::mem::swap(&mut self.on_stop, &mut on_stop); // stop workers @@ -412,8 +416,8 @@ impl ServerBuilder { .collect(); rt::spawn(async move { - on_stop().await; - + on_stop().await; + if graceful { let _ = join_all(stop).await; } diff --git a/actix-server/tests/test_server.rs b/actix-server/tests/test_server.rs index b48c49be..09dbf676 100644 --- a/actix-server/tests/test_server.rs +++ b/actix-server/tests/test_server.rs @@ -577,3 +577,45 @@ async fn worker_restart() { let _ = server.stop(false); let _ = h.join().unwrap(); } + +#[test] +fn on_stop() { + let bool = std::sync::Arc::new(AtomicBool::new(false)); + + let addr = unused_addr(); + let (tx, rx) = mpsc::channel(); + + let bool_clone = bool.clone(); + let h = thread::spawn(move || { + let sys = actix_rt::System::new(); + let lst = net::TcpListener::bind(addr).unwrap(); + sys.block_on(async { + let server = Server::build() + .disable_signals() + .on_stop(move || { + bool.store(true, Ordering::SeqCst); + async {} + }) + .listen("test", lst, move || { + fn_service(|_| async { Ok::<_, ()>(()) }) + }) + .unwrap() + .run(); + let _ = tx.send(server.clone()); + + server.await + }) + }); + let sys = rx.recv().unwrap(); + + thread::sleep(Duration::from_millis(500)); + + assert!(!bool_clone.load(Ordering::SeqCst)); + + assert!(net::TcpStream::connect(addr).is_ok()); + + let _ = sys.stop(true); + let _ = h.join(); + + assert!(bool_clone.load(Ordering::SeqCst)); +}