From 6c480fae90b9515c0615ad009deae09183b88375 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Tue, 27 Feb 2018 11:31:54 -0800
Subject: [PATCH] added HttpRequest::encoding() method; fix urlencoded parsing
 with charset

---
 CHANGES.md                 |   2 +
 Cargo.toml                 |   3 +-
 examples/state/src/main.rs |   5 +-
 guide/src/qs_8.md          |   3 +-
 src/error.rs               |  22 ++++++-
 src/httprequest.rs         | 119 ++++++++++++++++++++++++++++++-------
 src/lib.rs                 |   1 +
 7 files changed, 127 insertions(+), 28 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index a406bdfa..04e1aed8 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -8,6 +8,8 @@
 
 * Simplify HttpServer type definition
 
+* Added HttpRequest::encoding() method
+
 * Added HttpRequest::mime_type() method
 
 * Added HttpRequest::uri_mut(), allows to modify request uri
diff --git a/Cargo.toml b/Cargo.toml
index b7999b74..bdafe012 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -56,7 +56,8 @@ serde_json = "1.0"
 sha1 = "0.4"
 smallvec = "0.6"
 time = "0.1"
-url = "1.6"
+encoding = "0.2"
+url = { version="1.7", features=["query_encoding"] }
 cookie = { version="0.10", features=["percent-encode", "secure"] }
 
 # io
diff --git a/examples/state/src/main.rs b/examples/state/src/main.rs
index a981c7fb..f40f779e 100644
--- a/examples/state/src/main.rs
+++ b/examples/state/src/main.rs
@@ -36,8 +36,7 @@ impl Actor for MyWebSocket {
     type Context = ws::WebsocketContext<Self, AppState>;
 }
 
-impl Handler<ws::Message> for MyWebSocket {
-    type Result = ();
+impl StreamHandler<ws::Message, ws::WsError> for MyWebSocket {
 
     fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
         self.counter += 1;
@@ -46,7 +45,7 @@ impl Handler<ws::Message> for MyWebSocket {
             ws::Message::Ping(msg) => ctx.pong(&msg),
             ws::Message::Text(text) => ctx.text(text),
             ws::Message::Binary(bin) => ctx.binary(bin),
-            ws::Message::Close(_) | ws::Message::Error => {
+            ws::Message::Close(_) => {
                 ctx.stop();
             }
             _ => (),
diff --git a/guide/src/qs_8.md b/guide/src/qs_8.md
index 2e2b5420..b19e94a4 100644
--- a/guide/src/qs_8.md
+++ b/guide/src/qs_8.md
@@ -130,8 +130,7 @@ impl Actor for Ws {
     type Context = ws::WebsocketContext<Self>;
 }
 
-impl Handler<ws::Message> for Ws {
-    type Result = ();
+impl StreamHandler<ws::Message, ws::WsError> for Ws {
 
     fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
         match msg {
diff --git a/src/error.rs b/src/error.rs
index 6c50db25..d0497073 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -335,12 +335,29 @@ pub enum ExpectError {
 }
 
 impl ResponseError for ExpectError {
-
     fn error_response(&self) -> HttpResponse {
         HTTPExpectationFailed.with_body("Unknown Expect")
     }
 }
 
+/// A set of error that can occure during parsing content type
+#[derive(Fail, PartialEq, Debug)]
+pub enum ContentTypeError {
+    /// Can not parse content type
+    #[fail(display="Can not parse content type")]
+    ParseError,
+    /// Unknown content encoding
+    #[fail(display="Unknown content encoding")]
+    UnknownEncoding,
+}
+
+/// Return `BadRequest` for `ContentTypeError`
+impl ResponseError for ContentTypeError {
+    fn error_response(&self) -> HttpResponse {
+        HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty)
+    }
+}
+
 /// A set of errors that can occur during parsing urlencoded payloads
 #[derive(Fail, Debug)]
 pub enum UrlencodedError {
@@ -356,6 +373,9 @@ pub enum UrlencodedError {
     /// Content type error
     #[fail(display="Content type error")]
     ContentType,
+    /// Parse error
+    #[fail(display="Parse error")]
+    Parse,
     /// Payload error
     #[fail(display="Error that occur during reading payload: {}", _0)]
     Payload(#[cause] PayloadError),
diff --git a/src/httprequest.rs b/src/httprequest.rs
index aa8df4f5..ca70b6ed 100644
--- a/src/httprequest.rs
+++ b/src/httprequest.rs
@@ -11,6 +11,9 @@ use serde::de::DeserializeOwned;
 use mime::Mime;
 use failure;
 use url::{Url, form_urlencoded};
+use encoding::all::UTF_8;
+use encoding::EncodingRef;
+use encoding::label::encoding_from_whatwg_label;
 use http::{header, Uri, Method, Version, HeaderMap, Extensions};
 use tokio_io::AsyncRead;
 
@@ -21,7 +24,7 @@ use payload::Payload;
 use json::JsonBody;
 use multipart::Multipart;
 use helpers::SharedHttpMessage;
-use error::{ParseError, UrlGenerationError,
+use error::{ParseError, ContentTypeError, UrlGenerationError,
             CookieParseError, HttpRangeError, PayloadError, UrlencodedError};
 
 
@@ -389,17 +392,38 @@ impl<S> HttpRequest<S> {
         ""
     }
 
+    /// Get content type encoding
+    ///
+    /// UTF-8 is used by default, If request charset is not set.
+    pub fn encoding(&self) -> Result<EncodingRef, ContentTypeError> {
+        if let Some(mime_type) = self.mime_type()? {
+            if let Some(charset) = mime_type.get_param("charset") {
+                if let Some(enc) = encoding_from_whatwg_label(charset.as_str()) {
+                    Ok(enc)
+                } else {
+                    Err(ContentTypeError::UnknownEncoding)
+                }
+            } else {
+                Ok(UTF_8)
+            }
+        } else {
+            Ok(UTF_8)
+        }
+    }
+
     /// Convert the request content type to a known mime type.
-    pub fn mime_type(&self) -> Option<Mime> {
+    pub fn mime_type(&self) -> Result<Option<Mime>, ContentTypeError> {
         if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) {
             if let Ok(content_type) = content_type.to_str() {
                 return match content_type.parse() {
-                    Ok(mt) => Some(mt),
-                    Err(_) => None
+                    Ok(mt) => Ok(Some(mt)),
+                    Err(_) => Err(ContentTypeError::ParseError),
                 };
+            } else {
+                return Err(ContentTypeError::ParseError)
             }
         }
-        None
+        Ok(None)
     }
 
     /// Check if request requires connection upgrade
@@ -722,17 +746,10 @@ impl Future for UrlEncoded {
             }
 
             // check content type
-            let mut err = true;
-            if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
-                if let Ok(content_type) = content_type.to_str() {
-                    if content_type.to_lowercase() == "application/x-www-form-urlencoded" {
-                        err = false;
-                    }
-                }
-            }
-            if err {
-                return Err(UrlencodedError::ContentType);
+            if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" {
+                return Err(UrlencodedError::ContentType)
             }
+            let encoding = req.encoding().map_err(|_| UrlencodedError::ContentType)?;
 
             // future
             let limit = self.limit;
@@ -745,12 +762,14 @@ impl Future for UrlEncoded {
                         Ok(body)
                     }
                 })
-                .map(|body| {
+                .and_then(move |body| {
                     let mut m = HashMap::new();
-                    for (k, v) in form_urlencoded::parse(&body) {
+                    let parsed = form_urlencoded::parse_with_encoding(
+                        &body, Some(encoding), false).map_err(|_| UrlencodedError::Parse)?;
+                    for (k, v) in parsed {
                         m.insert(k.into(), v.into());
                     }
-                    m
+                    Ok(m)
                 });
             self.fut = Some(Box::new(fut));
         }
@@ -828,8 +847,11 @@ impl Future for RequestBody {
 mod tests {
     use super::*;
     use mime;
+    use encoding::Encoding;
+    use encoding::all::ISO_8859_2;
     use http::{Uri, HttpTryFrom};
     use std::str::FromStr;
+    use std::iter::FromIterator;
     use router::Pattern;
     use resource::Resource;
     use test::TestRequest;
@@ -856,17 +878,49 @@ mod tests {
     #[test]
     fn test_mime_type() {
         let req = TestRequest::with_header("content-type", "application/json").finish();
-        assert_eq!(req.mime_type(), Some(mime::APPLICATION_JSON));
+        assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON));
         let req = HttpRequest::default();
-        assert_eq!(req.mime_type(), None);
+        assert_eq!(req.mime_type().unwrap(), None);
         let req = TestRequest::with_header(
             "content-type", "application/json; charset=utf-8").finish();
-        let mt = req.mime_type().unwrap();
+        let mt = req.mime_type().unwrap().unwrap();
         assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8));
         assert_eq!(mt.type_(), mime::APPLICATION);
         assert_eq!(mt.subtype(), mime::JSON);
     }
 
+    #[test]
+    fn test_mime_type_error() {
+        let req = TestRequest::with_header(
+            "content-type", "applicationadfadsfasdflknadsfklnadsfjson").finish();
+        assert_eq!(Err(ContentTypeError::ParseError), req.mime_type());
+    }
+
+    #[test]
+    fn test_encoding() {
+        let req = HttpRequest::default();
+        assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
+
+        let req = TestRequest::with_header(
+            "content-type", "application/json").finish();
+        assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
+
+        let req = TestRequest::with_header(
+            "content-type", "application/json; charset=ISO-8859-2").finish();
+        assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name());
+    }
+
+    #[test]
+    fn test_encoding_error() {
+        let req = TestRequest::with_header(
+            "content-type", "applicatjson").finish();
+        assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err());
+
+        let req = TestRequest::with_header(
+            "content-type", "application/json; charset=kkkttktk").finish();
+        assert_eq!(Some(ContentTypeError::UnknownEncoding), req.encoding().err());
+    }
+
     #[test]
     fn test_uri_mut() {
         let mut req = HttpRequest::default();
@@ -1009,6 +1063,29 @@ mod tests {
         assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::ContentType);
     }
 
+    #[test]
+    fn test_urlencoded() {
+        let mut req = TestRequest::with_header(
+            header::CONTENT_TYPE, "application/x-www-form-urlencoded")
+            .header(header::CONTENT_LENGTH, "11")
+            .finish();
+        req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
+
+        let result = req.urlencoded().poll().ok().unwrap();
+        assert_eq!(result, Async::Ready(
+            HashMap::from_iter(vec![("hello".to_owned(), "world".to_owned())])));
+
+        let mut req = TestRequest::with_header(
+            header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8")
+            .header(header::CONTENT_LENGTH, "11")
+            .finish();
+        req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
+
+        let result = req.urlencoded().poll().ok().unwrap();
+        assert_eq!(result, Async::Ready(
+            HashMap::from_iter(vec![("hello".to_owned(), "world".to_owned())])));
+}
+
     #[test]
     fn test_request_body() {
         let req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish();
diff --git a/src/lib.rs b/src/lib.rs
index 6221afb9..91ec9e94 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -77,6 +77,7 @@ extern crate serde;
 extern crate serde_json;
 extern crate flate2;
 extern crate brotli2;
+extern crate encoding;
 extern crate percent_encoding;
 extern crate smallvec;
 extern crate num_cpus;