use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, io, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use actix_rt::task::JoinHandle; use bytes::Bytes; use futures_core::future::LocalBoxFuture; use h2::client::SendRequest; use pin_project::pin_project; use crate::body::MessageBody; use crate::h1::ClientCodec; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; use super::error::SendRequestError; use super::pool::Acquired; use super::{h1proto, h2proto}; pub(crate) enum ConnectionType { H1(Io), H2(H2Connection), } /// `H2Connection` has two parts: `SendRequest` and `Connection`. /// /// `Connection` is spawned as an async task on runtime and `H2Connection` holds a handle for /// this task. Therefore, it can wake up and quit the task when SendRequest is dropped. pub(crate) struct H2Connection { handle: JoinHandle<()>, sender: SendRequest, } impl H2Connection { pub(crate) fn new( sender: SendRequest, connection: h2::client::Connection, ) -> Self where Io: AsyncRead + AsyncWrite + Unpin + 'static, { let handle = actix_rt::spawn(async move { let _ = connection.await; }); Self { handle, sender } } } // cancel spawned connection task on drop. impl Drop for H2Connection { fn drop(&mut self) { self.handle.abort(); } } // only expose sender type to public. impl Deref for H2Connection { type Target = SendRequest; fn deref(&self) -> &Self::Target { &self.sender } } impl DerefMut for H2Connection { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.sender } } pub trait Connection { type Io: AsyncRead + AsyncWrite + Unpin; /// Send request and body fn send_request( self, head: H, body: B, ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> where B: MessageBody + 'static, H: Into + 'static; /// Send request, returns Response and Framed fn open_tunnel + 'static>( self, head: H, ) -> LocalBoxFuture< 'static, Result<(ResponseHead, Framed), SendRequestError>, >; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { /// Close connection fn close(self: Pin<&mut Self>); /// Release connection to the connection pool fn release(self: Pin<&mut Self>); } #[doc(hidden)] /// HTTP client connection pub struct IoConnection where T: AsyncWrite + Unpin + 'static, { io: Option>, created: time::Instant, pool: Option>, } impl fmt::Debug for IoConnection where T: AsyncWrite + Unpin + fmt::Debug + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.io { Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), None => write!(f, "Connection(Empty)"), } } } impl IoConnection { pub(crate) fn new( io: ConnectionType, created: time::Instant, pool: Option>, ) -> Self { IoConnection { pool, created, io: Some(io), } } pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { (self.io.unwrap(), self.created) } #[cfg(test)] pub(crate) fn into_parts(self) -> (ConnectionType, time::Instant, Acquired) { (self.io.unwrap(), self.created, self.pool.unwrap()) } async fn send_request>( mut self, head: H, body: B, ) -> Result<(ResponseHead, Payload), SendRequestError> { match self.io.take().unwrap() { ConnectionType::H1(io) => { h1proto::send_request(io, head.into(), body, self.created, self.pool) .await } ConnectionType::H2(io) => { h2proto::send_request(io, head.into(), body, self.created, self.pool) .await } } } /// Send request, returns Response and Framed async fn open_tunnel>( mut self, head: H, ) -> Result<(ResponseHead, Framed), SendRequestError> { match self.io.take().unwrap() { ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { pool.release(IoConnection::new( ConnectionType::H2(io), self.created, None, )); } Err(SendRequestError::TunnelNotSupported) } } } } #[allow(dead_code)] pub enum EitherIoConnection where A: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static, { A(IoConnection), B(IoConnection), } impl Connection for EitherIoConnection where A: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; fn send_request( self, head: H, body: RB, ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> where RB: MessageBody + 'static, H: Into + 'static, { match self { EitherIoConnection::A(con) => Box::pin(con.send_request(head, body)), EitherIoConnection::B(con) => Box::pin(con.send_request(head, body)), } } /// Send request, returns Response and Framed fn open_tunnel + 'static>( self, head: H, ) -> LocalBoxFuture< 'static, Result<(ResponseHead, Framed), SendRequestError>, > { match self { EitherIoConnection::A(con) => Box::pin(async { let (head, framed) = con.open_tunnel(head).await?; Ok((head, framed.into_map_io(EitherIo::A))) }), EitherIoConnection::B(con) => Box::pin(async { let (head, framed) = con.open_tunnel(head).await?; Ok((head, framed.into_map_io(EitherIo::B))) }), } } } #[pin_project(project = EitherIoProj)] pub enum EitherIo { A(#[pin] A), B(#[pin] B), } impl AsyncRead for EitherIo where A: AsyncRead, B: AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_read(cx, buf), EitherIoProj::B(val) => val.poll_read(cx, buf), } } } impl AsyncWrite for EitherIo where A: AsyncWrite, B: AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_write(cx, buf), EitherIoProj::B(val) => val.poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_flush(cx), EitherIoProj::B(val) => val.poll_flush(cx), } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { match self.project() { EitherIoProj::A(val) => val.poll_shutdown(cx), EitherIoProj::B(val) => val.poll_shutdown(cx), } } } #[cfg(test)] mod test { use std::net; use actix_rt::net::TcpStream; use super::*; #[actix_rt::test] async fn test_h2_connection_drop() { let addr = "127.0.0.1:0".parse::().unwrap(); let listener = net::TcpListener::bind(addr).unwrap(); let local = listener.local_addr().unwrap(); std::thread::spawn(move || while listener.accept().is_ok() {}); let tcp = TcpStream::connect(local).await.unwrap(); let (sender, connection) = h2::client::handshake(tcp).await.unwrap(); let conn = H2Connection::new(sender.clone(), connection); assert!(sender.clone().ready().await.is_ok()); assert!(h2::client::SendRequest::clone(&*conn).ready().await.is_ok()); drop(conn); match sender.ready().await { Ok(_) => panic!("connection should be gone and can not be ready"), Err(e) => assert!(e.is_io()), }; } }