From da22edef3655fd1a6114509611e400b4ff4af37d Mon Sep 17 00:00:00 2001 From: Mikail Bagishov Date: Fri, 1 May 2020 21:11:02 +0300 Subject: [PATCH] Expose on_connect in actix-web --- actix-http/src/builder.rs | 14 ++++++++ src/server.rs | 67 +++++++++++++++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 271abd43f..75efc68b2 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -184,6 +184,20 @@ where self } + /// Similar to `on_connect`, but takes optional callback. + /// If `f` is None, does nothing. + pub fn on_connect_optional(self, f: Option) -> Self + where + F: Fn(&T) -> I + 'static, + I: Clone + 'static, + { + if let Some(f) = f { + self.on_connect(f) + } else { + self + } + } + /// Finish service configuration and create *http service* for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where diff --git a/src/server.rs b/src/server.rs index 2b86f7416..ac850f5c5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -49,7 +49,7 @@ struct Config { /// .await /// } /// ``` -pub struct HttpServer +pub struct HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, @@ -64,10 +64,10 @@ where backlog: i32, sockets: Vec, builder: ServerBuilder, - _t: PhantomData<(S, B)>, + on_connect_fn: Option C + Send + Sync>>, + _t: PhantomData<(S, B, C)>, } - -impl HttpServer +impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, @@ -91,10 +91,48 @@ where backlog: 1024, sockets: Vec::new(), builder: ServerBuilder::default(), + on_connect_fn: None, _t: PhantomData, } } + /// Sets function that will be called once for each connection. + /// It will receive &Any, which contains underlying connection type. + /// For example: + /// - `actix_tls::openssl::SslStream` when using openssl. + /// - `actix_tls::rustls::TlsStream` when using rustls. + /// - `tokio::net::TcpStream` when no encryption is used. + pub fn on_connect( + self, + f: Arc C + Send + Sync>, + ) -> HttpServer + where + C: Clone + 'static, + { + HttpServer { + factory: self.factory, + config: self.config, + backlog: self.backlog, + sockets: self.sockets, + builder: self.builder, + on_connect_fn: Some(f), + _t: PhantomData, + } + } +} + +impl HttpServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + C: Clone + 'static, +{ /// Set number of workers to start. /// /// By default http server uses number of available logical cpu as threads @@ -240,6 +278,7 @@ where addr, scheme: "http", }); + let on_connect_fn = self.on_connect_fn.clone(); self.builder = self.builder.listen( format!("actix-web-service-{}", addr), @@ -256,6 +295,9 @@ where .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) .local_addr(addr) + .on_connect_optional(on_connect_fn.clone().map(|handler| { + move |arg: &_| (&*handler)(arg as &dyn std::any::Any) + })) .finish(map_config(factory(), move |_| cfg.clone())) .tcp() }, @@ -289,6 +331,8 @@ where scheme: "https", }); + let on_connect_fn = self.on_connect_fn.clone(); + self.builder = self.builder.listen( format!("actix-web-service-{}", addr), lst, @@ -303,6 +347,9 @@ where .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) .client_disconnect(c.client_shutdown) + .on_connect_optional(on_connect_fn.clone().map(|handler| { + move |arg: &_| (&*handler)(arg as &dyn std::any::Any) + })) .finish(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) }, @@ -336,6 +383,8 @@ where scheme: "https", }); + let on_connect_fn = self.on_connect_fn.clone(); + self.builder = self.builder.listen( format!("actix-web-service-{}", addr), lst, @@ -350,6 +399,9 @@ where .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) .client_disconnect(c.client_shutdown) + .on_connect_optional(on_connect_fn.clone().map(|handler| { + move |arg: &_| (&*handler)(arg as &dyn std::any::Any) + })) .finish(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) }, @@ -460,7 +512,7 @@ where }); let addr = format!("actix-web-service-{:?}", lst.local_addr()?); - + let on_connect_fn = self.on_connect_fn.clone(); self.builder = self.builder.listen_uds(addr, lst, move || { let c = cfg.lock().unwrap(); let config = AppConfig::new( @@ -472,6 +524,9 @@ where HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) + .on_connect_optional(on_connect_fn.clone().map(|handler| { + move |arg: &_| (&*handler)(arg as &dyn std::any::Any) + })) .finish(map_config(factory(), move |_| config.clone())), ) })?; @@ -520,7 +575,7 @@ where } } -impl HttpServer +impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory,