Merge branch 'master' into issue_490

This commit is contained in:
Robert Gabriel Jakabosky 2018-09-01 14:41:15 +08:00 committed by GitHub
commit fe62315c50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 119 additions and 91 deletions

View File

@ -6,6 +6,11 @@
* Added the ability to pass a custom `TlsConnector`. * Added the ability to pass a custom `TlsConnector`.
### Fixed
* Handle socket read disconnect
## [0.7.4] - 2018-08-23 ## [0.7.4] - 2018-08-23
### Added ### Added

View File

@ -1,6 +1,6 @@
[package] [package]
name = "actix-web" name = "actix-web"
version = "0.7.4" version = "0.7.5"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
readme = "README.md" readme = "README.md"

View File

@ -19,36 +19,28 @@ use tokio_timer::Delay;
#[cfg(feature = "alpn")] #[cfg(feature = "alpn")]
use { use {
openssl::ssl::{Error as SslError, SslConnector, SslMethod}, openssl::ssl::{Error as SslError, SslConnector, SslMethod},
tokio_openssl::SslConnectorExt tokio_openssl::SslConnectorExt,
}; };
#[cfg(all(feature = "tls", not(feature = "alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
use { use {
native_tls::{Error as SslError, TlsConnector as NativeTlsConnector}, native_tls::{Error as SslError, TlsConnector as NativeTlsConnector},
tokio_tls::TlsConnector as SslConnector tokio_tls::TlsConnector as SslConnector,
}; };
#[cfg( #[cfg(all(
all( feature = "rust-tls",
feature = "rust-tls", not(any(feature = "alpn", feature = "tls"))
not(any(feature = "alpn", feature = "tls")) ))]
)
)]
use { use {
rustls::ClientConfig, rustls::ClientConfig, std::io::Error as SslError, std::sync::Arc,
std::io::Error as SslError, tokio_rustls::ClientConfigExt, webpki::DNSNameRef, webpki_roots,
std::sync::Arc,
tokio_rustls::ClientConfigExt,
webpki::DNSNameRef,
webpki_roots,
}; };
#[cfg( #[cfg(all(
all( feature = "rust-tls",
feature = "rust-tls", not(any(feature = "alpn", feature = "tls"))
not(any(feature = "alpn", feature = "tls")) ))]
)
)]
type SslConnector = Arc<ClientConfig>; type SslConnector = Arc<ClientConfig>;
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
@ -255,17 +247,19 @@ impl Default for ClientConnector {
fn default() -> ClientConnector { fn default() -> ClientConnector {
let connector = { let connector = {
#[cfg(all(feature = "alpn"))] #[cfg(all(feature = "alpn"))]
{ SslConnector::builder(SslMethod::tls()).unwrap().build() } {
SslConnector::builder(SslMethod::tls()).unwrap().build()
}
#[cfg(all(feature = "tls", not(feature = "alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
{ NativeTlsConnector::builder().build().unwrap().into() } {
NativeTlsConnector::builder().build().unwrap().into()
}
#[cfg( #[cfg(all(
all( feature = "rust-tls",
feature = "rust-tls", not(any(feature = "alpn", feature = "tls"))
not(any(feature = "alpn", feature = "tls")) ))]
)
)]
{ {
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
config config
@ -275,7 +269,9 @@ impl Default for ClientConnector {
} }
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
{ () } {
()
}
}; };
ClientConnector::with_connector_impl(connector) ClientConnector::with_connector_impl(connector)
@ -327,12 +323,10 @@ impl ClientConnector {
Self::with_connector_impl(connector) Self::with_connector_impl(connector)
} }
#[cfg( #[cfg(all(
all( feature = "rust-tls",
feature = "rust-tls", not(any(feature = "alpn", feature = "tls"))
not(any(feature = "alpn", feature = "tls")) ))]
)
)]
/// Create `ClientConnector` actor with custom `SslConnector` instance. /// Create `ClientConnector` actor with custom `SslConnector` instance.
/// ///
/// By default `ClientConnector` uses very a simple SSL configuration. /// By default `ClientConnector` uses very a simple SSL configuration.
@ -382,12 +376,10 @@ impl ClientConnector {
Self::with_connector_impl(Arc::new(connector)) Self::with_connector_impl(Arc::new(connector))
} }
#[cfg( #[cfg(all(
all( feature = "tls",
feature = "tls", not(any(feature = "alpn", feature = "rust-tls"))
not(any(feature = "alpn", feature = "rust-tls")) ))]
)
)]
pub fn with_connector(connector: SslConnector) -> ClientConnector { pub fn with_connector(connector: SslConnector) -> ClientConnector {
// keep level of indirection for docstrings matching featureflags // keep level of indirection for docstrings matching featureflags
Self::with_connector_impl(connector) Self::with_connector_impl(connector)
@ -772,12 +764,10 @@ impl ClientConnector {
} }
} }
#[cfg( #[cfg(all(
all( feature = "rust-tls",
feature = "rust-tls", not(any(feature = "alpn", feature = "tls"))
not(any(feature = "alpn", feature = "tls")) ))]
)
)]
match res { match res {
Err(err) => { Err(err) => {
let _ = waiter.tx.send(Err(err.into())); let _ = waiter.tx.send(Err(err.into()));
@ -1263,7 +1253,7 @@ impl AsyncWrite for Connection {
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
use tokio_tls::{TlsStream}; use tokio_tls::TlsStream;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
/// This is temp solution untile actix-net migration /// This is temp solution untile actix-net migration

View File

@ -50,7 +50,9 @@ impl HttpResponseParser {
} }
Async::NotReady => { Async::NotReady => {
if buf.capacity() >= MAX_BUFFER_SIZE { if buf.capacity() >= MAX_BUFFER_SIZE {
return Err(HttpResponseParserError::Error(ParseError::TooLarge)); return Err(HttpResponseParserError::Error(
ParseError::TooLarge,
));
} }
// Parser needs more data. // Parser needs more data.
} }
@ -63,9 +65,7 @@ impl HttpResponseParser {
} }
Ok(Async::Ready(_)) => (), Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => { Err(err) => return Err(HttpResponseParserError::Error(err.into())),
return Err(HttpResponseParserError::Error(err.into()))
}
} }
} }
} }

View File

@ -236,7 +236,6 @@ macro_rules! FROM_STR {
($type:ty) => { ($type:ty) => {
impl FromParam for $type { impl FromParam for $type {
type Err = InternalError<<$type as FromStr>::Err>; type Err = InternalError<<$type as FromStr>::Err>;
fn from_param(val: &str) -> Result<Self, Self::Err> { fn from_param(val: &str) -> Result<Self, Self::Err> {
<$type as FromStr>::from_str(val) <$type as FromStr>::from_str(val)
.map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST)) .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))

View File

@ -160,8 +160,9 @@ where
if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() {
match kind { match kind {
ProtocolKind::Http1 => { ProtocolKind::Http1 => {
self.proto = self.proto = Some(HttpProtocol::H1(h1::Http1::new(
Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf, is_eof))); settings, io, addr, buf, is_eof,
)));
return self.poll(); return self.poll();
} }
ProtocolKind::Http2 => { ProtocolKind::Http2 => {

View File

@ -22,13 +22,14 @@ use super::{HttpHandler, HttpHandlerTask, IoStream};
const MAX_PIPELINED_MESSAGES: usize = 16; const MAX_PIPELINED_MESSAGES: usize = 16;
bitflags! { bitflags! {
struct Flags: u8 { pub struct Flags: u8 {
const STARTED = 0b0000_0001; const STARTED = 0b0000_0001;
const ERROR = 0b0000_0010; const ERROR = 0b0000_0010;
const KEEPALIVE = 0b0000_0100; const KEEPALIVE = 0b0000_0100;
const SHUTDOWN = 0b0000_1000; const SHUTDOWN = 0b0000_1000;
const DISCONNECTED = 0b0001_0000; const READ_DISCONNECTED = 0b0001_0000;
const POLLED = 0b0010_0000; const WRITE_DISCONNECTED = 0b0010_0000;
const POLLED = 0b0100_0000;
} }
} }
@ -93,7 +94,11 @@ where
buf: BytesMut, is_eof: bool, buf: BytesMut, is_eof: bool,
) -> Self { ) -> Self {
Http1 { Http1 {
flags: Flags::KEEPALIVE | if is_eof { Flags::DISCONNECTED } else { Flags::empty() }, flags: if is_eof {
Flags::READ_DISCONNECTED
} else {
Flags::KEEPALIVE
},
stream: H1Writer::new(stream, Rc::clone(&settings)), stream: H1Writer::new(stream, Rc::clone(&settings)),
decoder: H1Decoder::new(), decoder: H1Decoder::new(),
payload: None, payload: None,
@ -117,6 +122,13 @@ where
#[inline] #[inline]
fn can_read(&self) -> bool { fn can_read(&self) -> bool {
if self
.flags
.intersects(Flags::ERROR | Flags::READ_DISCONNECTED)
{
return false;
}
if let Some(ref info) = self.payload { if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read info.need_read() == PayloadStatus::Read
} else { } else {
@ -125,6 +137,8 @@ where
} }
fn notify_disconnect(&mut self) { fn notify_disconnect(&mut self) {
self.flags.insert(Flags::WRITE_DISCONNECTED);
// notify all tasks // notify all tasks
self.stream.disconnected(); self.stream.disconnected();
for task in &mut self.tasks { for task in &mut self.tasks {
@ -163,6 +177,11 @@ where
// shutdown // shutdown
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if self.flags.intersects(
Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED,
) {
return Ok(Async::Ready(()));
}
match self.stream.poll_completed(true) { match self.stream.poll_completed(true) {
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(_)) => return Ok(Async::Ready(())), Ok(Async::Ready(_)) => return Ok(Async::Ready(())),
@ -197,11 +216,9 @@ where
self.flags.insert(Flags::POLLED); self.flags.insert(Flags::POLLED);
return; return;
} }
// read io from socket // read io from socket
if !self.flags.intersects(Flags::ERROR) if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES {
&& self.tasks.len() < MAX_PIPELINED_MESSAGES
&& self.can_read()
{
match self.stream.get_mut().read_available(&mut self.buf) { match self.stream.get_mut().read_available(&mut self.buf) {
Ok(Async::Ready((read_some, disconnected))) => { Ok(Async::Ready((read_some, disconnected))) => {
if read_some { if read_some {
@ -209,7 +226,7 @@ where
} }
if disconnected { if disconnected {
// delay disconnect until all tasks have finished. // delay disconnect until all tasks have finished.
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECTED);
if self.tasks.is_empty() { if self.tasks.is_empty() {
self.client_disconnect(); self.client_disconnect();
} }
@ -231,7 +248,10 @@ where
let mut idx = 0; let mut idx = 0;
while idx < self.tasks.len() { while idx < self.tasks.len() {
// only one task can do io operation in http/1 // only one task can do io operation in http/1
if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) { if !io
&& !self.tasks[idx].flags.contains(EntryFlags::EOF)
&& !self.flags.contains(Flags::WRITE_DISCONNECTED)
{
// io is corrupted, send buffer // io is corrupted, send buffer
if self.tasks[idx].flags.contains(EntryFlags::ERROR) { if self.tasks[idx].flags.contains(EntryFlags::ERROR) {
if let Ok(Async::NotReady) = self.stream.poll_completed(true) { if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
@ -295,7 +315,6 @@ where
} }
// cleanup finished tasks // cleanup finished tasks
let max = self.tasks.len() >= MAX_PIPELINED_MESSAGES;
while !self.tasks.is_empty() { while !self.tasks.is_empty() {
if self.tasks[0] if self.tasks[0]
.flags .flags
@ -306,10 +325,6 @@ where
break; break;
} }
} }
// read more message
if max && self.tasks.len() >= MAX_PIPELINED_MESSAGES {
return Ok(Async::Ready(true));
}
// check stream state // check stream state
if self.flags.contains(Flags::STARTED) { if self.flags.contains(Flags::STARTED) {
@ -332,8 +347,7 @@ where
// deal with keep-alive and steam eof (client-side write shutdown) // deal with keep-alive and steam eof (client-side write shutdown)
if self.tasks.is_empty() { if self.tasks.is_empty() {
// handle stream eof // handle stream eof
if self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::READ_DISCONNECTED) {
self.client_disconnect();
return Ok(Async::Ready(false)); return Ok(Async::Ready(false));
} }
// no keep-alive // no keep-alive
@ -451,7 +465,14 @@ where
break; break;
} }
} }
Ok(None) => break, Ok(None) => {
if self.flags.contains(Flags::READ_DISCONNECTED)
&& self.tasks.is_empty()
{
self.client_disconnect();
}
break;
}
Err(e) => { Err(e) => {
self.flags.insert(Flags::ERROR); self.flags.insert(Flags::ERROR);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
@ -606,24 +627,36 @@ mod tests {
} }
#[test] #[test]
fn test_req_parse() { fn test_req_parse1() {
let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n");
let readbuf = BytesMut::new(); let readbuf = BytesMut::new();
let settings = Rc::new(wrk_settings()); let settings = Rc::new(wrk_settings());
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false);
h1.poll_io(); h1.poll_io();
h1.poll_io(); h1.poll_io();
assert_eq!(h1.tasks.len(), 1); assert_eq!(h1.tasks.len(), 1);
} }
#[test]
fn test_req_parse2() {
let buf = Buffer::new("");
let readbuf =
BytesMut::from(Vec::<u8>::from(&b"GET /test HTTP/1.1\r\n\r\n"[..]));
let settings = Rc::new(wrk_settings());
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
h1.poll_io();
assert_eq!(h1.tasks.len(), 1);
}
#[test] #[test]
fn test_req_parse_err() { fn test_req_parse_err() {
let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); let buf = Buffer::new("GET /test HTTP/1\r\n\r\n");
let readbuf = BytesMut::new(); let readbuf = BytesMut::new();
let settings = Rc::new(wrk_settings()); let settings = Rc::new(wrk_settings());
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true); let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false);
h1.poll_io(); h1.poll_io();
h1.poll_io(); h1.poll_io();
assert!(h1.flags.contains(Flags::ERROR)); assert!(h1.flags.contains(Flags::ERROR));

View File

@ -63,7 +63,9 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
self.flags = Flags::KEEPALIVE; self.flags = Flags::KEEPALIVE;
} }
pub fn disconnected(&mut self) {} pub fn disconnected(&mut self) {
self.flags.insert(Flags::DISCONNECTED);
}
pub fn keepalive(&self) -> bool { pub fn keepalive(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE)
@ -268,10 +270,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
let pl: &[u8] = payload.as_ref(); let pl: &[u8] = payload.as_ref();
let n = match Self::write_data(&mut self.stream, pl) { let n = match Self::write_data(&mut self.stream, pl) {
Err(err) => { Err(err) => {
if err.kind() == io::ErrorKind::WriteZero { self.disconnected();
self.disconnected();
}
return Err(err); return Err(err);
} }
Ok(val) => val, Ok(val) => val,
@ -315,14 +314,15 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline] #[inline]
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
if self.flags.contains(Flags::DISCONNECTED) {
return Err(io::Error::new(io::ErrorKind::Other, "disconnected"));
}
if !self.buffer.is_empty() { if !self.buffer.is_empty() {
let written = { let written = {
match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) { match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) {
Err(err) => { Err(err) => {
if err.kind() == io::ErrorKind::WriteZero { self.disconnected();
self.disconnected();
}
return Err(err); return Err(err);
} }
Ok(val) => val, Ok(val) => val,
@ -339,7 +339,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
self.stream.poll_flush()?; self.stream.poll_flush()?;
self.stream.shutdown() self.stream.shutdown()
} else { } else {
self.stream.poll_flush() Ok(self.stream.poll_flush()?)
} }
} }
} }