From 320fa713c6bd2867d81c3e2cd7dff54a552af0a0 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Mon, 26 Oct 2020 10:21:31 +0000 Subject: [PATCH] add tests --- CHANGES.md | 2 + Cargo.toml | 6 +- actix-http/CHANGES.md | 7 +++ actix-http/src/service.rs | 11 +++- actix-http/tests/test_openssl.rs | 2 + actix-http/tests/test_server.rs | 2 + examples/on_connect.rs | 101 ++++++++----------------------- src/server.rs | 5 +- tests/test_on_connect.rs | 45 ++++++++++++++ 9 files changed, 96 insertions(+), 85 deletions(-) create mode 100644 tests/test_on_connect.rs diff --git a/CHANGES.md b/CHANGES.md index af34c3b49..7ce2296c6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,7 @@ * Implement `exclude_regex` for Logger middleware. [#1723] * Add request-local data extractor `web::ReqData`. [#1748] * Add `app_data` to `ServiceConfig`. [#1757] +* Expose `on_connect` for access to the connection stream before request is handled. [#1748] ### Changed * Print non-configured `Data` type when attempting extraction. [#1743] @@ -15,6 +16,7 @@ [#1743]: https://github.com/actix/actix-web/pull/1743 [#1748]: https://github.com/actix/actix-web/pull/1748 [#1750]: https://github.com/actix/actix-web/pull/1750 +[#1754]: https://github.com/actix/actix-web/pull/1754 ## 3.1.0 - 2020-09-29 diff --git a/Cargo.toml b/Cargo.toml index a294f0228..4fafc61c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,13 +66,12 @@ required-features = ["compress"] [[example]] name = "on_connect" -required-features = ["rustls"] +required-features = [] [[example]] name = "client" required-features = ["rustls"] - [dependencies] actix-codec = "0.3.0" actix-service = "1.0.6" @@ -114,12 +113,11 @@ tinyvec = { version = "1", features = ["alloc"] } actix = "0.10.0" actix-http = { version = "2.0.0", features = ["actors"] } rand = "0.7" -env_logger = "0.7" +env_logger = "0.8" serde_derive = "1.0" brotli2 = "0.3.2" flate2 = "1.0.13" criterion = "0.3" -webpki-roots = "0.20" [profile.release] lto = true diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 990c9c071..0afb63a6d 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,9 +1,16 @@ # Changes ## Unreleased - 2020-xx-xx +### Added +* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754] + +### Changed * Upgrade `base64` to `0.13`. * Upgrade `pin-project` to `1.0`. +[#1754]: https://github.com/actix/actix-web/pull/1754 + + ## 2.0.0 - 2020-09-11 * No significant changes from `2.0.0-beta.4`. diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 8393f1080..75745209c 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -694,9 +694,16 @@ where } else { panic!() }; - let (_, cfg, srv, on_connect, on_connect_data, peer_addr) = data.take().unwrap(); + let (_, cfg, srv, on_connect, on_connect_data, peer_addr) = + data.take().unwrap(); self.set(State::H2(Dispatcher::new( - srv, conn, on_connect, on_connect_data, cfg, None, peer_addr, + srv, + conn, + on_connect, + on_connect_data, + cfg, + None, + peer_addr, ))); self.poll(cx) } diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index 795deacdc..05f01d240 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -411,8 +411,10 @@ async fn test_h2_on_connect() { let srv = test_server(move || { HttpService::build() .on_connect(|_| 10usize) + .on_connect_ext(|_, data| data.insert(20isize)) .h2(|req: Request| { assert!(req.extensions().contains::()); + assert!(req.extensions().contains::()); ok::<_, ()>(Response::Ok().finish()) }) .openssl(ssl_acceptor()) diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 0375b6f66..de6368fda 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -663,8 +663,10 @@ async fn test_h1_on_connect() { let srv = test_server(|| { HttpService::build() .on_connect(|_| 10usize) + .on_connect_ext(|_, data| data.insert(20isize)) .h1(|req: Request| { assert!(req.extensions().contains::()); + assert!(req.extensions().contains::()); future::ok::<_, ()>(Response::Ok().finish()) }) .tcp() diff --git a/examples/on_connect.rs b/examples/on_connect.rs index 74dd0636d..bdad7e67e 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -4,100 +4,47 @@ //! For an example of extracting a client TLS certificate, see: //! -use std::{any::Any, env, fs::File, io::BufReader}; +use std::{any::Any, env, io, net::SocketAddr}; -use actix_tls::rustls::{ServerConfig, TlsStream}; -use actix_web::{ - dev::Extensions, rt::net::TcpStream, web, App, HttpResponse, HttpServer, Responder, -}; -use log::info; -use rust_tls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - AllowAnyAnonymousOrAuthenticatedClient, Certificate, RootCertStore, Session, -}; - -const CA_CERT: &str = "examples/certs/rootCA.pem"; -const SERVER_CERT: &str = "examples/certs/server-cert.pem"; -const SERVER_KEY: &str = "examples/certs/server-key.pem"; +use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; #[derive(Debug, Clone)] -struct ConnectionInfo(String); - -async fn route_whoami( - conn_info: web::ReqData, - client_cert: Option>, -) -> impl Responder { - if let Some(cert) = client_cert { - HttpResponse::Ok().body(format!("{:?}\n\n{:?}", &conn_info, &cert)) - } else { - HttpResponse::Unauthorized().body("No client certificate provided.") - } +struct ConnectionInfo { + bind: SocketAddr, + peer: SocketAddr, + ttl: Option, } -fn get_client_cert(connection: &dyn Any, data: &mut Extensions) { - if let Some(tls_socket) = connection.downcast_ref::>() { - info!("TLS on_connect"); +async fn route_whoami(conn_info: web::ReqData) -> String { + format!( + "Here is some info about your connection:\n\n{:#?}", + conn_info + ) +} - let (socket, tls_session) = tls_socket.get_ref(); - - let msg = format!( - "local_addr={:?}; peer_addr={:?}", - socket.local_addr(), - socket.peer_addr() - ); - - data.insert(ConnectionInfo(msg)); - - if let Some(mut certs) = tls_session.get_peer_certificates() { - info!("client certificate found"); - - // insert a `rustls::Certificate` into request data - data.insert(certs.pop().unwrap()); - } - } else if let Some(socket) = connection.downcast_ref::() { - info!("plaintext on_connect"); - - let msg = format!( - "local_addr={:?}; peer_addr={:?}", - socket.local_addr(), - socket.peer_addr() - ); - - data.insert(ConnectionInfo(msg)); +fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { + if let Some(sock) = connection.downcast_ref::() { + data.insert(ConnectionInfo { + bind: sock.local_addr().unwrap(), + peer: sock.peer_addr().unwrap(), + ttl: sock.ttl().ok(), + }); } else { - unreachable!("socket should be TLS or plaintext"); + unreachable!("connection should only be plaintext since no TLS is set up"); } } #[actix_web::main] -async fn main() -> std::io::Result<()> { +async fn main() -> io::Result<()> { if env::var("RUST_LOG").is_err() { env::set_var("RUST_LOG", "info"); } env_logger::init(); - let ca_cert = &mut BufReader::new(File::open(CA_CERT)?); - - let mut cert_store = RootCertStore::empty(); - cert_store - .add_pem_file(ca_cert) - .expect("root CA not added to store"); - let client_auth = AllowAnyAnonymousOrAuthenticatedClient::new(cert_store); - - let mut config = ServerConfig::new(client_auth); - - let cert_file = &mut BufReader::new(File::open(SERVER_CERT)?); - let key_file = &mut BufReader::new(File::open(SERVER_KEY)?); - - let cert_chain = certs(cert_file).unwrap(); - let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - - HttpServer::new(|| App::new().route("/", web::get().to(route_whoami))) - .on_connect(get_client_cert) - .bind(("localhost", 8080))? - .bind_rustls(("localhost", 8443), config)? + HttpServer::new(|| App::new().default_service(web::to(route_whoami))) + .on_connect(get_conn_info) + .bind(("127.0.0.1", 8080))? .workers(1) .run() .await diff --git a/src/server.rs b/src/server.rs index 86aaf4265..70448fdce 100644 --- a/src/server.rs +++ b/src/server.rs @@ -103,8 +103,9 @@ where } } - /// Sets function that will be called once for each connection. - /// It will receive &Any, which contains underlying connection type. + /// Sets function that will be called once before each connection is handled. + /// It will receive a `&std::any::Any`, which contains underlying connection type and an + /// [Extensions] container so that request-local data can be passed to middleware and handlers. /// /// For example: /// - `actix_tls::openssl::SslStream` when using openssl. diff --git a/tests/test_on_connect.rs b/tests/test_on_connect.rs new file mode 100644 index 000000000..8f24c5ff9 --- /dev/null +++ b/tests/test_on_connect.rs @@ -0,0 +1,45 @@ +use std::{any::Any, env, io, net::SocketAddr}; + +use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; + +#[derive(Debug, Clone)] +struct ConnectionInfo { + bind: SocketAddr, + peer: SocketAddr, + ttl: Option, +} + +async fn route_whoami(conn_info: web::ReqData) -> String { + format!( + "Here is some info about your connection:\n\n{:#?}", + conn_info + ) +} + +fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { + if let Some(sock) = connection.downcast_ref::() { + data.insert(ConnectionInfo { + bind: sock.local_addr().unwrap(), + peer: sock.peer_addr().unwrap(), + ttl: sock.ttl().ok(), + }); + } else { + unreachable!("connection should only be plaintext since no TLS is set up"); + } +} + +#[actix_web::main] +async fn main() -> io::Result<()> { + if env::var("RUST_LOG").is_err() { + env::set_var("RUST_LOG", "info"); + } + + env_logger::init(); + + HttpServer::new(|| App::new().default_service(web::to(route_whoami))) + .on_connect(get_conn_info) + .bind(("127.0.0.1", 8080))? + .workers(1) + .run() + .await +}