From b9d6bbd35752cd301726650c4cb6166d130bff2f Mon Sep 17 00:00:00 2001
From: Niklas Fiekas <niklas.fiekas@backscattering.de>
Date: Wed, 7 Mar 2018 17:49:30 +0100
Subject: [PATCH] filter cross-site upgrades in csrf middleware

---
 src/middleware/csrf.rs | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs
index 3177cfde..d9d4692b 100644
--- a/src/middleware/csrf.rs
+++ b/src/middleware/csrf.rs
@@ -116,6 +116,7 @@ pub struct CsrfFilter {
     origins: HashSet<String>,
     allow_xhr: bool,
     allow_missing_origin: bool,
+    allow_upgrade: bool,
 }
 
 impl CsrfFilter {
@@ -126,12 +127,16 @@ impl CsrfFilter {
                 origins: HashSet::new(),
                 allow_xhr: false,
                 allow_missing_origin: false,
+                allow_upgrade: false,
             }
         }
     }
 
     fn validate<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CsrfError> {
-        if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
+        let is_upgrade = req.headers().contains_key(header::UPGRADE);
+        let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
+
+        if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
             Ok(())
         } else if let Some(header) = origin(req.headers()) {
             match header {
@@ -213,6 +218,12 @@ impl CsrfFilterBuilder {
         self
     }
 
+    /// Allow cross-site upgrade requests (for example to open a WebSocket).
+    pub fn allow_upgrade(mut self) -> CsrfFilterBuilder {
+        self.csrf.allow_upgrade = true;
+        self
+    }
+
     /// Finishes building the `CsrfFilter` instance.
     pub fn finish(self) -> CsrfFilter {
         self.csrf