mirror of https://github.com/fafhrd91/actix-web
Merge branch 'master' into issue_490
This commit is contained in:
commit
fe62315c50
|
@ -6,6 +6,11 @@
|
|||
|
||||
* Added the ability to pass a custom `TlsConnector`.
|
||||
|
||||
### Fixed
|
||||
|
||||
* Handle socket read disconnect
|
||||
|
||||
|
||||
## [0.7.4] - 2018-08-23
|
||||
|
||||
### Added
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "actix-web"
|
||||
version = "0.7.4"
|
||||
version = "0.7.5"
|
||||
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
|
||||
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
|
||||
readme = "README.md"
|
||||
|
|
|
@ -19,36 +19,28 @@ use tokio_timer::Delay;
|
|||
#[cfg(feature = "alpn")]
|
||||
use {
|
||||
openssl::ssl::{Error as SslError, SslConnector, SslMethod},
|
||||
tokio_openssl::SslConnectorExt
|
||||
tokio_openssl::SslConnectorExt,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "tls", not(feature = "alpn")))]
|
||||
use {
|
||||
native_tls::{Error as SslError, TlsConnector as NativeTlsConnector},
|
||||
tokio_tls::TlsConnector as SslConnector
|
||||
tokio_tls::TlsConnector as SslConnector,
|
||||
};
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "rust-tls",
|
||||
not(any(feature = "alpn", feature = "tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
use {
|
||||
rustls::ClientConfig,
|
||||
std::io::Error as SslError,
|
||||
std::sync::Arc,
|
||||
tokio_rustls::ClientConfigExt,
|
||||
webpki::DNSNameRef,
|
||||
webpki_roots,
|
||||
rustls::ClientConfig, std::io::Error as SslError, std::sync::Arc,
|
||||
tokio_rustls::ClientConfigExt, webpki::DNSNameRef, webpki_roots,
|
||||
};
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "rust-tls",
|
||||
not(any(feature = "alpn", feature = "tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
type SslConnector = Arc<ClientConfig>;
|
||||
|
||||
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
|
||||
|
@ -255,17 +247,19 @@ impl Default for ClientConnector {
|
|||
fn default() -> ClientConnector {
|
||||
let connector = {
|
||||
#[cfg(all(feature = "alpn"))]
|
||||
{ SslConnector::builder(SslMethod::tls()).unwrap().build() }
|
||||
{
|
||||
SslConnector::builder(SslMethod::tls()).unwrap().build()
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tls", not(feature = "alpn")))]
|
||||
{ NativeTlsConnector::builder().build().unwrap().into() }
|
||||
{
|
||||
NativeTlsConnector::builder().build().unwrap().into()
|
||||
}
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "rust-tls",
|
||||
not(any(feature = "alpn", feature = "tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
{
|
||||
let mut config = ClientConfig::new();
|
||||
config
|
||||
|
@ -275,7 +269,9 @@ impl Default for ClientConnector {
|
|||
}
|
||||
|
||||
#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))]
|
||||
{ () }
|
||||
{
|
||||
()
|
||||
}
|
||||
};
|
||||
|
||||
ClientConnector::with_connector_impl(connector)
|
||||
|
@ -327,12 +323,10 @@ impl ClientConnector {
|
|||
Self::with_connector_impl(connector)
|
||||
}
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "rust-tls",
|
||||
not(any(feature = "alpn", feature = "tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
/// Create `ClientConnector` actor with custom `SslConnector` instance.
|
||||
///
|
||||
/// By default `ClientConnector` uses very a simple SSL configuration.
|
||||
|
@ -382,12 +376,10 @@ impl ClientConnector {
|
|||
Self::with_connector_impl(Arc::new(connector))
|
||||
}
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "tls",
|
||||
not(any(feature = "alpn", feature = "rust-tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
pub fn with_connector(connector: SslConnector) -> ClientConnector {
|
||||
// keep level of indirection for docstrings matching featureflags
|
||||
Self::with_connector_impl(connector)
|
||||
|
@ -772,12 +764,10 @@ impl ClientConnector {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(
|
||||
all(
|
||||
#[cfg(all(
|
||||
feature = "rust-tls",
|
||||
not(any(feature = "alpn", feature = "tls"))
|
||||
)
|
||||
)]
|
||||
))]
|
||||
match res {
|
||||
Err(err) => {
|
||||
let _ = waiter.tx.send(Err(err.into()));
|
||||
|
@ -1263,7 +1253,7 @@ impl AsyncWrite for Connection {
|
|||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use tokio_tls::{TlsStream};
|
||||
use tokio_tls::TlsStream;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
/// This is temp solution untile actix-net migration
|
||||
|
|
|
@ -50,7 +50,9 @@ impl HttpResponseParser {
|
|||
}
|
||||
Async::NotReady => {
|
||||
if buf.capacity() >= MAX_BUFFER_SIZE {
|
||||
return Err(HttpResponseParserError::Error(ParseError::TooLarge));
|
||||
return Err(HttpResponseParserError::Error(
|
||||
ParseError::TooLarge,
|
||||
));
|
||||
}
|
||||
// Parser needs more data.
|
||||
}
|
||||
|
@ -63,9 +65,7 @@ impl HttpResponseParser {
|
|||
}
|
||||
Ok(Async::Ready(_)) => (),
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Err(err) => {
|
||||
return Err(HttpResponseParserError::Error(err.into()))
|
||||
}
|
||||
Err(err) => return Err(HttpResponseParserError::Error(err.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -236,7 +236,6 @@ macro_rules! FROM_STR {
|
|||
($type:ty) => {
|
||||
impl FromParam for $type {
|
||||
type Err = InternalError<<$type as FromStr>::Err>;
|
||||
|
||||
fn from_param(val: &str) -> Result<Self, Self::Err> {
|
||||
<$type as FromStr>::from_str(val)
|
||||
.map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))
|
||||
|
|
|
@ -160,8 +160,9 @@ where
|
|||
if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() {
|
||||
match kind {
|
||||
ProtocolKind::Http1 => {
|
||||
self.proto =
|
||||
Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf, is_eof)));
|
||||
self.proto = Some(HttpProtocol::H1(h1::Http1::new(
|
||||
settings, io, addr, buf, is_eof,
|
||||
)));
|
||||
return self.poll();
|
||||
}
|
||||
ProtocolKind::Http2 => {
|
||||
|
|
|
@ -22,13 +22,14 @@ use super::{HttpHandler, HttpHandlerTask, IoStream};
|
|||
const MAX_PIPELINED_MESSAGES: usize = 16;
|
||||
|
||||
bitflags! {
|
||||
struct Flags: u8 {
|
||||
pub struct Flags: u8 {
|
||||
const STARTED = 0b0000_0001;
|
||||
const ERROR = 0b0000_0010;
|
||||
const KEEPALIVE = 0b0000_0100;
|
||||
const SHUTDOWN = 0b0000_1000;
|
||||
const DISCONNECTED = 0b0001_0000;
|
||||
const POLLED = 0b0010_0000;
|
||||
const READ_DISCONNECTED = 0b0001_0000;
|
||||
const WRITE_DISCONNECTED = 0b0010_0000;
|
||||
const POLLED = 0b0100_0000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,7 +94,11 @@ where
|
|||
buf: BytesMut, is_eof: bool,
|
||||
) -> Self {
|
||||
Http1 {
|
||||
flags: Flags::KEEPALIVE | if is_eof { Flags::DISCONNECTED } else { Flags::empty() },
|
||||
flags: if is_eof {
|
||||
Flags::READ_DISCONNECTED
|
||||
} else {
|
||||
Flags::KEEPALIVE
|
||||
},
|
||||
stream: H1Writer::new(stream, Rc::clone(&settings)),
|
||||
decoder: H1Decoder::new(),
|
||||
payload: None,
|
||||
|
@ -117,6 +122,13 @@ where
|
|||
|
||||
#[inline]
|
||||
fn can_read(&self) -> bool {
|
||||
if self
|
||||
.flags
|
||||
.intersects(Flags::ERROR | Flags::READ_DISCONNECTED)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(ref info) = self.payload {
|
||||
info.need_read() == PayloadStatus::Read
|
||||
} else {
|
||||
|
@ -125,6 +137,8 @@ where
|
|||
}
|
||||
|
||||
fn notify_disconnect(&mut self) {
|
||||
self.flags.insert(Flags::WRITE_DISCONNECTED);
|
||||
|
||||
// notify all tasks
|
||||
self.stream.disconnected();
|
||||
for task in &mut self.tasks {
|
||||
|
@ -163,6 +177,11 @@ where
|
|||
|
||||
// shutdown
|
||||
if self.flags.contains(Flags::SHUTDOWN) {
|
||||
if self.flags.intersects(
|
||||
Flags::ERROR | Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED,
|
||||
) {
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
match self.stream.poll_completed(true) {
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Ok(Async::Ready(_)) => return Ok(Async::Ready(())),
|
||||
|
@ -197,11 +216,9 @@ where
|
|||
self.flags.insert(Flags::POLLED);
|
||||
return;
|
||||
}
|
||||
|
||||
// read io from socket
|
||||
if !self.flags.intersects(Flags::ERROR)
|
||||
&& self.tasks.len() < MAX_PIPELINED_MESSAGES
|
||||
&& self.can_read()
|
||||
{
|
||||
if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES {
|
||||
match self.stream.get_mut().read_available(&mut self.buf) {
|
||||
Ok(Async::Ready((read_some, disconnected))) => {
|
||||
if read_some {
|
||||
|
@ -209,7 +226,7 @@ where
|
|||
}
|
||||
if disconnected {
|
||||
// delay disconnect until all tasks have finished.
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
self.flags.insert(Flags::READ_DISCONNECTED);
|
||||
if self.tasks.is_empty() {
|
||||
self.client_disconnect();
|
||||
}
|
||||
|
@ -231,7 +248,10 @@ where
|
|||
let mut idx = 0;
|
||||
while idx < self.tasks.len() {
|
||||
// only one task can do io operation in http/1
|
||||
if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) {
|
||||
if !io
|
||||
&& !self.tasks[idx].flags.contains(EntryFlags::EOF)
|
||||
&& !self.flags.contains(Flags::WRITE_DISCONNECTED)
|
||||
{
|
||||
// io is corrupted, send buffer
|
||||
if self.tasks[idx].flags.contains(EntryFlags::ERROR) {
|
||||
if let Ok(Async::NotReady) = self.stream.poll_completed(true) {
|
||||
|
@ -295,7 +315,6 @@ where
|
|||
}
|
||||
|
||||
// cleanup finished tasks
|
||||
let max = self.tasks.len() >= MAX_PIPELINED_MESSAGES;
|
||||
while !self.tasks.is_empty() {
|
||||
if self.tasks[0]
|
||||
.flags
|
||||
|
@ -306,10 +325,6 @@ where
|
|||
break;
|
||||
}
|
||||
}
|
||||
// read more message
|
||||
if max && self.tasks.len() >= MAX_PIPELINED_MESSAGES {
|
||||
return Ok(Async::Ready(true));
|
||||
}
|
||||
|
||||
// check stream state
|
||||
if self.flags.contains(Flags::STARTED) {
|
||||
|
@ -332,8 +347,7 @@ where
|
|||
// deal with keep-alive and steam eof (client-side write shutdown)
|
||||
if self.tasks.is_empty() {
|
||||
// handle stream eof
|
||||
if self.flags.contains(Flags::DISCONNECTED) {
|
||||
self.client_disconnect();
|
||||
if self.flags.contains(Flags::READ_DISCONNECTED) {
|
||||
return Ok(Async::Ready(false));
|
||||
}
|
||||
// no keep-alive
|
||||
|
@ -451,7 +465,14 @@ where
|
|||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Ok(None) => {
|
||||
if self.flags.contains(Flags::READ_DISCONNECTED)
|
||||
&& self.tasks.is_empty()
|
||||
{
|
||||
self.client_disconnect();
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
self.flags.insert(Flags::ERROR);
|
||||
if let Some(mut payload) = self.payload.take() {
|
||||
|
@ -606,24 +627,36 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_req_parse() {
|
||||
fn test_req_parse1() {
|
||||
let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n");
|
||||
let readbuf = BytesMut::new();
|
||||
let settings = Rc::new(wrk_settings());
|
||||
|
||||
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
|
||||
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false);
|
||||
h1.poll_io();
|
||||
h1.poll_io();
|
||||
assert_eq!(h1.tasks.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_req_parse2() {
|
||||
let buf = Buffer::new("");
|
||||
let readbuf =
|
||||
BytesMut::from(Vec::<u8>::from(&b"GET /test HTTP/1.1\r\n\r\n"[..]));
|
||||
let settings = Rc::new(wrk_settings());
|
||||
|
||||
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
|
||||
h1.poll_io();
|
||||
assert_eq!(h1.tasks.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_req_parse_err() {
|
||||
let buf = Buffer::new("GET /test HTTP/1\r\n\r\n");
|
||||
let readbuf = BytesMut::new();
|
||||
let settings = Rc::new(wrk_settings());
|
||||
|
||||
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, true);
|
||||
let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf, false);
|
||||
h1.poll_io();
|
||||
h1.poll_io();
|
||||
assert!(h1.flags.contains(Flags::ERROR));
|
||||
|
|
|
@ -63,7 +63,9 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
|
|||
self.flags = Flags::KEEPALIVE;
|
||||
}
|
||||
|
||||
pub fn disconnected(&mut self) {}
|
||||
pub fn disconnected(&mut self) {
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
}
|
||||
|
||||
pub fn keepalive(&self) -> bool {
|
||||
self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE)
|
||||
|
@ -268,10 +270,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
|
|||
let pl: &[u8] = payload.as_ref();
|
||||
let n = match Self::write_data(&mut self.stream, pl) {
|
||||
Err(err) => {
|
||||
if err.kind() == io::ErrorKind::WriteZero {
|
||||
self.disconnected();
|
||||
}
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
Ok(val) => val,
|
||||
|
@ -315,14 +314,15 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
|
|||
|
||||
#[inline]
|
||||
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
|
||||
if self.flags.contains(Flags::DISCONNECTED) {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "disconnected"));
|
||||
}
|
||||
|
||||
if !self.buffer.is_empty() {
|
||||
let written = {
|
||||
match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) {
|
||||
Err(err) => {
|
||||
if err.kind() == io::ErrorKind::WriteZero {
|
||||
self.disconnected();
|
||||
}
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
Ok(val) => val,
|
||||
|
@ -339,7 +339,7 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
|
|||
self.stream.poll_flush()?;
|
||||
self.stream.shutdown()
|
||||
} else {
|
||||
self.stream.poll_flush()
|
||||
Ok(self.stream.poll_flush()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue