poc multipart derive

This commit is contained in:
Rob Ede 2020-08-23 15:11:05 +01:00
parent 75d86a6beb
commit 9161f279de
No known key found for this signature in database
GPG Key ID: C2A3B36E841A91E6
11 changed files with 269 additions and 0 deletions

View File

@ -32,6 +32,7 @@ members = [
"actix-http",
"actix-files",
"actix-multipart",
"actix-multipart-derive",
"actix-web-actors",
"actix-web-codegen",
"test-server",

View File

@ -0,0 +1,20 @@
[package]
name = "actix-multipart-derive"
version = "0.1.0"
authors = ["Rob Ede <robjtede@icloud.com>"]
edition = "2018"
[lib]
proc-macro = true
[dependencies]
quote = "1"
syn = { version = "1", features = ["extra-traits"] }
# [dev-dependencies]
actix-multipart = "0.3.0-beta.1"
actix-web = "3.0.0-beta.3"
bytes = "0.5"
futures-util = "0.3"
serde = { version = "1", features = ["derive"] }
trybuild = "1"

View File

@ -0,0 +1,139 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
};
#[proc_macro_derive(MultipartForm, attributes(multipart))]
pub fn derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let name = ast.ident;
let b_name = format!("{}MultipartBuilder", name);
let b_ident = Ident::new(&b_name, name.span());
let fields = if let Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { ref named, .. }),
..
}) = ast.data
{
named
} else {
unimplemented!()
};
let optioned = fields.iter().map(|f| {
let Field { ident, ty, .. } = f;
quote! { #ident: ::std::option::Option<#ty> }
});
let field_max_sizes = fields.iter().map(|f| {
let Field { ident, .. } = f;
// TODO: parse field attributes find
// #[multipart(max_size = n)]
quote! { stringify!(#ident) => Some(8096) }
});
let build_fields = fields.iter().map(|f| {
let Field { ident, .. } = f;
quote! { #ident: self.#ident.unwrap() }
});
let field_appending = fields.iter().map(|f| {
let Field { ident, ty, .. } = f;
quote! {
stringify!(#ident) => match builder.#ident {
Some(ref mut field) => {
field.append(chunk);
}
None => {
let mut field = #ty::default();
field.append(chunk);
builder.#ident.replace(field);
}
},
}
});
let expanded = quote! {
#[derive(Debug, Clone, Default)]
struct #b_ident {
#(#optioned,)*
}
impl #b_ident {
fn max_size(field: &str) -> Option<usize> {
match field {
#(#field_max_sizes,)*
_ => None,
}
}
fn build(self) -> Result<#name, ::actix_web::Error> {
Ok(Form {
#(#build_fields,)*
})
}
}
impl ::actix_web::FromRequest for #name {
type Error = ::actix_web::Error;
type Future = ::futures_util::future::LocalBoxFuture<'static, Result<Self, Self::Error>>;
type Config = ();
fn from_request(req: &::actix_web::HttpRequest, payload: &mut ::actix_web::dev::Payload) -> Self::Future {
use futures_util::future::FutureExt;
use futures_util::stream::StreamExt;
use actix_multipart::BuildFromBytes;
let pl = payload.take();
let req2 = req.clone();
async move {
let mut mp = ::actix_multipart::Multipart::new(req2.headers(), pl);
let mut builder = #b_ident::default();
while let Some(item) = mp.next().await {
let mut field = item?;
let headers = field.headers();
let cd = field.content_disposition().unwrap();
let name = cd.get_name().unwrap();
println!("FIELD: {}", name);
let mut size = 0;
while let Some(chunk) = field.next().await {
let chunk = chunk?;
size += chunk.len();
if (size > #b_ident::max_size(&name).unwrap_or(std::usize::MAX)) {
return Err(::actix_web::error::ErrorPayloadTooLarge("field is too large"));
}
match name {
#(#field_appending)*
_ => todo!("unknown field"),
}
}
println!();
}
builder.build()
}
.boxed_local()
}
}
};
expanded.into()
}

View File

@ -0,0 +1,27 @@
use actix_multipart_derive::MultipartForm;
use actix_web::{post, App, HttpServer};
use bytes::BytesMut;
#[derive(Debug, Clone, Default, MultipartForm)]
struct Form {
name: String,
#[multipart(max_size = 8096)]
file: BytesMut,
}
#[post("/")]
async fn no_params(form: Form) -> &'static str {
println!("{:?}", &form);
"Hello world!\r\n"
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(|| App::new().service(no_params))
.bind("127.0.0.1:8080")?
.workers(1)
.run()
.await
}

View File

@ -0,0 +1,19 @@
#[post("/")]
async fn no_params(form: Form) -> &'static str {
println!("{:?}", &form);
"Hello world!\r\n"
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info");
env_logger::init();
HttpServer::new(|| App::new().service(no_params))
.bind("127.0.0.1:8080")?
.workers(1)
.run()
.await
}

View File

@ -0,0 +1,9 @@
#[test]
fn compile_macros() {
let t = trybuild::TestCases::new();
t.pass("tests/trybuild/01-basic.rs");
t.pass("tests/trybuild/02-max-size.rs");
// t.pass("tests/trybuild/03-inert-filter.rs");
// t.compile_fail("tests/trybuild/02-only-async.rs");
}

View File

@ -0,0 +1,10 @@
use actix_multipart_derive::MultipartForm;
use bytes::BytesMut;
#[derive(Debug, Clone, Default, MultipartForm)]
struct Form {
name: String,
file: BytesMut,
}
fn main() {}

View File

@ -0,0 +1,11 @@
use actix_multipart_derive::MultipartForm;
use bytes::BytesMut;
#[derive(Debug, Clone, Default, MultipartForm)]
struct Form {
name: String,
#[multipart(max_size = 8096)]
file: BytesMut,
}
fn main () {}

View File

@ -0,0 +1,13 @@
use actix_multipart_derive::MultipartForm;
use serde::Deserialize;
#[derive(Debug, Clone, Default, Deserialize, MultipartForm)]
struct Form {
name: String,
#[multipart(max_size = 8096)]
#[serde(rename = "mFile")]
file: String,
}
fn main() {}

View File

@ -30,3 +30,4 @@ twoway = "0.2"
[dev-dependencies]
actix-rt = "1.0.0"
actix-http = "2.0.0-beta.3"
env_logger = "0.7"

View File

@ -6,3 +6,22 @@ mod server;
pub use self::error::MultipartError;
pub use self::server::{Field, Multipart};
use bytes::{BufMut, Bytes, BytesMut};
pub trait BuildFromBytes {
fn append(&mut self, next: Bytes);
}
impl BuildFromBytes for String {
fn append(&mut self, chunk: Bytes) {
let chunk_str = std::str::from_utf8(&chunk).expect("string field is not utf-8");
self.push_str(chunk_str);
}
}
impl BuildFromBytes for BytesMut {
fn append(&mut self, chunk: Bytes) {
self.put(&chunk[..]);
}
}