From 080f232a0f675e0f9eb13cba8f57133008c7f3b5 Mon Sep 17 00:00:00 2001
From: Tessa Bradbury <teresa.bradbury@bugcrowd.com>
Date: Thu, 5 Jul 2018 19:34:13 +1000
Subject: [PATCH] Use StaticFile default handler when file is inaccessible
 (#357)

* Use Staticfile default handler on all error paths

* Return an error from StaticFiles::new() if directory doesn't exist
---
 CHANGES.md         |   4 ++
 src/application.rs |   2 +-
 src/error.rs       |  18 ++++++
 src/fs.rs          | 158 +++++++++++++++++++++++----------------------
 4 files changed, 103 insertions(+), 79 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index fa18d898..b77ae686 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -49,6 +49,10 @@
 
 * `HttpRequest::cookies()` returns `Ref<Vec<Cookie<'static>>>`
 
+* `StaticFiles::new()` returns `Result<StaticFiles<S>, Error>` instead of `StaticFiles<S>`
+
+* `StaticFiles` uses the default handler if the file does not exist
+
 
 ### Removed
 
diff --git a/src/application.rs b/src/application.rs
index aa9a74d1..de474dfc 100644
--- a/src/application.rs
+++ b/src/application.rs
@@ -587,7 +587,7 @@ where
     ///     let app = App::new()
     ///         .middleware(middleware::Logger::default())
     ///         .configure(config)  // <- register resources
-    ///         .handler("/static", fs::StaticFiles::new("."));
+    ///         .handler("/static", fs::StaticFiles::new(".").unwrap());
     /// }
     /// ```
     pub fn configure<F>(self, cfg: F) -> App<S>
diff --git a/src/error.rs b/src/error.rs
index 129de76b..c024561f 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -626,6 +626,24 @@ impl From<UrlParseError> for UrlGenerationError {
     }
 }
 
+/// Errors which can occur when serving static files.
+#[derive(Fail, Debug, PartialEq)]
+pub enum StaticFileError {
+    /// Path is not a directory
+    #[fail(display = "Path is not a directory. Unable to serve static files")]
+    IsNotDirectory,
+    /// Cannot render directory
+    #[fail(display = "Unable to render directory without index file")]
+    IsDirectory,
+}
+
+/// Return `NotFound` for `StaticFileError`
+impl ResponseError for StaticFileError {
+    fn error_response(&self) -> HttpResponse {
+        HttpResponse::new(StatusCode::NOT_FOUND)
+    }
+}
+
 /// Helper type that can wrap any error and generate custom response.
 ///
 /// In following example any `io::Error` will be converted into "BAD REQUEST"
diff --git a/src/fs.rs b/src/fs.rs
index b6ba4706..f42ef2e4 100644
--- a/src/fs.rs
+++ b/src/fs.rs
@@ -18,7 +18,7 @@ use mime;
 use mime_guess::{get_mime_type, guess_mime_type};
 use percent_encoding::{utf8_percent_encode, DEFAULT_ENCODE_SET};
 
-use error::Error;
+use error::{Error, StaticFileError};
 use handler::{AsyncResult, Handler, Responder, RouteHandler, WrapHandler};
 use header;
 use http::{ContentEncoding, Method, StatusCode};
@@ -555,13 +555,12 @@ fn directory_listing<S>(
 ///
 /// fn main() {
 ///     let app = App::new()
-///         .handler("/static", fs::StaticFiles::new("."))
+///         .handler("/static", fs::StaticFiles::new(".").unwrap())
 ///         .finish();
 /// }
 /// ```
 pub struct StaticFiles<S> {
     directory: PathBuf,
-    accessible: bool,
     index: Option<String>,
     show_index: bool,
     cpu_pool: CpuPool,
@@ -577,7 +576,7 @@ impl<S: 'static> StaticFiles<S> {
     /// `StaticFile` uses `CpuPool` for blocking filesystem operations.
     /// By default pool with 20 threads is used.
     /// Pool size can be changed by setting ACTIX_CPU_POOL environment variable.
-    pub fn new<T: Into<PathBuf>>(dir: T) -> StaticFiles<S> {
+    pub fn new<T: Into<PathBuf>>(dir: T) -> Result<StaticFiles<S>, Error> {
         // use default CpuPool
         let pool = { DEFAULT_CPUPOOL.lock().clone() };
 
@@ -586,27 +585,15 @@ impl<S: 'static> StaticFiles<S> {
 
     /// Create new `StaticFiles` instance for specified base directory and
     /// `CpuPool`.
-    pub fn with_pool<T: Into<PathBuf>>(dir: T, pool: CpuPool) -> StaticFiles<S> {
-        let dir = dir.into();
+    pub fn with_pool<T: Into<PathBuf>>(dir: T, pool: CpuPool) -> Result<StaticFiles<S>, Error> {
+        let dir = dir.into().canonicalize()?;
 
-        let (dir, access) = match dir.canonicalize() {
-            Ok(dir) => {
-                if dir.is_dir() {
-                    (dir, true)
-                } else {
-                    warn!("Is not directory `{:?}`", dir);
-                    (dir, false)
-                }
-            }
-            Err(err) => {
-                warn!("Static files directory `{:?}` error: {}", dir, err);
-                (dir, false)
-            }
-        };
+        if !dir.is_dir() {
+            return Err(StaticFileError::IsNotDirectory.into())
+        }
 
-        StaticFiles {
+        Ok(StaticFiles {
             directory: dir,
-            accessible: access,
             index: None,
             show_index: false,
             cpu_pool: pool,
@@ -616,7 +603,7 @@ impl<S: 'static> StaticFiles<S> {
             renderer: Box::new(directory_listing),
             _chunk_size: 0,
             _follow_symlinks: false,
-        }
+        })
     }
 
     /// Show files listing for directories.
@@ -652,56 +639,51 @@ impl<S: 'static> StaticFiles<S> {
         self.default = Box::new(WrapHandler::new(handler));
         self
     }
+
+    fn try_handle(&self, req: &HttpRequest<S>) -> Result<AsyncResult<HttpResponse>, Error> {
+        let tail: String = req.match_info().query("tail")?;
+        let relpath = PathBuf::from_param(tail.trim_left_matches('/'))?;
+
+        // full filepath
+        let path = self.directory.join(&relpath).canonicalize()?;
+
+        if path.is_dir() {
+            if let Some(ref redir_index) = self.index {
+                // TODO: Don't redirect, just return the index content.
+                // TODO: It'd be nice if there were a good usable URL manipulation
+                // library
+                let mut new_path: String = req.path().to_owned();
+                if !new_path.ends_with('/') {
+                    new_path.push('/');
+                }
+                new_path.push_str(redir_index);
+                HttpResponse::Found()
+                    .header(header::LOCATION, new_path.as_str())
+                    .finish()
+                    .respond_to(&req)
+            } else if self.show_index {
+                let dir = Directory::new(self.directory.clone(), path);
+                Ok((*self.renderer)(&dir, &req)?.into())
+            } else {
+                Err(StaticFileError::IsDirectory.into())
+            }
+        } else {
+            NamedFile::open(path)?
+                .set_cpu_pool(self.cpu_pool.clone())
+                .respond_to(&req)?
+                .respond_to(&req)
+        }
+    }
 }
 
 impl<S: 'static> Handler<S> for StaticFiles<S> {
     type Result = Result<AsyncResult<HttpResponse>, Error>;
 
     fn handle(&self, req: &HttpRequest<S>) -> Self::Result {
-        if !self.accessible {
+        self.try_handle(req).or_else(|e| {
+            debug!("StaticFiles: Failed to handle {}: {}", req.path(), e);
             Ok(self.default.handle(req))
-        } else {
-            let relpath = match req
-                .match_info()
-                .get("tail")
-                .map(|tail| PathBuf::from_param(tail.trim_left_matches('/')))
-            {
-                Some(Ok(path)) => path,
-                _ => {
-                    return Ok(self.default.handle(req));
-                }
-            };
-
-            // full filepath
-            let path = self.directory.join(&relpath).canonicalize()?;
-
-            if path.is_dir() {
-                if let Some(ref redir_index) = self.index {
-                    // TODO: Don't redirect, just return the index content.
-                    // TODO: It'd be nice if there were a good usable URL manipulation
-                    // library
-                    let mut new_path: String = req.path().to_owned();
-                    if !new_path.ends_with('/') {
-                        new_path.push('/');
-                    }
-                    new_path.push_str(redir_index);
-                    HttpResponse::Found()
-                        .header(header::LOCATION, new_path.as_str())
-                        .finish()
-                        .respond_to(&req)
-                } else if self.show_index {
-                    let dir = Directory::new(self.directory.clone(), path);
-                    Ok((*self.renderer)(&dir, &req)?.into())
-                } else {
-                    Ok(self.default.handle(req))
-                }
-            } else {
-                NamedFile::open(path)?
-                    .set_cpu_pool(self.cpu_pool.clone())
-                    .respond_to(&req)?
-                    .respond_to(&req)
-            }
-        }
+        })
     }
 }
 
@@ -806,6 +788,7 @@ mod tests {
 
     use super::*;
     use application::App;
+    use body::{Binary, Body};
     use http::{header, Method, StatusCode};
     use test::{self, TestRequest};
 
@@ -988,7 +971,7 @@ mod tests {
     #[test]
     fn test_named_file_ranges_status_code() {
         let mut srv = test::TestServer::with_factory(|| {
-            App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
+            App::new().handler("test", StaticFiles::new(".").unwrap().index_file("Cargo.toml"))
         });
 
         // Valid range header
@@ -1018,7 +1001,7 @@ mod tests {
         let mut srv = test::TestServer::with_factory(|| {
             App::new().handler(
                 "test",
-                StaticFiles::new(".").index_file("tests/test.binary"),
+                StaticFiles::new(".").unwrap().index_file("tests/test.binary"),
             )
         });
 
@@ -1066,7 +1049,7 @@ mod tests {
         let mut srv = test::TestServer::with_factory(|| {
             App::new().handler(
                 "test",
-                StaticFiles::new(".").index_file("tests/test.binary"),
+                StaticFiles::new(".").unwrap().index_file("tests/test.binary"),
             )
         });
 
@@ -1183,14 +1166,12 @@ mod tests {
 
     #[test]
     fn test_static_files() {
-        let mut st = StaticFiles::new(".").show_files_listing();
-        st.accessible = false;
-        let req = TestRequest::default().finish();
+        let mut st = StaticFiles::new(".").unwrap().show_files_listing();
+        let req = TestRequest::with_uri("/missing").param("tail", "missing").finish();
         let resp = st.handle(&req).respond_to(&req).unwrap();
         let resp = resp.as_msg();
         assert_eq!(resp.status(), StatusCode::NOT_FOUND);
 
-        st.accessible = true;
         st.show_index = false;
         let req = TestRequest::default().finish();
         let resp = st.handle(&req).respond_to(&req).unwrap();
@@ -1210,9 +1191,30 @@ mod tests {
         assert!(format!("{:?}", resp.body()).contains("README.md"));
     }
 
+    #[test]
+    fn test_static_files_bad_directory() {
+        let st: Result<StaticFiles<()>, Error> = StaticFiles::new("missing");
+        assert!(st.is_err());
+
+        let st: Result<StaticFiles<()>, Error> = StaticFiles::new("Cargo.toml");
+        assert!(st.is_err());
+    }
+
+    #[test]
+    fn test_default_handler_file_missing() {
+        let st = StaticFiles::new(".").unwrap()
+            .default_handler(|_: &_| "default content");
+        let req = TestRequest::with_uri("/missing").param("tail", "missing").finish();
+
+        let resp = st.handle(&req).respond_to(&req).unwrap();
+        let resp = resp.as_msg();
+        assert_eq!(resp.status(), StatusCode::OK);
+        assert_eq!(resp.body(), &Body::Binary(Binary::Slice(b"default content")));
+    }
+
     #[test]
     fn test_redirect_to_index() {
-        let st = StaticFiles::new(".").index_file("index.html");
+        let st = StaticFiles::new(".").unwrap().index_file("index.html");
         let req = TestRequest::default().uri("/tests").finish();
 
         let resp = st.handle(&req).respond_to(&req).unwrap();
@@ -1235,7 +1237,7 @@ mod tests {
 
     #[test]
     fn test_redirect_to_index_nested() {
-        let st = StaticFiles::new(".").index_file("mod.rs");
+        let st = StaticFiles::new(".").unwrap().index_file("mod.rs");
         let req = TestRequest::default().uri("/src/client").finish();
         let resp = st.handle(&req).respond_to(&req).unwrap();
         let resp = resp.as_msg();
@@ -1251,7 +1253,7 @@ mod tests {
         let mut srv = test::TestServer::with_factory(|| {
             App::new()
                 .prefix("public")
-                .handler("/", StaticFiles::new(".").index_file("Cargo.toml"))
+                .handler("/", StaticFiles::new(".").unwrap().index_file("Cargo.toml"))
         });
 
         let request = srv.get().uri(srv.url("/public")).finish().unwrap();
@@ -1280,7 +1282,7 @@ mod tests {
     #[test]
     fn integration_redirect_to_index() {
         let mut srv = test::TestServer::with_factory(|| {
-            App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
+            App::new().handler("test", StaticFiles::new(".").unwrap().index_file("Cargo.toml"))
         });
 
         let request = srv.get().uri(srv.url("/test")).finish().unwrap();
@@ -1309,7 +1311,7 @@ mod tests {
     #[test]
     fn integration_percent_encoded() {
         let mut srv = test::TestServer::with_factory(|| {
-            App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
+            App::new().handler("test", StaticFiles::new(".").unwrap().index_file("Cargo.toml"))
         });
 
         let request = srv