Fix quality parse in Accept-Encoding HTTP header

This commit is contained in:
Arthur LE MOIGNE 2021-07-23 16:40:33 +02:00
parent c50eef6166
commit 73f412519d
No known key found for this signature in database
GPG Key ID: 25D15180A1D2E077
4 changed files with 211 additions and 69 deletions

View File

@ -1,6 +1,7 @@
//! Stream decoders. //! Stream decoders.
use std::{ use std::{
convert::TryFrom,
future::Future, future::Future,
io::{self, Write as _}, io::{self, Write as _},
pin::Pin, pin::Pin,
@ -80,7 +81,7 @@ where
let encoding = headers let encoding = headers
.get(&CONTENT_ENCODING) .get(&CONTENT_ENCODING)
.and_then(|val| val.to_str().ok()) .and_then(|val| val.to_str().ok())
.map(ContentEncoding::from) .and_then(|x| ContentEncoding::try_from(x).ok())
.unwrap_or(ContentEncoding::Identity); .unwrap_or(ContentEncoding::Identity);
Self::new(stream, encoding) Self::new(stream, encoding)

View File

@ -1,4 +1,4 @@
use std::{convert::Infallible, str::FromStr}; use std::{convert::TryFrom, error, fmt, str::FromStr};
use http::header::InvalidHeaderValue; use http::header::InvalidHeaderValue;
@ -8,6 +8,20 @@ use crate::{
HttpMessage, HttpMessage,
}; };
/// Error return when a content encoding is unknown.
///
/// Example: 'compress'
#[derive(Debug)]
pub struct ContentEncodingParseError;
impl fmt::Display for ContentEncodingParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unsupported content encoding")
}
}
impl error::Error for ContentEncodingParseError {}
/// Represents a supported content encoding. /// Represents a supported content encoding.
#[derive(Copy, Clone, PartialEq, Debug)] #[derive(Copy, Clone, PartialEq, Debug)]
pub enum ContentEncoding { pub enum ContentEncoding {
@ -48,18 +62,6 @@ impl ContentEncoding {
ContentEncoding::Identity | ContentEncoding::Auto => "identity", ContentEncoding::Identity | ContentEncoding::Auto => "identity",
} }
} }
/// Default Q-factor (quality) value.
#[inline]
pub fn quality(self) -> f64 {
match self {
ContentEncoding::Br => 1.1,
ContentEncoding::Gzip => 1.0,
ContentEncoding::Deflate => 0.9,
ContentEncoding::Identity | ContentEncoding::Auto => 0.1,
ContentEncoding::Zstd => 0.0,
}
}
} }
impl Default for ContentEncoding { impl Default for ContentEncoding {
@ -69,27 +71,29 @@ impl Default for ContentEncoding {
} }
impl FromStr for ContentEncoding { impl FromStr for ContentEncoding {
type Err = Infallible; type Err = ContentEncodingParseError;
fn from_str(val: &str) -> Result<Self, Self::Err> { fn from_str(val: &str) -> Result<Self, Self::Err> {
Ok(Self::from(val)) Self::try_from(val)
} }
} }
impl From<&str> for ContentEncoding { impl TryFrom<&str> for ContentEncoding {
fn from(val: &str) -> ContentEncoding { type Error = ContentEncodingParseError;
fn try_from(val: &str) -> Result<Self, Self::Error> {
let val = val.trim(); let val = val.trim();
if val.eq_ignore_ascii_case("br") { if val.eq_ignore_ascii_case("br") {
ContentEncoding::Br Ok(ContentEncoding::Br)
} else if val.eq_ignore_ascii_case("gzip") { } else if val.eq_ignore_ascii_case("gzip") {
ContentEncoding::Gzip Ok(ContentEncoding::Gzip)
} else if val.eq_ignore_ascii_case("deflate") { } else if val.eq_ignore_ascii_case("deflate") {
ContentEncoding::Deflate Ok(ContentEncoding::Deflate)
} else if val.eq_ignore_ascii_case("zstd") { } else if val.eq_ignore_ascii_case("zstd") {
ContentEncoding::Zstd Ok(ContentEncoding::Zstd)
} else { } else {
ContentEncoding::default() Err(ContentEncodingParseError)
} }
} }
} }

View File

@ -2,27 +2,27 @@
use std::{ use std::{
cmp, cmp,
convert::TryFrom,
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
str::FromStr,
task::{Context, Poll}, task::{Context, Poll},
}; };
use actix_http::{ use actix_http::{
body::{MessageBody, ResponseBody}, body::{AnyBody, ResponseBody},
encoding::Encoder, encoding::Encoder,
http::header::{ContentEncoding, ACCEPT_ENCODING}, http::header::{ContentEncoding, ACCEPT_ENCODING},
}; };
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_utils::future::{ok, Ready}; use actix_utils::future::{ok, Either, Ready};
use futures_core::ready; use futures_core::ready;
use pin_project::pin_project; use pin_project::pin_project;
use crate::{ use crate::{
dev::BodyEncoding, dev::BodyEncoding,
service::{ServiceRequest, ServiceResponse}, service::{ServiceRequest, ServiceResponse},
Error, Error, HttpResponse,
}; };
/// Middleware for compressing response payloads. /// Middleware for compressing response payloads.
@ -54,12 +54,11 @@ impl Default for Compress {
} }
} }
impl<S, B> Transform<S, ServiceRequest> for Compress impl<S> Transform<S, ServiceRequest> for Compress
where where
B: MessageBody, S: Service<ServiceRequest, Response = ServiceResponse<AnyBody>, Error = Error>,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{ {
type Response = ServiceResponse<ResponseBody<Encoder<B>>>; type Response = ServiceResponse<ResponseBody<Encoder<AnyBody>>>;
type Error = Error; type Error = Error;
type Transform = CompressMiddleware<S>; type Transform = CompressMiddleware<S>;
type InitError = (); type InitError = ();
@ -78,56 +77,69 @@ pub struct CompressMiddleware<S> {
encoding: ContentEncoding, encoding: ContentEncoding,
} }
impl<S, B> Service<ServiceRequest> for CompressMiddleware<S> impl<S> Service<ServiceRequest> for CompressMiddleware<S>
where where
B: MessageBody, S: Service<ServiceRequest, Response = ServiceResponse<AnyBody>, Error = Error>,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{ {
type Response = ServiceResponse<ResponseBody<Encoder<B>>>; type Response = ServiceResponse<ResponseBody<Encoder<AnyBody>>>;
type Error = Error; type Error = Error;
type Future = CompressResponse<S, B>; type Future = Either<CompressResponse<S>, Ready<Result<Self::Response, Self::Error>>>;
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
#[allow(clippy::borrow_interior_mutable_const)] #[allow(clippy::borrow_interior_mutable_const)]
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
// negotiate content-encoding // negotiate content-encoding
let encoding = if let Some(val) = req.headers().get(&ACCEPT_ENCODING) { let encoding_result = req
if let Ok(enc) = val.to_str() { .headers()
AcceptEncoding::parse(enc, self.encoding) .get(&ACCEPT_ENCODING)
} else { .and_then(|val| val.to_str().ok())
ContentEncoding::Identity .map(|enc| AcceptEncoding::try_parse(enc, self.encoding));
}
} else {
ContentEncoding::Identity
};
CompressResponse { match encoding_result {
encoding, // Missing header => fallback to identity
fut: self.service.call(req), None => Either::left(CompressResponse {
_phantom: PhantomData, encoding: ContentEncoding::Identity,
fut: self.service.call(req),
_phantom: PhantomData,
}),
// Valid encoding
Some(Ok(encoding)) => Either::left(CompressResponse {
encoding,
fut: self.service.call(req),
_phantom: PhantomData,
}),
// There is an HTTP header but we cannot match what client as asked for
Some(Err(_)) => {
let res = HttpResponse::NotAcceptable().finish();
let enc = ContentEncoding::Identity;
Either::right(ok(req.into_response(res.map_body(move |head, body| {
Encoder::response(enc, head, ResponseBody::Body(body))
}))))
}
} }
} }
} }
#[pin_project] #[pin_project]
pub struct CompressResponse<S, B> pub struct CompressResponse<S>
where where
S: Service<ServiceRequest>, S: Service<ServiceRequest>,
B: MessageBody,
{ {
#[pin] #[pin]
fut: S::Future, fut: S::Future,
encoding: ContentEncoding, encoding: ContentEncoding,
_phantom: PhantomData<B>, _phantom: PhantomData<AnyBody>,
} }
impl<S, B> Future for CompressResponse<S, B> impl<S> Future for CompressResponse<S>
where where
B: MessageBody, S: Service<ServiceRequest, Response = ServiceResponse<AnyBody>, Error = Error>,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{ {
type Output = Result<ServiceResponse<ResponseBody<Encoder<B>>>, Error>; type Output = Result<ServiceResponse<ResponseBody<Encoder<AnyBody>>>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
@ -177,26 +189,56 @@ impl PartialOrd for AcceptEncoding {
impl PartialEq for AcceptEncoding { impl PartialEq for AcceptEncoding {
fn eq(&self, other: &AcceptEncoding) -> bool { fn eq(&self, other: &AcceptEncoding) -> bool {
self.quality == other.quality self.quality == other.quality && self.encoding == other.encoding
} }
} }
/// Parse qfactor from HTTP header
///
/// If parse fail, then fallback to default value which is 1
/// More details available here: https://developer.mozilla.org/en-US/docs/Glossary/Quality_values
fn parse_quality(parts: &[&str]) -> f64 {
for part in parts {
if part.starts_with("q=") {
return part[2..].parse().unwrap_or(1.0);
}
}
1.0
}
#[derive(Debug, PartialEq, Eq)]
enum AcceptEncodingError {
/// This error occurs when client only support compressed response
/// and server do not have any algorithm that match client accepted
/// algorithms
CompressionAlgorithmMismatch,
}
impl AcceptEncoding { impl AcceptEncoding {
fn new(tag: &str) -> Option<AcceptEncoding> { fn new(tag: &str) -> Option<AcceptEncoding> {
let parts: Vec<&str> = tag.split(';').collect(); let parts: Vec<&str> = tag.split(';').collect();
let encoding = match parts.len() { let encoding = match parts.len() {
0 => return None, 0 => return None,
_ => ContentEncoding::from(parts[0]), _ => match ContentEncoding::try_from(parts[0]) {
}; Err(_) => return None,
let quality = match parts.len() { Ok(x) => x,
1 => encoding.quality(), },
_ => f64::from_str(parts[1]).unwrap_or(0.0),
}; };
let quality = parse_quality(&parts[1..]);
if quality <= 0.0 || quality > 1.0 {
return None;
}
Some(AcceptEncoding { encoding, quality }) Some(AcceptEncoding { encoding, quality })
} }
/// Parse a raw Accept-Encoding header value into an ordered list. /// Parse a raw Accept-Encoding header value into an ordered list
pub fn parse(raw: &str, encoding: ContentEncoding) -> ContentEncoding { /// then return the best match based on middleware configuration.
pub fn try_parse(
raw: &str,
encoding: ContentEncoding,
) -> Result<ContentEncoding, AcceptEncodingError> {
let mut encodings = raw let mut encodings = raw
.replace(' ', "") .replace(' ', "")
.split(',') .split(',')
@ -206,13 +248,89 @@ impl AcceptEncoding {
encodings.sort(); encodings.sort();
for enc in encodings { for enc in encodings {
if encoding == ContentEncoding::Auto { if encoding == ContentEncoding::Auto || encoding == enc.encoding {
return enc.encoding; return Ok(enc.encoding);
} else if encoding == enc.encoding {
return encoding;
} }
} }
ContentEncoding::Identity // Special case if user cannot accept uncompressed data
// See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
if raw.contains("*;q=0") || raw.contains("identity;q=0") {
return Err(AcceptEncodingError::CompressionAlgorithmMismatch);
}
Ok(ContentEncoding::Identity)
}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_parse_eq {
($raw:expr, $result:expr) => {
assert_eq!(
AcceptEncoding::try_parse($raw, ContentEncoding::Auto),
Ok($result)
);
};
}
macro_rules! assert_parse_fail {
($raw:expr) => {
assert!(AcceptEncoding::try_parse($raw, ContentEncoding::Auto).is_err());
};
}
#[test]
fn test_parse_encoding() {
// Test simple case
assert_parse_eq!("br", ContentEncoding::Br);
assert_parse_eq!("gzip", ContentEncoding::Gzip);
assert_parse_eq!("deflate", ContentEncoding::Deflate);
assert_parse_eq!("zstd", ContentEncoding::Zstd);
// Test space, trim, missing values
assert_parse_eq!("br,,,,", ContentEncoding::Br);
assert_parse_eq!("gzip , br, zstd", ContentEncoding::Gzip);
// Test float number parsing
assert_parse_eq!("br;q=1 ,", ContentEncoding::Br);
assert_parse_eq!("br;q=1.0 , br", ContentEncoding::Br);
// Test wildcard
assert_parse_eq!("*", ContentEncoding::Identity);
assert_parse_eq!("*;q=1.0", ContentEncoding::Identity);
}
#[test]
fn test_parse_encoding_qfactor_ordering() {
assert_parse_eq!("gzip, br, zstd", ContentEncoding::Gzip);
assert_parse_eq!("zstd, br, gzip", ContentEncoding::Zstd);
assert_parse_eq!("gzip;q=0.4, br;q=0.6", ContentEncoding::Br);
assert_parse_eq!("gzip;q=0.8, br;q=0.4", ContentEncoding::Gzip);
}
#[test]
fn test_parse_encoding_qfactor_invalid() {
// Out of range
assert_parse_eq!("gzip;q=-5.0", ContentEncoding::Identity);
assert_parse_eq!("gzip;q=5.0", ContentEncoding::Identity);
// Disabled
assert_parse_eq!("gzip;q=0", ContentEncoding::Identity);
}
#[test]
fn test_parse_compression_required() {
// Check we fallback to identity if there is an unsuported compression algorithm
assert_parse_eq!("compress", ContentEncoding::Identity);
// User do not want any compression
assert_parse_fail!("compress, identity;q=0");
assert_parse_fail!("compress, identity;q=0.0");
assert_parse_fail!("compress, *;q=0");
assert_parse_fail!("compress, *;q=0.0");
} }
} }

View File

@ -1077,3 +1077,22 @@ async fn test_data_drop() {
assert_eq!(num.load(Ordering::SeqCst), 0); assert_eq!(num.load(Ordering::SeqCst), 0);
} }
#[actix_rt::test]
async fn test_accept_encoding_no_match() {
let srv = actix_test::start_with(actix_test::config().h1(), || {
App::new()
.wrap(Compress::default())
.service(web::resource("/").route(web::to(move || HttpResponse::Ok().finish())))
});
let response = srv
.get("/")
.append_header((ACCEPT_ENCODING, "compress, identity;q=0"))
.no_decompress()
.send()
.await
.unwrap();
assert_eq!(response.status().as_u16(), 406);
}