Correct unclear code (#1611)

`content_length = None` doesn't work if Upgrade headers may appear
before Content-Length headers.

As described in the comment at L224, it's intended to use
PayloadLength::Upgrade only for `Upgrade: websocket`.
This commit is contained in:
masnagam 2020-07-20 20:32:41 +09:00
parent 971ba3eee1
commit 7257ac8487
1 changed files with 6 additions and 9 deletions

View File

@ -46,7 +46,7 @@ impl<T: MessageType> Decoder for MessageDecoder<T> {
pub(crate) enum PayloadLength { pub(crate) enum PayloadLength {
Payload(PayloadType), Payload(PayloadType),
Upgrade, UpgradeWebSocket,
None, None,
} }
@ -65,7 +65,7 @@ pub(crate) trait MessageType: Sized {
raw_headers: &[HeaderIndex], raw_headers: &[HeaderIndex],
) -> Result<PayloadLength, ParseError> { ) -> Result<PayloadLength, ParseError> {
let mut ka = None; let mut ka = None;
let mut has_upgrade = false; let mut has_upgrade_websocket = false;
let mut expect = false; let mut expect = false;
let mut chunked = false; let mut chunked = false;
let mut content_length = None; let mut content_length = None;
@ -124,12 +124,9 @@ pub(crate) trait MessageType: Sized {
}; };
} }
header::UPGRADE => { header::UPGRADE => {
has_upgrade = true;
// check content-length, some clients (dart)
// sends "content-length: 0" with websocket upgrade
if let Ok(val) = value.to_str().map(|val| val.trim()) { if let Ok(val) = value.to_str().map(|val| val.trim()) {
if val.eq_ignore_ascii_case("websocket") { if val.eq_ignore_ascii_case("websocket") {
content_length = None; has_upgrade_websocket = true;
} }
} }
} }
@ -156,13 +153,13 @@ pub(crate) trait MessageType: Sized {
Ok(PayloadLength::Payload(PayloadType::Payload( Ok(PayloadLength::Payload(PayloadType::Payload(
PayloadDecoder::chunked(), PayloadDecoder::chunked(),
))) )))
} else if has_upgrade_websocket {
Ok(PayloadLength::UpgradeWebSocket)
} else if let Some(len) = content_length { } else if let Some(len) = content_length {
// Content-Length // Content-Length
Ok(PayloadLength::Payload(PayloadType::Payload( Ok(PayloadLength::Payload(PayloadType::Payload(
PayloadDecoder::length(len), PayloadDecoder::length(len),
))) )))
} else if has_upgrade {
Ok(PayloadLength::Upgrade)
} else { } else {
Ok(PayloadLength::None) Ok(PayloadLength::None)
} }
@ -222,7 +219,7 @@ impl MessageType for Request {
// payload decoder // payload decoder
let decoder = match length { let decoder = match length {
PayloadLength::Payload(pl) => pl, PayloadLength::Payload(pl) => pl,
PayloadLength::Upgrade => { PayloadLength::UpgradeWebSocket => {
// upgrade(websocket) // upgrade(websocket)
PayloadType::Stream(PayloadDecoder::eof()) PayloadType::Stream(PayloadDecoder::eof())
} }