second cleanup pass

This commit is contained in:
Rob Ede 2022-11-25 23:23:10 +00:00
parent e0de3e10b0
commit 447c42697f
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
11 changed files with 215 additions and 186 deletions

View File

@ -18,17 +18,17 @@ use syn::{parse_macro_input, Type};
#[darling(attributes(multipart), default)] #[darling(attributes(multipart), default)]
struct MultipartFormAttrs { struct MultipartFormAttrs {
deny_unknown_fields: bool, deny_unknown_fields: bool,
duplicate_action: DuplicateAction, duplicate_field: DuplicateField,
} }
#[derive(FromMeta)] #[derive(FromMeta)]
enum DuplicateAction { enum DuplicateField {
Ignore, Ignore,
Deny, Deny,
Replace, Replace,
} }
impl Default for DuplicateAction { impl Default for DuplicateField {
fn default() -> Self { fn default() -> Self {
Self::Ignore Self::Ignore
} }
@ -136,7 +136,7 @@ struct ParsedField<'t> {
/// ## Duplicate Fields /// ## Duplicate Fields
/// ///
/// You can change the behaviour for when multiple fields are received with the same name using the /// You can change the behaviour for when multiple fields are received with the same name using the
/// `#[multipart(duplicate_action = "")]` attribute: /// `#[multipart(duplicate_field = "")]` attribute:
/// ///
/// - "ignore": Extra fields are ignored (default). /// - "ignore": Extra fields are ignored (default).
/// - "replace": Each field is processed, but only the last one is persisted. /// - "replace": Each field is processed, but only the last one is persisted.
@ -147,7 +147,7 @@ struct ParsedField<'t> {
/// ``` /// ```
/// # use actix_multipart::form::MultipartForm; /// # use actix_multipart::form::MultipartForm;
/// #[derive(MultipartForm)] /// #[derive(MultipartForm)]
/// #[multipart(duplicate_action = "deny")] /// #[multipart(duplicate_field = "deny")]
/// struct Form { } /// struct Form { }
/// ``` /// ```
/// ///
@ -168,7 +168,7 @@ pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenS
let attrs: MultipartFormAttrs = match MultipartFormAttrs::from_derive_input(&input) { let attrs: MultipartFormAttrs = match MultipartFormAttrs::from_derive_input(&input) {
Ok(attrs) => attrs, Ok(attrs) => attrs,
Err(e) => return e.write_errors().into(), Err(err) => return err.write_errors().into(),
}; };
// Parse the field attributes // Parse the field attributes
@ -198,7 +198,7 @@ pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenS
.collect::<Result<Vec<_>, darling::Error>>() .collect::<Result<Vec<_>, darling::Error>>()
{ {
Ok(attrs) => attrs, Ok(attrs) => attrs,
Err(e) => return e.write_errors().into(), Err(err) => return err.write_errors().into(),
}; };
// Check that field names are unique // Check that field names are unique
@ -219,10 +219,10 @@ pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenS
}; };
// Value for duplicate action // Value for duplicate action
let duplicate_action = match attrs.duplicate_action { let duplicate_field = match attrs.duplicate_field {
DuplicateAction::Ignore => quote!(::actix_multipart::form::DuplicateAction::Ignore), DuplicateField::Ignore => quote!(::actix_multipart::form::DuplicateField::Ignore),
DuplicateAction::Deny => quote!(::actix_multipart::form::DuplicateAction::Deny), DuplicateField::Deny => quote!(::actix_multipart::form::DuplicateField::Deny),
DuplicateAction::Replace => quote!(::actix_multipart::form::DuplicateAction::Replace), DuplicateField::Replace => quote!(::actix_multipart::form::DuplicateField::Replace),
}; };
// read_field() implementation // read_field() implementation
@ -232,7 +232,7 @@ pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenS
let ty = &field.ty; let ty = &field.ty;
read_field_impl.extend(quote!( read_field_impl.extend(quote!(
#name => ::std::boxed::Box::pin( #name => ::std::boxed::Box::pin(
<#ty as ::actix_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_action) <#ty as ::actix_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field)
), ),
)); ));
} }

View File

@ -9,6 +9,10 @@ repository = "https://github.com/actix/actix-web.git"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
edition = "2018" edition = "2018"
[package.metadata.docs.rs]
rustdoc-args = ["--cfg", "docsrs"]
all-features = true
[features] [features]
default = ["tempfile", "derive"] default = ["tempfile", "derive"]
derive = ["actix-multipart-derive"] derive = ["actix-multipart-derive"]
@ -19,31 +23,31 @@ name = "actix_multipart"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-multipart-derive = { version = "=0.4.0", optional = true }
actix-http = "3"
actix-utils = "3" actix-utils = "3"
actix-web = { version = "4", default-features = false } actix-web = { version = "4", default-features = false }
actix-http = "3"
actix-multipart-derive = { version = "=0.4.0", optional = true }
bytes = "1" bytes = "1"
derive_more = "0.99.5" derive_more = "0.99.5"
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-core = "0.3.17"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.17", default-features = false, features = ["std"] }
httparse = "1.3" httparse = "1.3"
local-waker = "0.1" local-waker = "0.1"
log = "0.4" log = "0.4"
memchr = "2.5" memchr = "2.5"
mime = "0.3" mime = "0.3"
serde = "1.0" serde = "1.0"
serde_plain = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_plain = "1.0"
# TODO(MSRV 1.60): replace with dep: prefix # TODO(MSRV 1.60): replace with dep: prefix
tempfile-dep = { package = "tempfile", version = "3.3.0", optional = true } tempfile-dep = { package = "tempfile", version = "3.3.0", optional = true }
tokio = { version = "1.8.4", features = ["sync"] } tokio = { version = "1.13.1", features = ["sync"] }
[dev-dependencies] [dev-dependencies]
actix-multipart-rfc7578 = "0.10"
actix-rt = "2.2" actix-rt = "2.2"
actix-test = "0.1.0" actix-test = "0.1.0"
awc = "3.0.1" awc = "3.0.1"
actix-multipart-rfc7578 = "0.10.0"
futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] }
tokio-stream = "0.1" tokio-stream = "0.1"

View File

@ -7,9 +7,9 @@ use actix_web::{
}; };
use derive_more::{Display, Error, From}; use derive_more::{Display, Error, From};
/// A set of errors that can occur during parsing multipart streams /// A set of errors that can occur during parsing multipart streams.
#[non_exhaustive]
#[derive(Debug, Display, From, Error)] #[derive(Debug, Display, From, Error)]
#[non_exhaustive]
pub enum MultipartError { pub enum MultipartError {
/// Content-Disposition header is not found or is not equal to "form-data". /// Content-Disposition header is not found or is not equal to "form-data".
/// ///
@ -51,7 +51,11 @@ pub enum MultipartError {
NotConsumed, NotConsumed,
/// An error from a field handler in a form /// An error from a field handler in a form
#[display(fmt = "An error occurred processing field `{field_name}`: {source}")] #[display(
fmt = "An error occurred processing field `{}`: {}",
field_name,
source
)]
Field { Field {
field_name: String, field_name: String,
source: actix_web::Error, source: actix_web::Error,

View File

@ -9,8 +9,7 @@ use crate::server::Multipart;
/// ///
/// Content-type: multipart/form-data; /// Content-type: multipart/form-data;
/// ///
/// ## Server example /// # Examples
///
/// ``` /// ```
/// use actix_web::{web, HttpResponse, Error}; /// use actix_web::{web, HttpResponse, Error};
/// use actix_multipart::Multipart; /// use actix_multipart::Multipart;

View File

@ -3,7 +3,7 @@
use actix_web::HttpRequest; use actix_web::HttpRequest;
use bytes::BytesMut; use bytes::BytesMut;
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::{FutureExt, TryStreamExt}; use futures_util::TryStreamExt as _;
use mime::Mime; use mime::Mime;
use crate::{ use crate::{
@ -16,9 +16,11 @@ use crate::{
pub struct Bytes { pub struct Bytes {
/// The data. /// The data.
pub data: bytes::Bytes, pub data: bytes::Bytes,
/// The value of the `content-type` header.
/// The value of the `Content-Type` header.
pub content_type: Option<Mime>, pub content_type: Option<Mime>,
/// The `filename` value in the `content-disposition` header.
/// The `filename` value in the `Content-Disposition` header.
pub file_name: Option<String>, pub file_name: Option<String>,
} }
@ -30,21 +32,22 @@ impl<'t> FieldReader<'t> for Bytes {
mut field: Field, mut field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
) -> Self::Future { ) -> Self::Future {
async move { Box::pin(async move {
let mut data = BytesMut::new(); let mut buf = BytesMut::new();
while let Some(chunk) = field.try_next().await? { while let Some(chunk) = field.try_next().await? {
limits.try_consume_limits(chunk.len(), true)?; limits.try_consume_limits(chunk.len(), true)?;
data.extend(chunk); buf.extend(chunk);
} }
Ok(Bytes { Ok(Bytes {
data: data.freeze(), data: buf.freeze(),
content_type: field.content_type().map(ToOwned::to_owned), content_type: field.content_type().map(ToOwned::to_owned),
file_name: field file_name: field
.content_disposition() .content_disposition()
.get_filename() .get_filename()
.map(str::to_owned), .map(str::to_owned),
}) })
} })
.boxed_local()
} }
} }

View File

@ -5,7 +5,6 @@ use std::sync::Arc;
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Deref, DerefMut, Display, Error}; use derive_more::{Deref, DerefMut, Display, Error};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::FutureExt;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::{ use crate::{
@ -13,6 +12,8 @@ use crate::{
Field, MultipartError, Field, MultipartError,
}; };
use super::FieldErrorHandler;
/// Deserialize from JSON. /// Deserialize from JSON.
#[derive(Debug, Deref, DerefMut)] #[derive(Debug, Deref, DerefMut)]
pub struct Json<T: DeserializeOwned>(pub T); pub struct Json<T: DeserializeOwned>(pub T);
@ -23,11 +24,14 @@ impl<T: DeserializeOwned> Json<T> {
} }
} }
impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Json<T> { impl<'t, T> FieldReader<'t> for Json<T>
where
T: DeserializeOwned + 'static,
{
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>; type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future { fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
async move { Box::pin(async move {
let config = JsonConfig::from_req(req); let config = JsonConfig::from_req(req);
let field_name = field.name().to_owned(); let field_name = field.name().to_owned();
@ -37,6 +41,7 @@ impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Json<T> {
} else { } else {
false false
}; };
if !valid { if !valid {
return Err(MultipartError::Field { return Err(MultipartError::Field {
field_name, field_name,
@ -48,24 +53,23 @@ impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Json<T> {
let bytes = Bytes::read_field(req, field, limits).await?; let bytes = Bytes::read_field(req, field, limits).await?;
Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err( Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err(
|e| MultipartError::Field { |err| MultipartError::Field {
field_name, field_name,
source: config.map_error(req, JsonFieldError::Deserialize(e)), source: config.map_error(req, JsonFieldError::Deserialize(err)),
}, },
)?)) )?))
} })
.boxed_local()
} }
} }
#[derive(Debug, Display, Error)] #[derive(Debug, Display, Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum JsonFieldError { pub enum JsonFieldError {
/// Deserialize error /// Deserialize error.
#[display(fmt = "Json deserialize error: {}", _0)] #[display(fmt = "Json deserialize error: {}", _0)]
Deserialize(serde_json::Error), Deserialize(serde_json::Error),
/// Content type error /// Content type error.
#[display(fmt = "Content type error")] #[display(fmt = "Content type error")]
ContentType, ContentType,
} }
@ -79,8 +83,7 @@ impl ResponseError for JsonFieldError {
/// Configuration for the [`Json`] field reader. /// Configuration for the [`Json`] field reader.
#[derive(Clone)] #[derive(Clone)]
pub struct JsonConfig { pub struct JsonConfig {
#[allow(clippy::type_complexity)] err_handler: FieldErrorHandler<JsonFieldError>,
err_handler: Option<Arc<dyn Fn(JsonFieldError, &HttpRequest) -> Error + Send + Sync>>,
validate_content_type: bool, validate_content_type: bool,
} }
@ -131,9 +134,8 @@ impl Default for JsonConfig {
mod tests { mod tests {
use std::{collections::HashMap, io::Cursor}; use std::{collections::HashMap, io::Cursor};
use actix_http::StatusCode;
use actix_multipart_rfc7578::client::multipart; use actix_multipart_rfc7578::client::multipart;
use actix_web::{web, App, HttpResponse, Responder}; use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::{ use crate::form::{
json::{Json, JsonConfig}, json::{Json, JsonConfig},

View File

@ -7,11 +7,10 @@ use std::{
sync::Arc, sync::Arc,
}; };
use actix_web::{dev::Payload, error::PayloadError, web, Error, FromRequest, HttpRequest}; use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest};
use derive_more::{Deref, DerefMut}; use derive_more::{Deref, DerefMut};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::TryFutureExt; use futures_util::{TryFutureExt, TryStreamExt as _};
use futures_util::{FutureExt, TryStreamExt};
use crate::{Field, Multipart, MultipartError}; use crate::{Field, Multipart, MultipartError};
@ -26,6 +25,8 @@ pub mod text;
#[cfg(feature = "derive")] #[cfg(feature = "derive")]
pub use actix_multipart_derive::MultipartForm; pub use actix_multipart_derive::MultipartForm;
type FieldErrorHandler<T> = Option<Arc<dyn Fn(T, &HttpRequest) -> Error + Send + Sync>>;
/// Trait that data types to be used in a multipart form struct should implement. /// Trait that data types to be used in a multipart form struct should implement.
/// ///
/// It represents an asynchronous handler that processes a multipart field to produce `Self`. /// It represents an asynchronous handler that processes a multipart field to produce `Self`.
@ -47,16 +48,16 @@ pub struct State(pub HashMap<String, Box<dyn Any>>);
pub trait FieldGroupReader<'t>: Sized + Any { pub trait FieldGroupReader<'t>: Sized + Any {
type Future: Future<Output = Result<(), MultipartError>>; type Future: Future<Output = Result<(), MultipartError>>;
/// The form will call this function for each matching field /// The form will call this function for each matching field.
fn handle_field( fn handle_field(
req: &'t HttpRequest, req: &'t HttpRequest,
field: Field, field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
state: &'t mut State, state: &'t mut State,
duplicate_action: DuplicateAction, duplicate_field: DuplicateField,
) -> Self::Future; ) -> Self::Future;
/// Create `Self` from the group of processed fields /// Construct `Self` from the group of processed fields.
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError>; fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError>;
} }
@ -71,27 +72,28 @@ where
field: Field, field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
state: &'t mut State, state: &'t mut State,
duplicate_action: DuplicateAction, duplicate_field: DuplicateField,
) -> Self::Future { ) -> Self::Future {
if state.contains_key(field.name()) { if state.contains_key(field.name()) {
match duplicate_action { match duplicate_field {
DuplicateAction::Ignore => return ready(Ok(())).boxed_local(), DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
DuplicateAction::Deny => {
return ready(Err(MultipartError::DuplicateField( DuplicateField::Deny => {
return Box::pin(ready(Err(MultipartError::DuplicateField(
field.name().to_string(), field.name().to_string(),
))) ))))
.boxed_local()
} }
DuplicateAction::Replace => {}
DuplicateField::Replace => {}
} }
} }
async move {
Box::pin(async move {
let field_name = field.name().to_string(); let field_name = field.name().to_string();
let t = T::read_field(req, field, limits).await?; let t = T::read_field(req, field, limits).await?;
state.insert(field_name, Box::new(t)); state.insert(field_name, Box::new(t));
Ok(()) Ok(())
} })
.boxed_local()
} }
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> { fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
@ -110,21 +112,24 @@ where
field: Field, field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
state: &'t mut State, state: &'t mut State,
_duplicate_action: DuplicateAction, _duplicate_field: DuplicateField,
) -> Self::Future { ) -> Self::Future {
// Vec GroupReader always allows duplicates! Box::pin(async move {
async move { // Note: Vec GroupReader always allows duplicates
let field_name = field.name().to_string(); let field_name = field.name().to_string();
let vec = state let vec = state
.entry(field_name) .entry(field_name)
.or_insert_with(|| Box::new(Vec::<T>::new())) .or_insert_with(|| Box::new(Vec::<T>::new()))
.downcast_mut::<Vec<T>>() .downcast_mut::<Vec<T>>()
.unwrap(); .unwrap();
let item = T::read_field(req, field, limits).await?; let item = T::read_field(req, field, limits).await?;
vec.push(item); vec.push(item);
Ok(()) Ok(())
} })
.boxed_local()
} }
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> { fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
@ -146,27 +151,27 @@ where
field: Field, field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
state: &'t mut State, state: &'t mut State,
duplicate_action: DuplicateAction, duplicate_field: DuplicateField,
) -> Self::Future { ) -> Self::Future {
if state.contains_key(field.name()) { if state.contains_key(field.name()) {
match duplicate_action { match duplicate_field {
DuplicateAction::Ignore => return ready(Ok(())).boxed_local(), DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
DuplicateAction::Deny => {
return ready(Err(MultipartError::DuplicateField( DuplicateField::Deny => {
return Box::pin(ready(Err(MultipartError::DuplicateField(
field.name().to_string(), field.name().to_string(),
))) ))))
.boxed_local()
} }
DuplicateAction::Replace => {}
DuplicateField::Replace => {}
} }
} }
async move { Box::pin(async move {
let field_name = field.name().to_string(); let field_name = field.name().to_string();
let t = T::read_field(req, field, limits).await?; let t = T::read_field(req, field, limits).await?;
state.insert(field_name, Box::new(t)); state.insert(field_name, Box::new(t));
Ok(()) Ok(())
} })
.boxed_local()
} }
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> { fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
@ -199,11 +204,13 @@ pub trait MultipartFormTrait: Sized {
} }
#[doc(hidden)] #[doc(hidden)]
pub enum DuplicateAction { pub enum DuplicateField {
/// Additional fields are not processed /// Additional fields are not processed
Ignore, Ignore,
/// An error will be raised /// An error will be raised
Deny, Deny,
/// All fields will be processed, the last one will replace all previous /// All fields will be processed, the last one will replace all previous
Replace, Replace,
} }
@ -240,12 +247,14 @@ impl Limits {
.total_limit_remaining .total_limit_remaining
.checked_sub(bytes) .checked_sub(bytes)
.ok_or(MultipartError::Payload(PayloadError::Overflow))?; .ok_or(MultipartError::Payload(PayloadError::Overflow))?;
if in_memory { if in_memory {
self.memory_limit_remaining = self self.memory_limit_remaining = self
.memory_limit_remaining .memory_limit_remaining
.checked_sub(bytes) .checked_sub(bytes)
.ok_or(MultipartError::Payload(PayloadError::Overflow))?; .ok_or(MultipartError::Payload(PayloadError::Overflow))?;
} }
if let Some(field_limit) = self.field_limit_remaining { if let Some(field_limit) = self.field_limit_remaining {
self.field_limit_remaining = Some( self.field_limit_remaining = Some(
field_limit field_limit
@ -253,6 +262,7 @@ impl Limits {
.ok_or(MultipartError::Payload(PayloadError::Overflow))?, .ok_or(MultipartError::Payload(PayloadError::Overflow))?,
); );
} }
Ok(()) Ok(())
} }
} }
@ -260,8 +270,8 @@ impl Limits {
/// Typed `multipart/form-data` extractor. /// Typed `multipart/form-data` extractor.
/// ///
/// To extract typed data from a multipart stream, the inner type `T` must implement the /// To extract typed data from a multipart stream, the inner type `T` must implement the
/// [`MultipartFormTrait`] trait, you should use the [`macro@MultipartForm`] macro to derive this for /// [`MultipartFormTrait`] trait, you should use the [`macro@MultipartForm`] macro to derive this
/// your struct. /// for your struct.
/// ///
/// Use [`MultipartFormConfig`] to configure extraction options. /// Use [`MultipartFormConfig`] to configure extraction options.
#[derive(Deref, DerefMut)] #[derive(Deref, DerefMut)]
@ -282,14 +292,17 @@ where
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
let mut payload = Multipart::new(req.headers(), payload.take()); let mut payload = Multipart::new(req.headers(), payload.take());
let config = MultipartFormConfig::from_req(req); let config = MultipartFormConfig::from_req(req);
let mut limits = Limits::new(config.total_limit, config.memory_limit); let mut limits = Limits::new(config.total_limit, config.memory_limit);
let req = req.clone(); let req = req.clone();
let req2 = req.clone(); let req2 = req.clone();
let err_handler = config.err_handler.clone(); let err_handler = config.err_handler.clone();
Box::pin(
async move { async move {
let mut state = State::default(); let mut state = State::default();
// We need to ensure field limits are shared for all instances of this field name // We need to ensure field limits are shared for all instances of this field name
@ -310,14 +323,14 @@ where
let inner = T::from_state(state)?; let inner = T::from_state(state)?;
Ok(MultipartForm(inner)) Ok(MultipartForm(inner))
} }
.map_err(move |e| { .map_err(move |err| {
if let Some(handler) = err_handler { if let Some(handler) = err_handler {
(*handler)(e, &req2) (*handler)(err, &req2)
} else { } else {
e.into() err.into()
} }
}) }),
.boxed_local() )
} }
} }
@ -378,18 +391,13 @@ impl Default for MultipartFormConfig {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_http::encoding::Decoder; use actix_http::encoding::Decoder;
use actix_http::Payload;
use actix_multipart_rfc7578::client::multipart; use actix_multipart_rfc7578::client::multipart;
use actix_test::TestServer; use actix_test::TestServer;
use actix_web::http::StatusCode; use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder};
use actix_web::{web, App, HttpResponse, Responder};
use awc::{Client, ClientResponse}; use awc::{Client, ClientResponse};
use super::MultipartForm; use super::MultipartForm;
use crate::form::bytes::Bytes; use crate::form::{bytes::Bytes, tempfile::Tempfile, text::Text, MultipartFormConfig};
use crate::form::tempfile::Tempfile;
use crate::form::text::Text;
use crate::form::MultipartFormConfig;
pub async fn send_form( pub async fn send_form(
srv: &TestServer, srv: &TestServer,
@ -404,8 +412,7 @@ mod tests {
.unwrap() .unwrap()
} }
/// Test `Option` fields /// Test `Option` fields.
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct TestOptions { struct TestOptions {
field1: Option<Text<String>>, field1: Option<Text<String>>,
@ -430,8 +437,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
/// Test `Vec` fields /// Test `Vec` fields.
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct TestVec { struct TestVec {
list1: Vec<Text<String>>, list1: Vec<Text<String>>,
@ -463,8 +469,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
/// Test the `rename` field attribute /// Test the `rename` field attribute.
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct TestFieldRenaming { struct TestFieldRenaming {
#[multipart(rename = "renamed")] #[multipart(rename = "renamed")]
@ -498,8 +503,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
/// Test the `deny_unknown_fields` struct attribute /// Test the `deny_unknown_fields` struct attribute.
#[derive(MultipartForm)] #[derive(MultipartForm)]
#[multipart(deny_unknown_fields)] #[multipart(deny_unknown_fields)]
struct TestDenyUnknown {} struct TestDenyUnknown {}
@ -534,22 +538,21 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
/// Test the `duplicate_action` struct attribute /// Test the `duplicate_field` struct attribute.
#[derive(MultipartForm)] #[derive(MultipartForm)]
#[multipart(duplicate_action = "deny")] #[multipart(duplicate_field = "deny")]
struct TestDuplicateDeny { struct TestDuplicateDeny {
_field: Text<String>, _field: Text<String>,
} }
#[derive(MultipartForm)] #[derive(MultipartForm)]
#[multipart(duplicate_action = "replace")] #[multipart(duplicate_field = "replace")]
struct TestDuplicateReplace { struct TestDuplicateReplace {
field: Text<String>, field: Text<String>,
} }
#[derive(MultipartForm)] #[derive(MultipartForm)]
#[multipart(duplicate_action = "ignore")] #[multipart(duplicate_field = "ignore")]
struct TestDuplicateIgnore { struct TestDuplicateIgnore {
field: Text<String>, field: Text<String>,
} }
@ -573,7 +576,7 @@ mod tests {
} }
#[actix_rt::test] #[actix_rt::test]
async fn test_duplicate_action() { async fn test_duplicate_field() {
let srv = actix_test::start(|| { let srv = actix_test::start(|| {
App::new() App::new()
.route("/deny", web::post().to(test_duplicate_deny_route)) .route("/deny", web::post().to(test_duplicate_deny_route))
@ -600,8 +603,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }
/// Test the Limits /// Test the Limits.
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct TestMemoryUploadLimits { struct TestMemoryUploadLimits {
field: Bytes, field: Bytes,

View File

@ -1,6 +1,7 @@
//! Writes a field to a temporary file on disk. //! Writes a field to a temporary file on disk.
use std::{ use std::{
io,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
}; };
@ -8,11 +9,12 @@ use std::{
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Display, Error}; use derive_more::{Display, Error};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::{FutureExt as _, TryStreamExt as _}; use futures_util::TryStreamExt as _;
use mime::Mime; use mime::Mime;
use tempfile_dep::NamedTempFile; use tempfile::NamedTempFile;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use super::FieldErrorHandler;
use crate::{ use crate::{
form::{tempfile::TempfileError::FileIo, FieldReader, Limits}, form::{tempfile::TempfileError::FileIo, FieldReader, Limits},
Field, MultipartError, Field, MultipartError,
@ -23,10 +25,13 @@ use crate::{
pub struct Tempfile { pub struct Tempfile {
/// The temporary file on disk. /// The temporary file on disk.
pub file: NamedTempFile, pub file: NamedTempFile,
/// The value of the `content-type` header. /// The value of the `content-type` header.
pub content_type: Option<Mime>, pub content_type: Option<Mime>,
/// The `filename` value in the `content-disposition` header. /// The `filename` value in the `content-disposition` header.
pub file_name: Option<String>, pub file_name: Option<String>,
/// The size in bytes of the file. /// The size in bytes of the file.
pub size: usize, pub size: usize,
} }
@ -39,21 +44,18 @@ impl<'t> FieldReader<'t> for Tempfile {
mut field: Field, mut field: Field,
limits: &'t mut Limits, limits: &'t mut Limits,
) -> Self::Future { ) -> Self::Future {
async move { Box::pin(async move {
let config = TempfileConfig::from_req(req); let config = TempfileConfig::from_req(req);
let field_name = field.name().to_owned(); let field_name = field.name().to_owned();
let mut size = 0; let mut size = 0;
let file = if let Some(dir) = &config.directory { let file = config
NamedTempFile::new_in(dir) .create_tempfile()
} else { .map_err(|err| config.map_error(req, &field_name, FileIo(err)))?;
NamedTempFile::new()
}
.map_err(|e| config.map_error(req, &field_name, FileIo(e)))?;
let mut file_async = tokio::fs::File::from_std( let mut file_async = tokio::fs::File::from_std(
file.reopen() file.reopen()
.map_err(|e| config.map_error(req, &field_name, FileIo(e)))?, .map_err(|err| config.map_error(req, &field_name, FileIo(err)))?,
); );
while let Some(chunk) = field.try_next().await? { while let Some(chunk) = field.try_next().await? {
@ -62,12 +64,13 @@ impl<'t> FieldReader<'t> for Tempfile {
file_async file_async
.write_all(chunk.as_ref()) .write_all(chunk.as_ref())
.await .await
.map_err(|e| config.map_error(req, &field_name, FileIo(e)))?; .map_err(|err| config.map_error(req, &field_name, FileIo(err)))?;
} }
file_async file_async
.flush() .flush()
.await .await
.map_err(|e| config.map_error(req, &field_name, FileIo(e)))?; .map_err(|err| config.map_error(req, &field_name, FileIo(err)))?;
Ok(Tempfile { Ok(Tempfile {
file, file,
@ -78,15 +81,14 @@ impl<'t> FieldReader<'t> for Tempfile {
.map(str::to_owned), .map(str::to_owned),
size, size,
}) })
} })
.boxed_local()
} }
} }
#[derive(Debug, Display, Error)] #[derive(Debug, Display, Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum TempfileError { pub enum TempfileError {
/// IO Error /// File I/O Error
#[display(fmt = "File I/O error: {}", _0)] #[display(fmt = "File I/O error: {}", _0)]
FileIo(std::io::Error), FileIo(std::io::Error),
} }
@ -100,11 +102,20 @@ impl ResponseError for TempfileError {
/// Configuration for the [`Tempfile`] field reader. /// Configuration for the [`Tempfile`] field reader.
#[derive(Clone)] #[derive(Clone)]
pub struct TempfileConfig { pub struct TempfileConfig {
#[allow(clippy::type_complexity)] err_handler: FieldErrorHandler<TempfileError>,
err_handler: Option<Arc<dyn Fn(TempfileError, &HttpRequest) -> Error + Send + Sync>>,
directory: Option<PathBuf>, directory: Option<PathBuf>,
} }
impl TempfileConfig {
fn create_tempfile(&self) -> io::Result<NamedTempFile> {
if let Some(dir) = self.directory.as_deref() {
NamedTempFile::new_in(dir)
} else {
NamedTempFile::new()
}
}
}
const DEFAULT_CONFIG: TempfileConfig = TempfileConfig { const DEFAULT_CONFIG: TempfileConfig = TempfileConfig {
err_handler: None, err_handler: None,
directory: None, directory: None,
@ -138,14 +149,17 @@ impl TempfileConfig {
} else { } else {
err.into() err.into()
}; };
MultipartError::Field { MultipartError::Field {
field_name: field_name.to_owned(), field_name: field_name.to_owned(),
source, source,
} }
} }
/// Set the directory tempfiles will be created in. /// Sets the directory that temp files will be created in.
pub fn directory<P: AsRef<Path>>(mut self, dir: P) -> Self { ///
/// The default temporary file location is platform dependent.
pub fn directory(mut self, dir: impl AsRef<Path>) -> Self {
self.directory = Some(dir.as_ref().to_owned()); self.directory = Some(dir.as_ref().to_owned());
self self
} }
@ -161,13 +175,10 @@ impl Default for TempfileConfig {
mod tests { mod tests {
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
use actix_http::StatusCode;
use actix_multipart_rfc7578::client::multipart; use actix_multipart_rfc7578::client::multipart;
use actix_web::{web, App, HttpResponse, Responder}; use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::tempfile::Tempfile; use crate::form::{tempfile::Tempfile, tests::send_form, MultipartForm};
use crate::form::tests::send_form;
use crate::form::MultipartForm;
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct FileForm { struct FileForm {

View File

@ -1,13 +1,13 @@
//! Deserializes a field from plain text. //! Deserializes a field from plain text.
use std::sync::Arc; use std::{str, sync::Arc};
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Deref, DerefMut, Display, Error}; use derive_more::{Deref, DerefMut, Display, Error};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use futures_util::future::FutureExt as _;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use super::FieldErrorHandler;
use crate::{ use crate::{
form::{bytes::Bytes, FieldReader, Limits}, form::{bytes::Bytes, FieldReader, Limits},
Field, MultipartError, Field, MultipartError,
@ -30,7 +30,7 @@ impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Text<T> {
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>; type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future { fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
async move { Box::pin(async move {
let config = TextConfig::from_req(req); let config = TextConfig::from_req(req);
let field_name = field.name().to_owned(); let field_name = field.name().to_owned();
@ -42,6 +42,7 @@ impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Text<T> {
// content type defaults to text/plain, so None should be considered valid // content type defaults to text/plain, so None should be considered valid
true true
}; };
if !valid && config.validate_content_type { if !valid && config.validate_content_type {
return Err(MultipartError::Field { return Err(MultipartError::Field {
field_name, field_name,
@ -52,36 +53,33 @@ impl<'t, T: DeserializeOwned + 'static> FieldReader<'t> for Text<T> {
let bytes = Bytes::read_field(req, field, limits).await?; let bytes = Bytes::read_field(req, field, limits).await?;
let text = std::str::from_utf8(bytes.data.as_ref()).map_err(|e| { let text = str::from_utf8(&bytes.data).map_err(|err| MultipartError::Field {
MultipartError::Field {
field_name: field_name.clone(), field_name: field_name.clone(),
source: config.map_error(req, TextError::Utf8Error(e)), source: config.map_error(req, TextError::Utf8Error(err)),
}
})?; })?;
Ok(Text(serde_plain::from_str(text).map_err(|e| { Ok(Text(serde_plain::from_str(text).map_err(|err| {
MultipartError::Field { MultipartError::Field {
field_name, field_name,
source: config.map_error(req, TextError::Deserialize(e)), source: config.map_error(req, TextError::Deserialize(err)),
} }
})?)) })?))
} })
.boxed_local()
} }
} }
#[derive(Debug, Display, Error)] #[derive(Debug, Display, Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum TextError { pub enum TextError {
/// Utf8 error /// UTF-8 decoding error.
#[display(fmt = "Utf8 decoding error: {}", _0)] #[display(fmt = "UTF-8 decoding error: {}", _0)]
Utf8Error(std::str::Utf8Error), Utf8Error(str::Utf8Error),
/// Deserialize error /// Deserialize error.
#[display(fmt = "Plain text deserialize error: {}", _0)] #[display(fmt = "Plain text deserialize error: {}", _0)]
Deserialize(serde_plain::Error), Deserialize(serde_plain::Error),
/// Content type error /// Content type error.
#[display(fmt = "Content type error")] #[display(fmt = "Content type error")]
ContentType, ContentType,
} }
@ -95,8 +93,7 @@ impl ResponseError for TextError {
/// Configuration for the [`Text`] field reader. /// Configuration for the [`Text`] field reader.
#[derive(Clone)] #[derive(Clone)]
pub struct TextConfig { pub struct TextConfig {
#[allow(clippy::type_complexity)] err_handler: FieldErrorHandler<TextError>,
err_handler: Option<Arc<dyn Fn(TextError, &HttpRequest) -> Error + Send + Sync>>,
validate_content_type: bool, validate_content_type: bool,
} }
@ -131,6 +128,7 @@ impl TextConfig {
} }
/// Sets whether or not the field must have a valid `Content-Type` header to be parsed. /// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
///
/// Note that an empty `Content-Type` is also accepted, as the multipart specification defines /// Note that an empty `Content-Type` is also accepted, as the multipart specification defines
/// `text/plain` as the default for text fields. /// `text/plain` as the default for text fields.
pub fn validate_content_type(mut self, validate_content_type: bool) -> Self { pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
@ -149,13 +147,14 @@ impl Default for TextConfig {
mod tests { mod tests {
use std::io::Cursor; use std::io::Cursor;
use actix_http::StatusCode;
use actix_multipart_rfc7578::client::multipart; use actix_multipart_rfc7578::client::multipart;
use actix_web::{web, App, HttpResponse, Responder}; use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::tests::send_form; use crate::form::{
use crate::form::text::{Text, TextConfig}; tests::send_form,
use crate::form::MultipartForm; text::{Text, TextConfig},
MultipartForm,
};
#[derive(MultipartForm)] #[derive(MultipartForm)]
struct TextForm { struct TextForm {

View File

@ -8,6 +8,7 @@
// This allows us to use the actix_multipart_derive within this crate's tests // This allows us to use the actix_multipart_derive within this crate's tests
#[cfg(test)] #[cfg(test)]
extern crate self as actix_multipart; extern crate self as actix_multipart;
extern crate tempfile_dep as tempfile;
mod error; mod error;
mod extractor; mod extractor;

View File

@ -270,7 +270,9 @@ impl InnerMultipart {
match field.borrow_mut().poll(safety) { match field.borrow_mut().poll(safety) {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(_))) => continue, Poll::Ready(Some(Ok(_))) => continue,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(err)))
}
Poll::Ready(None) => true, Poll::Ready(None) => true,
} }
} }
@ -658,7 +660,7 @@ impl InnerField {
match res { match res {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(None) => self.eof = true, Poll::Ready(None) => self.eof = true,
} }
} }
@ -673,7 +675,7 @@ impl InnerField {
} }
Poll::Ready(None) Poll::Ready(None)
} }
Err(e) => Poll::Ready(Some(Err(e))), Err(err) => Poll::Ready(Some(Err(err))),
} }
} else { } else {
Poll::Pending Poll::Pending
@ -794,7 +796,7 @@ impl PayloadBuffer {
loop { loop {
match Pin::new(&mut self.stream).poll_next(cx) { match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
Poll::Ready(Some(Err(e))) => return Err(e), Poll::Ready(Some(Err(err))) => return Err(err),
Poll::Ready(None) => { Poll::Ready(None) => {
self.eof = true; self.eof = true;
return Ok(()); return Ok(());
@ -863,10 +865,12 @@ mod tests {
use std::time::Duration; use std::time::Duration;
use actix_http::h1::Payload; use actix_http::h1::Payload;
use actix_web::http::header::{DispositionParam, DispositionType}; use actix_web::{
use actix_web::rt; http::header::{DispositionParam, DispositionType},
use actix_web::test::TestRequest; rt,
use actix_web::FromRequest; test::TestRequest,
FromRequest,
};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future::lazy, StreamExt}; use futures_util::{future::lazy, StreamExt};
use tokio::sync::mpsc; use tokio::sync::mpsc;