From 2d518318993bd1c5393136371acc0a74887c4e97 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 Aug 2018 17:24:13 -0700 Subject: [PATCH 1/5] handle socket read disconnect --- CHANGES.md | 5 ++++ src/server/h1.rs | 60 +++++++++++++++++++++++++----------------- src/server/h1writer.rs | 20 +++++++------- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 34b0a9621..d99aa8ba2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,11 @@ * Added the ability to pass a custom `TlsConnector`. +### Fixed + +* Handle socket read disconnect + + ## [0.7.4] - 2018-08-23 ### Added diff --git a/src/server/h1.rs b/src/server/h1.rs index ae5dd4655..1acae26ec 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -22,13 +22,14 @@ use super::{HttpHandler, HttpHandlerTask, IoStream}; const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const ERROR = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const SHUTDOWN = 0b0000_1000; - const DISCONNECTED = 0b0001_0000; - const POLLED = 0b0010_0000; + pub struct Flags: u8 { + const STARTED = 0b0000_0001; + const ERROR = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const SHUTDOWN = 0b0000_1000; + const READ_DISCONNECTED = 0b0001_0000; + const WRITE_DISCONNECTED = 0b0010_0000; + const POLLED = 0b0100_0000; } } @@ -93,7 +94,7 @@ where buf: BytesMut, is_eof: bool, ) -> Self { Http1 { - flags: Flags::KEEPALIVE | if is_eof { Flags::DISCONNECTED } else { Flags::empty() }, + flags: if is_eof { Flags::READ_DISCONNECTED } else { Flags::KEEPALIVE }, stream: H1Writer::new(stream, Rc::clone(&settings)), decoder: H1Decoder::new(), payload: None, @@ -117,6 +118,10 @@ where #[inline] fn can_read(&self) -> bool { + if self.flags.intersects(Flags::ERROR | Flags::READ_DISCONNECTED) { + return false + } + if let Some(ref info) = self.payload { info.need_read() == PayloadStatus::Read } else { @@ -125,6 +130,8 @@ where } fn notify_disconnect(&mut self) { + self.flags.insert(Flags::WRITE_DISCONNECTED); + // notify all tasks self.stream.disconnected(); for task in &mut self.tasks { @@ -163,11 +170,15 @@ where // shutdown if self.flags.contains(Flags::SHUTDOWN) { + if self.flags.intersects( + Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) { + return Ok(Async::Ready(())) + } match self.stream.poll_completed(true) { Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::Ready(_)) => return Ok(Async::Ready(())), Err(err) => { - debug!("Error sending data: {}", err); + debug!("Error sendips ng data: {}", err); return Err(()); } } @@ -197,11 +208,9 @@ where self.flags.insert(Flags::POLLED); return; } + // read io from socket - if !self.flags.intersects(Flags::ERROR) - && self.tasks.len() < MAX_PIPELINED_MESSAGES - && self.can_read() - { + if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES { match self.stream.get_mut().read_available(&mut self.buf) { Ok(Async::Ready((read_some, disconnected))) => { if read_some { @@ -209,7 +218,7 @@ where } if disconnected { // delay disconnect until all tasks have finished. - self.flags.insert(Flags::DISCONNECTED); + self.flags.insert(Flags::READ_DISCONNECTED); if self.tasks.is_empty() { self.client_disconnect(); } @@ -231,7 +240,9 @@ where let mut idx = 0; while idx < self.tasks.len() { // only one task can do io operation in http/1 - if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) { + if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) + && !self.flags.contains(Flags::WRITE_DISCONNECTED) + { // io is corrupted, send buffer if self.tasks[idx].flags.contains(EntryFlags::ERROR) { if let Ok(Async::NotReady) = self.stream.poll_completed(true) { @@ -295,7 +306,6 @@ where } // cleanup finished tasks - let max = self.tasks.len() >= MAX_PIPELINED_MESSAGES; while !self.tasks.is_empty() { if self.tasks[0] .flags @@ -306,15 +316,13 @@ where break; } } - // read more message - if max && self.tasks.len() >= MAX_PIPELINED_MESSAGES { - return Ok(Async::Ready(true)); - } // check stream state if self.flags.contains(Flags::STARTED) { match self.stream.poll_completed(false) { - Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::NotReady) => { + return Ok(Async::NotReady) + }, Err(err) => { debug!("Error sending data: {}", err); self.notify_disconnect(); @@ -332,8 +340,7 @@ where // deal with keep-alive and steam eof (client-side write shutdown) if self.tasks.is_empty() { // handle stream eof - if self.flags.contains(Flags::DISCONNECTED) { - self.client_disconnect(); + if self.flags.contains(Flags::READ_DISCONNECTED) { return Ok(Async::Ready(false)); } // no keep-alive @@ -451,7 +458,12 @@ where break; } } - Ok(None) => break, + Ok(None) => { + if self.flags.contains(Flags::READ_DISCONNECTED) && self.tasks.is_empty() { + self.client_disconnect(); + } + break + }, Err(e) => { self.flags.insert(Flags::ERROR); if let Some(mut payload) = self.payload.take() { diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index 8981f9df9..422f0ebc1 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -63,7 +63,9 @@ impl H1Writer { self.flags = Flags::KEEPALIVE; } - pub fn disconnected(&mut self) {} + pub fn disconnected(&mut self) { + self.flags.insert(Flags::DISCONNECTED); + } pub fn keepalive(&self) -> bool { self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) @@ -268,10 +270,7 @@ impl Writer for H1Writer { let pl: &[u8] = payload.as_ref(); let n = match Self::write_data(&mut self.stream, pl) { Err(err) => { - if err.kind() == io::ErrorKind::WriteZero { - self.disconnected(); - } - + self.disconnected(); return Err(err); } Ok(val) => val, @@ -315,14 +314,15 @@ impl Writer for H1Writer { #[inline] fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { + if self.flags.contains(Flags::DISCONNECTED) { + return Err(io::Error::new(io::ErrorKind::Other, "disconnected")); + } + if !self.buffer.is_empty() { let written = { match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) { Err(err) => { - if err.kind() == io::ErrorKind::WriteZero { - self.disconnected(); - } - + self.disconnected(); return Err(err); } Ok(val) => val, @@ -339,7 +339,7 @@ impl Writer for H1Writer { self.stream.poll_flush()?; self.stream.shutdown() } else { - self.stream.poll_flush() + Ok(self.stream.poll_flush()?) } } } From 3fa23f5e10ce89ee7f06bddf0a8b3ac35062cd39 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 Aug 2018 17:25:15 -0700 Subject: [PATCH 2/5] update version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index bc182b16e..631b48dc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "0.7.4" +version = "0.7.5" authors = ["Nikolay Kim "] description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." readme = "README.md" From c313c003a4b8b3526b33f782996116263cba7140 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 Aug 2018 17:45:29 -0700 Subject: [PATCH 3/5] Fix typo --- src/server/h1.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/h1.rs b/src/server/h1.rs index 1acae26ec..dd8497101 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -178,7 +178,7 @@ where Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::Ready(_)) => return Ok(Async::Ready(())), Err(err) => { - debug!("Error sendips ng data: {}", err); + debug!("Error sending data: {}", err); return Err(()); } } From 0b42cae08254768d7b16ab95ffce1d2269ff0b05 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 Aug 2018 18:54:19 -0700 Subject: [PATCH 4/5] update tests --- src/server/h1.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/server/h1.rs b/src/server/h1.rs index 1acae26ec..652922973 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -291,9 +291,8 @@ where } else if !self.tasks[idx].flags.contains(EntryFlags::FINISHED) { match self.tasks[idx].pipe.poll_completed() { Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - self.tasks[idx].flags.insert(EntryFlags::FINISHED) - } + Ok(Async::Ready(_)) => + self.tasks[idx].flags.insert(EntryFlags::FINISHED), Err(err) => { self.notify_disconnect(); self.tasks[idx].flags.insert(EntryFlags::ERROR); @@ -618,24 +617,35 @@ mod tests { } #[test] - fn test_req_parse() { + fn test_req_parse1() { let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); let readbuf = BytesMut::new(); let settings = Rc::new(wrk_settings()); - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false); h1.poll_io(); h1.poll_io(); assert_eq!(h1.tasks.len(), 1); } + #[test] + fn test_req_parse2() { + let buf = Buffer::new(""); + let readbuf = BytesMut::from(Vec::::from(&b"GET /test HTTP/1.1\r\n\r\n"[..])); + let settings = Rc::new(wrk_settings()); + + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); + h1.poll_io(); + assert_eq!(h1.tasks.len(), 1); + } + #[test] fn test_req_parse_err() { let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); let readbuf = BytesMut::new(); let settings = Rc::new(wrk_settings()); - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false); h1.poll_io(); h1.poll_io(); assert!(h1.flags.contains(Flags::ERROR)); From a2b170fec96d0d101dcd7e1abd31c5595d88e453 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 Aug 2018 18:56:21 -0700 Subject: [PATCH 5/5] fmt --- src/client/connector.rs | 86 ++++++++++++++++++----------------------- src/client/parser.rs | 8 ++-- src/param.rs | 1 - src/server/channel.rs | 5 ++- src/server/h1.rs | 41 +++++++++++++------- 5 files changed, 71 insertions(+), 70 deletions(-) diff --git a/src/client/connector.rs b/src/client/connector.rs index 430a0f752..694e03bc9 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -19,36 +19,28 @@ use tokio_timer::Delay; #[cfg(feature = "alpn")] use { openssl::ssl::{Error as SslError, SslConnector, SslMethod}, - tokio_openssl::SslConnectorExt + tokio_openssl::SslConnectorExt, }; #[cfg(all(feature = "tls", not(feature = "alpn")))] use { native_tls::{Error as SslError, TlsConnector as NativeTlsConnector}, - tokio_tls::TlsConnector as SslConnector + tokio_tls::TlsConnector as SslConnector, }; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] +#[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) +))] use { - rustls::ClientConfig, - std::io::Error as SslError, - std::sync::Arc, - tokio_rustls::ClientConfigExt, - webpki::DNSNameRef, - webpki_roots, + rustls::ClientConfig, std::io::Error as SslError, std::sync::Arc, + tokio_rustls::ClientConfigExt, webpki::DNSNameRef, webpki_roots, }; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] +#[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) +))] type SslConnector = Arc; #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] @@ -255,17 +247,19 @@ impl Default for ClientConnector { fn default() -> ClientConnector { let connector = { #[cfg(all(feature = "alpn"))] - { SslConnector::builder(SslMethod::tls()).unwrap().build() } + { + SslConnector::builder(SslMethod::tls()).unwrap().build() + } #[cfg(all(feature = "tls", not(feature = "alpn")))] - { NativeTlsConnector::builder().build().unwrap().into() } + { + NativeTlsConnector::builder().build().unwrap().into() + } - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) + ))] { let mut config = ClientConfig::new(); config @@ -275,7 +269,9 @@ impl Default for ClientConnector { } #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] - { () } + { + () + } }; ClientConnector::with_connector_impl(connector) @@ -327,12 +323,10 @@ impl ClientConnector { Self::with_connector_impl(connector) } - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) + ))] /// Create `ClientConnector` actor with custom `SslConnector` instance. /// /// By default `ClientConnector` uses very a simple SSL configuration. @@ -382,12 +376,10 @@ impl ClientConnector { Self::with_connector_impl(Arc::new(connector)) } - #[cfg( - all( - feature = "tls", - not(any(feature = "alpn", feature = "rust-tls")) - ) - )] + #[cfg(all( + feature = "tls", + not(any(feature = "alpn", feature = "rust-tls")) + ))] pub fn with_connector(connector: SslConnector) -> ClientConnector { // keep level of indirection for docstrings matching featureflags Self::with_connector_impl(connector) @@ -772,12 +764,10 @@ impl ClientConnector { } } - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) + ))] match res { Err(err) => { let _ = waiter.tx.send(Err(err.into())); @@ -1263,7 +1253,7 @@ impl AsyncWrite for Connection { } #[cfg(feature = "tls")] -use tokio_tls::{TlsStream}; +use tokio_tls::TlsStream; #[cfg(feature = "tls")] /// This is temp solution untile actix-net migration diff --git a/src/client/parser.rs b/src/client/parser.rs index 5fd81da25..0ee4598de 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -50,7 +50,9 @@ impl HttpResponseParser { } Async::NotReady => { if buf.capacity() >= MAX_BUFFER_SIZE { - return Err(HttpResponseParserError::Error(ParseError::TooLarge)); + return Err(HttpResponseParserError::Error( + ParseError::TooLarge, + )); } // Parser needs more data. } @@ -63,9 +65,7 @@ impl HttpResponseParser { } Ok(Async::Ready(_)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { - return Err(HttpResponseParserError::Error(err.into())) - } + Err(err) => return Err(HttpResponseParserError::Error(err.into())), } } } diff --git a/src/param.rs b/src/param.rs index 063159d72..d0664df99 100644 --- a/src/param.rs +++ b/src/param.rs @@ -236,7 +236,6 @@ macro_rules! FROM_STR { ($type:ty) => { impl FromParam for $type { type Err = InternalError<<$type as FromStr>::Err>; - fn from_param(val: &str) -> Result { <$type as FromStr>::from_str(val) .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST)) diff --git a/src/server/channel.rs b/src/server/channel.rs index 84f301513..bec1c4c87 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -160,8 +160,9 @@ where if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { match kind { ProtocolKind::Http1 => { - self.proto = - Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf, is_eof))); + self.proto = Some(HttpProtocol::H1(h1::Http1::new( + settings, io, addr, buf, is_eof, + ))); return self.poll(); } ProtocolKind::Http2 => { diff --git a/src/server/h1.rs b/src/server/h1.rs index 652922973..f4875519e 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -94,7 +94,11 @@ where buf: BytesMut, is_eof: bool, ) -> Self { Http1 { - flags: if is_eof { Flags::READ_DISCONNECTED } else { Flags::KEEPALIVE }, + flags: if is_eof { + Flags::READ_DISCONNECTED + } else { + Flags::KEEPALIVE + }, stream: H1Writer::new(stream, Rc::clone(&settings)), decoder: H1Decoder::new(), payload: None, @@ -118,8 +122,11 @@ where #[inline] fn can_read(&self) -> bool { - if self.flags.intersects(Flags::ERROR | Flags::READ_DISCONNECTED) { - return false + if self + .flags + .intersects(Flags::ERROR | Flags::READ_DISCONNECTED) + { + return false; } if let Some(ref info) = self.payload { @@ -171,8 +178,9 @@ where // shutdown if self.flags.contains(Flags::SHUTDOWN) { if self.flags.intersects( - Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) { - return Ok(Async::Ready(())) + Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED, + ) { + return Ok(Async::Ready(())); } match self.stream.poll_completed(true) { Ok(Async::NotReady) => return Ok(Async::NotReady), @@ -240,7 +248,8 @@ where let mut idx = 0; while idx < self.tasks.len() { // only one task can do io operation in http/1 - if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) + if !io + && !self.tasks[idx].flags.contains(EntryFlags::EOF) && !self.flags.contains(Flags::WRITE_DISCONNECTED) { // io is corrupted, send buffer @@ -291,8 +300,9 @@ where } else if !self.tasks[idx].flags.contains(EntryFlags::FINISHED) { match self.tasks[idx].pipe.poll_completed() { Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => - self.tasks[idx].flags.insert(EntryFlags::FINISHED), + Ok(Async::Ready(_)) => { + self.tasks[idx].flags.insert(EntryFlags::FINISHED) + } Err(err) => { self.notify_disconnect(); self.tasks[idx].flags.insert(EntryFlags::ERROR); @@ -319,9 +329,7 @@ where // check stream state if self.flags.contains(Flags::STARTED) { match self.stream.poll_completed(false) { - Ok(Async::NotReady) => { - return Ok(Async::NotReady) - }, + Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); self.notify_disconnect(); @@ -458,11 +466,13 @@ where } } Ok(None) => { - if self.flags.contains(Flags::READ_DISCONNECTED) && self.tasks.is_empty() { + if self.flags.contains(Flags::READ_DISCONNECTED) + && self.tasks.is_empty() + { self.client_disconnect(); } - break - }, + break; + } Err(e) => { self.flags.insert(Flags::ERROR); if let Some(mut payload) = self.payload.take() { @@ -631,7 +641,8 @@ mod tests { #[test] fn test_req_parse2() { let buf = Buffer::new(""); - let readbuf = BytesMut::from(Vec::::from(&b"GET /test HTTP/1.1\r\n\r\n"[..])); + let readbuf = + BytesMut::from(Vec::::from(&b"GET /test HTTP/1.1\r\n\r\n"[..])); let settings = Rc::new(wrk_settings()); let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);