From 90fc3575422847c9358980a75c2011b51205fdf6 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sun, 25 Oct 2020 01:11:30 +0000 Subject: [PATCH] change on connect API --- actix-http/src/builder.rs | 48 ++++++++----- actix-http/src/extensions.rs | 29 +++++++- actix-http/src/h1/dispatcher.rs | 18 ++++- actix-http/src/h1/service.rs | 35 +++++++-- actix-http/src/h2/service.rs | 42 +++++++---- actix-http/src/helpers.rs | 1 + actix-http/src/lib.rs | 4 +- actix-http/src/service.rs | 52 ++++++++++---- examples/on_connect.rs | 43 ++++++----- src/server.rs | 123 +++++++++++++++++++++----------- 10 files changed, 278 insertions(+), 117 deletions(-) diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 75efc68b2..807d1f7c1 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -14,10 +14,11 @@ use crate::helpers::{Data, DataFactory}; use crate::request::Request; use crate::response::Response; use crate::service::HttpService; +use crate::{ConnectCallback, Extensions}; -/// A http service builder +/// A HTTP service builder /// -/// This type can be used to construct an instance of `http service` through a +/// This type can be used to construct an instance of [`HttpService`] through a /// builder-like pattern. pub struct HttpServiceBuilder> { keep_alive: KeepAlive, @@ -27,7 +28,9 @@ pub struct HttpServiceBuilder> { local_addr: Option, expect: X, upgrade: Option, + // DEPRECATED: in favor of on_connect_ext on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, S)>, } @@ -49,6 +52,7 @@ where expect: ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -138,6 +142,7 @@ where expect: expect.into_factory(), upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -167,14 +172,15 @@ where expect: self.expect, upgrade: Some(upgrade.into_factory()), on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } /// Set on-connect callback. /// - /// It get called once per connection and result of the call - /// get stored to the request's extensions. + /// Called once per connection. Return value of the call is stored in request extensions. + #[deprecated = "Prefer the `on_connect_ext` style callback."] pub fn on_connect(mut self, f: F) -> Self where F: Fn(&T) -> I + 'static, @@ -184,21 +190,20 @@ where self } - /// Similar to `on_connect`, but takes optional callback. - /// If `f` is None, does nothing. - pub fn on_connect_optional(self, f: Option) -> Self - where - F: Fn(&T) -> I + 'static, - I: Clone + 'static, + /// Sets the callback to be run on connection establishment. + /// + /// Has mutable access to a data container that will be merged into request extensions. + /// This enables transport layer data (like client certificates) to be accessed in middleware + /// and handlers. + pub fn on_connect_ext(mut self, f: F) -> Self + where + F: Fn(&T, &mut Extensions) + 'static, { - if let Some(f) = f { - self.on_connect(f) - } else { - self - } + self.on_connect_ext = Some(Rc::new(f)); + self } - /// Finish service configuration and create *http service* for HTTP/1 protocol. + /// Finish service configuration and create a HTTP Service for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where B: MessageBody, @@ -214,13 +219,15 @@ where self.secure, self.local_addr, ); + H1Service::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } - /// Finish service configuration and create *http service* for HTTP/2 protocol. + /// Finish service configuration and create a HTTP service for HTTP/2 protocol. pub fn h2(self, service: F) -> H2Service where B: MessageBody + 'static, @@ -237,7 +244,10 @@ where self.secure, self.local_addr, ); - H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) + + H2Service::with_config(cfg, service.into_factory()) + .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } /// Finish service configuration and create `HttpService` instance. @@ -257,9 +267,11 @@ where self.secure, self.local_addr, ); + HttpService::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } } diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index 09f1b711f..bcf70b409 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,5 +1,5 @@ use std::any::{Any, TypeId}; -use std::fmt; +use std::{fmt, mem}; use fxhash::FxHashMap; @@ -66,6 +66,11 @@ impl Extensions { pub fn extend(&mut self, other: Extensions) { self.map.extend(other.map); } + + /// Sets (or overrides) items from `other` into this map. + pub(crate) fn drain_from(&mut self, other: &mut Self) { + self.map.extend(mem::take(&mut other.map)); + } } impl fmt::Debug for Extensions { @@ -213,4 +218,26 @@ mod tests { assert_eq!(extensions.get(), Some(&20u8)); assert_eq!(extensions.get_mut(), Some(&mut 20u8)); } + + fn test_drain_from() { + let mut ext = Extensions::new(); + ext.insert(2isize); + + let mut more_ext = Extensions::new(); + + more_ext.insert(5isize); + more_ext.insert(5usize); + + assert_eq!(ext.get::(), Some(&2isize)); + assert_eq!(ext.get::(), None); + assert_eq!(more_ext.get::(), Some(&5isize)); + assert_eq!(more_ext.get::(), Some(&5usize)); + + ext.drain_from(&mut more_ext); + + assert_eq!(ext.get::(), Some(&5isize)); + assert_eq!(ext.get::(), Some(&5usize)); + assert_eq!(more_ext.get::(), None); + assert_eq!(more_ext.get::(), None); + } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 7c4de9707..ace4144e3 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -12,7 +12,6 @@ use bytes::{Buf, BytesMut}; use log::{error, trace}; use pin_project::pin_project; -use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; @@ -21,6 +20,10 @@ use crate::helpers::DataFactory; use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::Response; +use crate::{ + body::{Body, BodySize, MessageBody, ResponseBody}, + Extensions, +}; use super::codec::Codec; use super::payload::{Payload, PayloadSender, PayloadStatus}; @@ -88,6 +91,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, flags: Flags, peer_addr: Option, error: Option, @@ -167,7 +171,7 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - /// Create http/1 dispatcher. + /// Create HTTP/1 dispatcher. pub(crate) fn new( stream: T, config: ServiceConfig, @@ -175,6 +179,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, peer_addr: Option, ) -> Self { Dispatcher::with_timeout( @@ -187,6 +192,7 @@ where expect, upgrade, on_connect, + on_connect_data, peer_addr, ) } @@ -202,6 +208,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, peer_addr: Option, ) -> Self { let keepalive = config.keep_alive_enabled(); @@ -234,6 +241,7 @@ where expect, upgrade, on_connect, + on_connect_data, flags, peer_addr, ka_expire, @@ -526,11 +534,15 @@ where let pl = this.codec.message_type(); req.head_mut().peer_addr = *this.peer_addr; + // DEPRECATED // set on_connect data if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } + // merge on_connect_ext data into request extensions + req.extensions_mut().drain_from(this.on_connect_data); + if pl == MessageType::Stream && this.upgrade.is_some() { this.messages.push_back(DispatcherMessage::Upgrade(req)); break; @@ -927,8 +939,10 @@ mod tests { CloneableService::new(ExpectHandler), None, None, + Extensions::new(), None, ); + match Pin::new(&mut h1).poll(cx) { Poll::Pending => panic!(), Poll::Ready(res) => assert!(res.is_err()), diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 6aafd4089..5008791c0 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -18,6 +18,7 @@ use crate::error::{DispatchError, Error, ParseError}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; +use crate::{ConnectCallback, Extensions}; use super::codec::Codec; use super::dispatcher::Dispatcher; @@ -30,6 +31,7 @@ pub struct H1Service> { expect: X, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -52,6 +54,7 @@ where expect: ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -213,6 +216,7 @@ where srv: self.srv, upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -229,6 +233,7 @@ where srv: self.srv, expect: self.expect, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -241,6 +246,12 @@ where self.on_connect = f; self } + + /// Set on connect callback. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl ServiceFactory for H1Service @@ -274,6 +285,7 @@ where expect: None, upgrade: None, on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), cfg: Some(self.cfg.clone()), _t: PhantomData, } @@ -303,6 +315,7 @@ where expect: Option, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: Option, _t: PhantomData<(T, B)>, } @@ -352,23 +365,26 @@ where Poll::Ready(result.map(|service| { let this = self.as_mut().project(); + H1ServiceHandler::new( this.cfg.take().unwrap(), service, this.expect.take().unwrap(), this.upgrade.take(), this.on_connect.clone(), + this.on_connect_ext.clone(), ) })) } } -/// `Service` implementation for HTTP1 transport +/// `Service` implementation for HTTP/1 transport pub struct H1ServiceHandler { srv: CloneableService, expect: CloneableService, upgrade: Option>, on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: ServiceConfig, _t: PhantomData<(T, B)>, } @@ -390,6 +406,7 @@ where expect: X, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, ) -> H1ServiceHandler { H1ServiceHandler { srv: CloneableService::new(srv), @@ -397,6 +414,7 @@ where upgrade: upgrade.map(CloneableService::new), cfg, on_connect, + on_connect_ext, _t: PhantomData, } } @@ -462,11 +480,13 @@ where } fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + + let mut connect_extensions = Extensions::new(); + if let Some(ref handler) = self.on_connect_ext { + // run on_connect_ext callback, populating connect extensions + handler(&io, &mut connect_extensions); + } Dispatcher::new( io, @@ -474,7 +494,8 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), - on_connect, + deprecated_on_connect, + connect_extensions, addr, ) } diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 6b5620e02..428f6c4a4 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{net, rc}; +use std::{net, rc::Rc}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; @@ -23,6 +23,7 @@ use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; +use crate::{ConnectCallback, Extensions}; use super::dispatcher::Dispatcher; @@ -30,7 +31,8 @@ use super::dispatcher::Dispatcher; pub struct H2Service { srv: S, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -50,19 +52,27 @@ where H2Service { cfg, on_connect: None, + on_connect_ext: None, srv: service.into_factory(), _t: PhantomData, } } /// Set on connect callback. + pub(crate) fn on_connect( mut self, - f: Option Box>>, + f: Option Box>>, ) -> Self { self.on_connect = f; self } + + /// Set on connect callback. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl H2Service @@ -203,6 +213,7 @@ where fut: self.srv.new_service(()), cfg: Some(self.cfg.clone()), on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), _t: PhantomData, } } @@ -214,7 +225,8 @@ pub struct H2ServiceResponse { #[pin] fut: S::Future, cfg: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -237,6 +249,7 @@ where H2ServiceHandler::new( this.cfg.take().unwrap(), this.on_connect.clone(), + this.on_connect_ext.clone(), service, ) })) @@ -247,7 +260,8 @@ where pub struct H2ServiceHandler { srv: CloneableService, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -261,12 +275,14 @@ where { fn new( cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, srv: S, ) -> H2ServiceHandler { H2ServiceHandler { cfg, on_connect, + on_connect_ext, srv: CloneableService::new(srv), _t: PhantomData, } @@ -296,18 +312,20 @@ where } fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + + let mut connect_extensions = Extensions::new(); + if let Some(ref handler) = self.on_connect_ext { + // run on_connect_ext callback, populating connect extensions + handler(&io, &mut connect_extensions); + } H2ServiceHandlerResponse { state: State::Handshake( Some(self.srv.clone()), Some(self.cfg.clone()), addr, - on_connect, + deprecated_on_connect, server::handshake(io), ), } diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs index bbf358b66..ac0e0f118 100644 --- a/actix-http/src/helpers.rs +++ b/actix-http/src/helpers.rs @@ -50,6 +50,7 @@ impl<'a> io::Write for Writer<'a> { self.0.extend_from_slice(buf); Ok(buf.len()) } + fn flush(&mut self) -> io::Result<()> { Ok(()) } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index fab91be2b..e57a3727e 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -1,4 +1,4 @@ -//! Basic http primitives for actix-net framework. +//! Basic HTTP primitives for the Actix ecosystem. #![deny(rust_2018_idioms)] #![allow( @@ -78,3 +78,5 @@ pub enum Protocol { Http1, Http2, } + +type ConnectCallback = dyn Fn(&IO, &mut Extensions); diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 9ee579702..cf7b67dc3 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{fmt, net, rc}; +use std::{fmt, net, rc::Rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; @@ -20,15 +20,17 @@ use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; -use crate::{h1, h2::Dispatcher, Protocol}; +use crate::{h1, h2::Dispatcher, ConnectCallback, Protocol}; -/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation +/// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. pub struct HttpService> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, - on_connect: Option Box>>, + // DEPRECATED: in favor of on_connect_ext + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -66,6 +68,7 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -81,6 +84,7 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -113,6 +117,7 @@ where srv: self.srv, upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -138,6 +143,7 @@ where srv: self.srv, expect: self.expect, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -145,11 +151,17 @@ where /// Set on connect callback. pub(crate) fn on_connect( mut self, - f: Option Box>>, + f: Option Box>>, ) -> Self { self.on_connect = f; self } + + /// Set connect callback with mutable access to request data container. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl HttpService @@ -355,6 +367,7 @@ where expect: None, upgrade: None, on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), cfg: self.cfg.clone(), _t: PhantomData, } @@ -378,7 +391,8 @@ pub struct HttpServiceResponse< fut_upg: Option, expect: Option, upgrade: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: ServiceConfig, _t: PhantomData<(T, B)>, } @@ -429,6 +443,7 @@ where .fut .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); + Poll::Ready(result.map(|service| { let this = self.as_mut().project(); HttpServiceHandler::new( @@ -437,6 +452,7 @@ where this.expect.take().unwrap(), this.upgrade.take(), this.on_connect.clone(), + this.on_connect_ext.clone(), ) })) } @@ -448,7 +464,8 @@ pub struct HttpServiceHandler { expect: CloneableService, upgrade: Option>, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B, X)>, } @@ -469,11 +486,13 @@ where srv: S, expect: X, upgrade: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, ) -> HttpServiceHandler { HttpServiceHandler { cfg, on_connect, + on_connect_ext, srv: CloneableService::new(srv), expect: CloneableService::new(expect), upgrade: upgrade.map(CloneableService::new), @@ -543,11 +562,12 @@ where } fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let mut connect_extensions = crate::Extensions::new(); + + let legacy_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + self.on_connect_ext + .as_ref() + .map(|handler| handler(&io, &mut connect_extensions)); match proto { Protocol::Http2 => HttpServiceHandlerResponse { @@ -555,10 +575,11 @@ where server::handshake(io), self.cfg.clone(), self.srv.clone(), - on_connect, + legacy_on_connect, peer_addr, ))), }, + Protocol::Http1 => HttpServiceHandlerResponse { state: State::H1(h1::Dispatcher::new( io, @@ -566,7 +587,8 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), - on_connect, + legacy_on_connect, + connect_extensions, peer_addr, )), }, diff --git a/examples/on_connect.rs b/examples/on_connect.rs index 0772515da..20c2585de 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -1,31 +1,38 @@ //! This example shows how to use `actix_web::HttpServer::on_connect` -#[derive(Clone)] +use std::any::Any; + +use actix_rt::net; +use actix_web::{dev::Extensions, web, App, HttpRequest, HttpServer}; + +#[derive(Debug, Clone)] struct ConnectionInfo(String); -async fn route_whoami(req: actix_web::HttpRequest) -> String { - let extensions = req.extensions(); - let conn_info = extensions.get::().unwrap(); - format!("Here is some info about you: {}", conn_info.0) +async fn route_whoami(conn_info: web::ReqData) -> String { + format!("Here is some info about you:\n{}", &conn_info.0) } -fn on_connect(connection: &dyn std::any::Any) -> ConnectionInfo { - let sock = connection.downcast_ref::().unwrap(); - let msg = format!("local_addr={:?}\npeer_addr={:?}", sock.local_addr(),sock.peer_addr()); - ConnectionInfo(msg) +fn on_connect(connection: &dyn Any, data: &mut Extensions) { + let sock = connection.downcast_ref::().unwrap(); + + let msg = format!( + "local_addr={:?}; peer_addr={:?}", + sock.local_addr(), + sock.peer_addr() + ); + + data.insert(ConnectionInfo(msg)); } -#[actix_rt::main] +#[actix_web::main] async fn main() -> std::io::Result<()> { std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); env_logger::init(); - actix_web::HttpServer::new(|| { - actix_web::App::new().route("/", actix_web::web::get().to(route_whoami)) - }) - .on_connect(std::sync::Arc::new(on_connect)) - .bind("127.0.0.1:8080")? - .workers(1) - .run() - .await + HttpServer::new(|| App::new().route("/", web::get().to(route_whoami))) + .on_connect(on_connect) + .bind(("127.0.0.1", 8080))? + .workers(1) + .run() + .await } diff --git a/src/server.rs b/src/server.rs index 3f659803a..78e016908 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,14 @@ -use std::marker::PhantomData; -use std::sync::{Arc, Mutex}; -use std::{fmt, io, net}; +use std::{ + any::Any, + fmt, io, + marker::PhantomData, + net, + sync::{Arc, Mutex}, +}; -use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Response}; +use actix_http::{ + body::MessageBody, Error, Extensions, HttpService, KeepAlive, Request, Response, +}; use actix_server::{Server, ServerBuilder}; use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; @@ -49,7 +55,7 @@ struct Config { /// .await /// } /// ``` -pub struct HttpServer +pub struct HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, @@ -64,10 +70,11 @@ where backlog: i32, sockets: Vec, builder: ServerBuilder, - on_connect_fn: Option C + Send + Sync>>, - _t: PhantomData<(S, B, C)>, + on_connect_fn: Option>, + _t: PhantomData<(S, B)>, } -impl HttpServer + +impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, @@ -103,12 +110,9 @@ where /// - `actix_tls::rustls::TlsStream` when using rustls. /// - `tokio::net::TcpStream` when no encryption is used. /// See `on_connect` example for additional details. - pub fn on_connect( - self, - f: Arc C + Send + Sync>, - ) -> HttpServer + pub fn on_connect(self, f: CB) -> HttpServer where - C: Clone + 'static, + CB: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static, { HttpServer { factory: self.factory, @@ -116,13 +120,13 @@ where backlog: self.backlog, sockets: self.sockets, builder: self.builder, - on_connect_fn: Some(f), + on_connect_fn: Some(Arc::new(f)), _t: PhantomData, } } } -impl HttpServer +impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, @@ -132,7 +136,6 @@ where S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, - C: Clone + 'static, { /// Set number of workers to start. /// @@ -292,14 +295,20 @@ where c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .local_addr(addr) - .on_connect_optional(on_connect_fn.clone().map(|handler| { - move |arg: &_| (&*handler)(arg as &dyn std::any::Any) - })) - .finish(map_config(factory(), move |_| cfg.clone())) + .local_addr(addr); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .tcp() }, )?; @@ -344,14 +353,23 @@ where addr, c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .on_connect_optional(on_connect_fn.clone().map(|handler| { - move |arg: &_| (&*handler)(arg as &dyn std::any::Any) - })) - .finish(map_config(factory(), move |_| cfg.clone())) + .client_disconnect(c.client_shutdown); + + let svc = if let Some(handler) = + on_connect_fn.map(|handler| Arc::clone(&handler)) + { + svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) }, )?; @@ -396,14 +414,23 @@ where addr, c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .on_connect_optional(on_connect_fn.clone().map(|handler| { - move |arg: &_| (&*handler)(arg as &dyn std::any::Any) - })) - .finish(map_config(factory(), move |_| cfg.clone())) + .client_disconnect(c.client_shutdown); + + let svc = if let Some(handler) = + on_connect_fn.map(|handler| Rc::clone(&handler)) + { + svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) }, )?; @@ -494,7 +521,7 @@ where } #[cfg(unix)] - /// Start listening for unix domain connections on existing listener. + /// Start listening for unix domain (UDS) connections on existing listener. pub fn listen_uds( mut self, lst: std::os::unix::net::UnixListener, @@ -514,6 +541,7 @@ where let addr = format!("actix-web-service-{:?}", lst.local_addr()?); let on_connect_fn = self.on_connect_fn.clone(); + self.builder = self.builder.listen_uds(addr, lst, move || { let c = cfg.lock().unwrap(); let config = AppConfig::new( @@ -521,14 +549,23 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .on_connect_optional(on_connect_fn.clone().map(|handler| { - move |arg: &_| (&*handler)(arg as &dyn std::any::Any) - })) - .finish(map_config(factory(), move |_| config.clone())), + { + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (&*handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| config.clone())) + }, ) })?; Ok(self) @@ -576,7 +613,7 @@ where } } -impl HttpServer +impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory,