diff --git a/actix-server/src/builder.rs b/actix-server/src/builder.rs index 64a45df9..6449cd0f 100644 --- a/actix-server/src/builder.rs +++ b/actix-server/src/builder.rs @@ -8,9 +8,8 @@ use actix_rt::time::{delay_until, Instant}; use actix_rt::{spawn, System}; use futures_channel::mpsc::{unbounded, UnboundedReceiver}; use futures_channel::oneshot; -use futures_util::future::ready; use futures_util::stream::FuturesUnordered; -use futures_util::{future::Future, ready, stream::Stream, FutureExt, StreamExt}; +use futures_util::{future::Future, ready, stream::Stream, StreamExt}; use log::{error, info}; use socket2::{Domain, Protocol, Socket, Type}; @@ -37,6 +36,7 @@ pub struct ServerBuilder { no_signals: bool, cmd: UnboundedReceiver, server: Server, + on_stop: Pin>>, notify: Vec>, } @@ -66,6 +66,7 @@ impl ServerBuilder { cmd: rx, notify: Vec::new(), server, + on_stop: Box::pin(async {}), } } @@ -296,6 +297,16 @@ impl ServerBuilder { } } + pub fn on_stop(mut self, future: F) -> Self + where + F: Fn() -> Fut + 'static, + Fut: Future, + { + self.on_stop = Box::pin(async move { future().await }); + + self + } + fn start_worker(&self, idx: usize, notify: AcceptNotify) -> WorkerClient { let avail = WorkerAvailability::new(notify); let services: Vec> = @@ -358,54 +369,51 @@ impl ServerBuilder { self.accept.send(Command::Stop); let notify = std::mem::take(&mut self.notify); + let mut on_stop = Box::pin(async {}) as _; + std::mem::swap(&mut self.on_stop, &mut on_stop); + // stop workers if !self.workers.is_empty() && graceful { - spawn( - self.workers - .iter() - .map(move |worker| worker.1.stop(graceful)) - .collect::>() - .collect::>() - .then(move |_| { - if let Some(tx) = completion { - let _ = tx.send(()); - } - for tx in notify { - let _ = tx.send(()); - } - if exit { - spawn( - async { - delay_until( - Instant::now() + Duration::from_millis(300), - ) - .await; - System::current().stop(); - } - .boxed(), - ); - } - ready(()) - }), - ) - } else { - // we need to stop system if server was spawned - if self.exit { - spawn( - delay_until(Instant::now() + Duration::from_millis(300)).then( - |_| { - System::current().stop(); - ready(()) - }, - ), - ); - } + let stop_workers = self + .workers + .iter() + .map(move |worker| worker.1.stop(graceful)) + .collect::>() + .collect::>(); + + spawn(async move { + on_stop.await; + stop_workers.await; + if let Some(tx) = completion { + let _ = tx.send(()); + } + for tx in notify { + let _ = tx.send(()); + } + if exit { + spawn(async { + delay_until(Instant::now() + Duration::from_millis(300)).await; + System::current().stop(); + }); + } + }); + // we need to stop system if server was spawned + } else if self.exit { + spawn(async move { + on_stop.await; + delay_until(Instant::now() + Duration::from_millis(300)).await; + System::current().stop(); + }); if let Some(tx) = completion { let _ = tx.send(()); } for tx in notify { let _ = tx.send(()); } + } else { + spawn(async move { + on_stop.await; + }); } } ServerCommand::WorkerFaulted(idx) => { diff --git a/actix-server/tests/test_server.rs b/actix-server/tests/test_server.rs index ce309c94..6391e64a 100644 --- a/actix-server/tests/test_server.rs +++ b/actix-server/tests/test_server.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{mpsc, Arc}; use std::{net, thread, time}; @@ -160,7 +160,7 @@ fn test_configure() { rt.service("addr1", fn_service(|_| ok::<_, ()>(()))); rt.service("addr3", fn_service(|_| ok::<_, ()>(()))); rt.on_start(lazy(move |_| { - let _ = num.fetch_add(1, Relaxed); + let _ = num.fetch_add(1, Ordering::Relaxed); })) }) }) @@ -176,7 +176,60 @@ fn test_configure() { assert!(net::TcpStream::connect(addr1).is_ok()); assert!(net::TcpStream::connect(addr2).is_ok()); assert!(net::TcpStream::connect(addr3).is_ok()); - assert_eq!(num.load(Relaxed), 1); + assert_eq!(num.load(Ordering::Relaxed), 1); + sys.stop(); + let _ = h.join(); +} + +#[test] +#[cfg(unix)] +fn test_on_stop() { + use actix_codec::{BytesCodec, Framed}; + use actix_rt::net::TcpStream; + use bytes::Bytes; + use futures_util::sink::SinkExt; + + let bool = std::sync::Arc::new(AtomicBool::new(false)); + + let addr = unused_addr(); + let (tx, rx) = mpsc::channel(); + + let h = thread::spawn({ + let bool = bool.clone(); + move || { + let sys = actix_rt::System::new("test"); + let srv: Server = Server::build() + .backlog(100) + .disable_signals() + .on_stop(move || { + let bool = bool.clone(); + async move { + bool.store(true, Ordering::SeqCst); + } + }) + .bind("test", addr, move || { + fn_service(|io: TcpStream| async move { + let mut f = Framed::new(io, BytesCodec); + f.send(Bytes::from_static(b"test")).await.unwrap(); + Ok::<_, ()>(()) + }) + }) + .unwrap() + .start(); + + let _ = tx.send((srv, actix_rt::System::current())); + let _ = sys.run(); + } + }); + + let (srv, sys) = rx.recv().unwrap(); + + let _ = srv.stop(true); + + thread::sleep(time::Duration::from_millis(100)); + + assert!(bool.load(Ordering::SeqCst)); + sys.stop(); let _ = h.join(); }