diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 06b40172..f01355ac 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -107,7 +107,7 @@ actix-utils = "3" actix-rt = { version = "2.2", default-features = false } bitflags = "2" -bytes = "1" +bytes = "1.7" bytestring = "1" derive_more = { version = "2", features = ["as_ref", "deref", "deref_mut", "display", "error", "from"] } encoding_rs = "0.8" diff --git a/actix-http/examples/actix-web.rs b/actix-http/examples/actix-web.rs index e07abfd9..68e12301 100644 --- a/actix-http/examples/actix-web.rs +++ b/actix-http/examples/actix-web.rs @@ -1,19 +1,34 @@ +use std::sync::OnceLock; + use actix_http::HttpService; use actix_server::Server; use actix_service::map_config; use actix_web::{dev::AppConfig, get, App, Responder}; +static MEDIUM: OnceLock = OnceLock::new(); +static LARGE: OnceLock = OnceLock::new(); + #[get("/")] async fn index() -> impl Responder { "Hello, world. From Actix Web!" } +#[get("/large")] +async fn large() -> &'static str { + LARGE.get_or_init(|| "123456890".repeat(1024 * 100)) +} + +#[get("/medium")] +async fn medium() -> &'static str { + MEDIUM.get_or_init(|| "123456890".repeat(1024 * 5)) +} + #[tokio::main(flavor = "current_thread")] async fn main() -> std::io::Result<()> { Server::build() .bind("hello-world", "127.0.0.1:8080", || { // construct actix-web app - let app = App::new().service(index); + let app = App::new().service(index).service(large).service(medium); HttpService::build() // pass the app to service builder diff --git a/actix-http/src/big_bytes.rs b/actix-http/src/big_bytes.rs new file mode 100644 index 00000000..49839ab4 --- /dev/null +++ b/actix-http/src/big_bytes.rs @@ -0,0 +1,124 @@ +use std::collections::VecDeque; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +// 64KB max capacity (arbitrarily chosen) +const MAX_CAPACITY: usize = 1024 * 64; + +pub struct BigBytes { + buffer: BytesMut, + frozen: VecDeque, + frozen_len: usize, +} + +impl BigBytes { + /// Initialize a new BigBytes with the internal buffer set to `capacity` capacity + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: BytesMut::with_capacity(capacity), + frozen: VecDeque::default(), + frozen_len: 0, + } + } + + /// Clear the internal queue and buffer, resetting length to zero + /// + /// if the internal buffer capacity exceeds 64KB or new_capacity, whichever is greater, it will + /// be freed and a new buffer of capacity `new_capacity` will be allocated + pub fn clear(&mut self, new_capacity: usize) { + std::mem::take(&mut self.frozen); + self.frozen_len = 0; + self.buffer.clear(); + + if self.buffer.capacity() > new_capacity.max(MAX_CAPACITY) { + self.buffer = BytesMut::with_capacity(new_capacity); + } + } + + /// Return a mutable reference to the underlying buffer. This should only be used when dealing + /// with small allocations (e.g. writing headers) + pub fn buffer_mut(&mut self) -> &mut BytesMut { + &mut self.buffer + } + + /// Return the total length of the bytes stored in BigBytes + pub fn total_len(&mut self) -> usize { + self.frozen_len + self.buffer.len() + } + + /// Return whether there are no bytes present in the BigBytes + pub fn is_empty(&self) -> bool { + self.frozen_len == 0 && self.buffer.is_empty() + } + + /// Add the `bytes` to the internal structure. If `bytes` exceeds 64KB, it is pushed into a + /// queue, otherwise, it is added to a buffer. + pub fn put_bytes(&mut self, bytes: Bytes) { + if !self.buffer.is_empty() { + let current = self.buffer.split().freeze(); + self.frozen_len += current.len(); + self.frozen.push_back(current); + } + + if !bytes.is_empty() { + self.frozen_len += bytes.len(); + self.frozen.push_back(bytes); + } + } + + /// Returns a slice of the frontmost buffer + /// + /// While there are bytes present in BigBytes, front_slice is guaranteed not to return an empty + /// slice. + pub fn front_slice(&self) -> &[u8] { + if let Some(front) = self.frozen.front() { + front + } else { + &self.buffer + } + } + + /// Advances the first buffer by `count` bytes. If the first buffer is advanced to completion, + /// it is popped from the queue + pub fn advance(&mut self, count: usize) { + if let Some(front) = self.frozen.front_mut() { + front.advance(count); + + if front.is_empty() { + self.frozen.pop_front(); + } + + self.frozen_len -= count; + } else { + self.buffer.advance(count); + } + } + + /// Pops the front Bytes from the BigBytes, or splits and freezes the internal buffer if no + /// Bytes are present. + pub fn pop_front(&mut self) -> Option { + if let Some(front) = self.frozen.pop_front() { + self.frozen_len -= front.len(); + Some(front) + } else if !self.buffer.is_empty() { + Some(self.buffer.split().freeze()) + } else { + None + } + } + + /// Drain the BigBytes, writing everything into the provided BytesMut + pub fn write_to(&mut self, dst: &mut BytesMut) { + dst.reserve(self.total_len()); + + for buf in &self.frozen { + dst.put_slice(buf); + } + + dst.put_slice(&self.buffer.split()); + + self.frozen_len = 0; + + std::mem::take(&mut self.frozen); + } +} diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 2b452f8f..b097ddf9 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -9,7 +9,10 @@ use super::{ decoder::{self, PayloadDecoder, PayloadItem, PayloadType}, encoder, Message, MessageType, }; -use crate::{body::BodySize, error::ParseError, ConnectionType, Request, Response, ServiceConfig}; +use crate::{ + big_bytes::BigBytes, body::BodySize, error::ParseError, ConnectionType, Request, Response, + ServiceConfig, +}; bitflags! { #[derive(Debug, Clone, Copy)] @@ -146,14 +149,12 @@ impl Decoder for Codec { } } -impl Encoder, BodySize)>> for Codec { - type Error = io::Error; - - fn encode( +impl Codec { + pub(super) fn encode_bigbytes( &mut self, item: Message<(Response<()>, BodySize)>, - dst: &mut BytesMut, - ) -> Result<(), Self::Error> { + dst: &mut BigBytes, + ) -> std::io::Result<()> { match item { Message::Item((mut res, length)) => { // set response version @@ -172,7 +173,7 @@ impl Encoder, BodySize)>> for Codec { // encode message self.encoder.encode( - dst, + dst.buffer_mut(), &mut res, self.flags.contains(Flags::HEAD), self.flags.contains(Flags::STREAM), @@ -184,11 +185,11 @@ impl Encoder, BodySize)>> for Codec { } Message::Chunk(Some(bytes)) => { - self.encoder.encode_chunk(bytes.as_ref(), dst)?; + self.encoder.encode_chunk_bigbytes(bytes, dst)?; } Message::Chunk(None) => { - self.encoder.encode_eof(dst)?; + self.encoder.encode_eof(dst.buffer_mut())?; } } @@ -196,6 +197,23 @@ impl Encoder, BodySize)>> for Codec { } } +impl Encoder, BodySize)>> for Codec { + type Error = io::Error; + + fn encode( + &mut self, + item: Message<(Response<()>, BodySize)>, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + let mut bigbytes = BigBytes::with_capacity(1024 * 8); + self.encode_bigbytes(item, &mut bigbytes)?; + + bigbytes.write_to(dst); + + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 00b51360..40b47293 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -12,11 +12,11 @@ use actix_codec::{Framed, FramedParts}; use actix_rt::time::sleep_until; use actix_service::Service; use bitflags::bitflags; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use futures_core::ready; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::{Decoder as _, Encoder as _}; +use tokio_util::codec::Decoder as _; use tracing::{error, trace}; use super::{ @@ -27,6 +27,7 @@ use super::{ Message, MessageType, }; use crate::{ + big_bytes::BigBytes, body::{BodySize, BoxBody, MessageBody}, config::ServiceConfig, error::{DispatchError, ParseError, PayloadError}, @@ -165,7 +166,7 @@ pin_project! { pub(super) io: Option, read_buf: BytesMut, - write_buf: BytesMut, + write_buf: BigBytes, codec: Codec, } } @@ -277,7 +278,7 @@ where io: Some(io), read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), - write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), + write_buf: BigBytes::with_capacity(HW_BUFFER_SIZE), codec: Codec::new(config), }, }, @@ -329,27 +330,24 @@ where let InnerDispatcherProj { io, write_buf, .. } = self.project(); let mut io = Pin::new(io.as_mut().unwrap()); - let len = write_buf.len(); - let mut written = 0; - - while written < len { - match io.as_mut().poll_write(cx, &write_buf[written..])? { + while write_buf.total_len() > 0 { + match io.as_mut().poll_write(cx, write_buf.front_slice())? { Poll::Ready(0) => { + println!("WRITE ZERO"); error!("write zero; closing"); return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, ""))); } - Poll::Ready(n) => written += n, + Poll::Ready(n) => write_buf.advance(n), Poll::Pending => { - write_buf.advance(written); return Poll::Pending; } } } // everything has written to I/O; clear buffer - write_buf.clear(); + write_buf.clear(HW_BUFFER_SIZE); // flush the I/O and check if get blocked io.poll_flush(cx) @@ -365,7 +363,7 @@ where let size = body.size(); this.codec - .encode(Message::Item((res, size)), this.write_buf) + .encode_bigbytes(Message::Item((res, size)), this.write_buf) .map_err(|err| { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -416,6 +414,7 @@ where fn send_continue(self: Pin<&mut Self>) { self.project() .write_buf + .buffer_mut() .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); } @@ -493,15 +492,16 @@ where StateProj::SendPayload { mut body } => { // keep populate writer buffer until buffer size limit hit, // get blocked or finished. - while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE { match body.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { this.codec - .encode(Message::Chunk(Some(item)), this.write_buf)?; + .encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?; } Poll::Ready(None) => { - this.codec.encode(Message::Chunk(None), this.write_buf)?; + this.codec + .encode_bigbytes(Message::Chunk(None), this.write_buf)?; // payload stream finished. // set state to None and handle next message @@ -532,15 +532,16 @@ where // keep populate writer buffer until buffer size limit hit, // get blocked or finished. - while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE { match body.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { this.codec - .encode(Message::Chunk(Some(item)), this.write_buf)?; + .encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?; } Poll::Ready(None) => { - this.codec.encode(Message::Chunk(None), this.write_buf)?; + this.codec + .encode_bigbytes(Message::Chunk(None), this.write_buf)?; // payload stream finished // set state to None and handle next message @@ -575,6 +576,7 @@ where // to service call. Poll::Ready(Ok(req)) => { this.write_buf + .buffer_mut() .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); let fut = this.flow.service.call(req); this.state.set(State::ServiceCall { fut }); @@ -1027,7 +1029,7 @@ where mem::take(this.codec), mem::take(this.read_buf), ); - parts.write_buf = mem::take(this.write_buf); + this.write_buf.write_to(&mut parts.write_buf); let framed = Framed::from_parts(parts); this.flow.upgrade.as_ref().unwrap().call((req, framed)) } diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 77e34bcd..33949d2e 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -6,9 +6,10 @@ use std::{ slice::from_raw_parts_mut, }; -use bytes::{BufMut, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use crate::{ + big_bytes::BigBytes, body::BodySize, header::{ map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, @@ -323,6 +324,14 @@ impl MessageEncoder { self.te.encode(msg, buf) } + pub(super) fn encode_chunk_bigbytes( + &mut self, + msg: Bytes, + buf: &mut BigBytes, + ) -> io::Result { + self.te.encode_bigbytes(msg, buf) + } + /// Encode EOF. pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { self.te.encode_eof(buf) @@ -414,6 +423,63 @@ impl TransferEncoding { } } + #[inline] + /// Encode message. Return `EOF` state of encoder + pub(super) fn encode_bigbytes(&mut self, msg: Bytes, buf: &mut BigBytes) -> io::Result { + match self.kind { + TransferEncodingKind::Eof => { + let eof = msg.is_empty(); + if msg.len() > 1024 * 64 { + buf.put_bytes(msg); + } else { + buf.buffer_mut().extend_from_slice(&msg); + } + Ok(eof) + } + TransferEncodingKind::Chunked(ref mut eof) => { + if *eof { + return Ok(true); + } + + if msg.is_empty() { + *eof = true; + buf.buffer_mut().extend_from_slice(b"0\r\n\r\n"); + } else { + writeln!(helpers::MutWriter(buf.buffer_mut()), "{:X}\r", msg.len()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + if msg.len() > 1024 * 64 { + buf.put_bytes(msg); + } else { + buf.buffer_mut().reserve(msg.len() + 2); + buf.buffer_mut().extend_from_slice(&msg); + } + buf.buffer_mut().extend_from_slice(b"\r\n"); + } + Ok(*eof) + } + TransferEncodingKind::Length(ref mut remaining) => { + if *remaining > 0 { + if msg.is_empty() { + return Ok(*remaining == 0); + } + let len = cmp::min(*remaining, msg.len() as u64); + + if len > 1024 * 64 { + buf.put_bytes(msg.slice(..len as usize)); + } else { + buf.buffer_mut().extend_from_slice(&msg[..len as usize]); + } + + *remaining -= len; + Ok(*remaining == 0) + } else { + Ok(true) + } + } + } + } + /// Encode message. Return `EOF` state of encoder #[inline] pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 734e6e1e..cf29766e 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -31,6 +31,7 @@ pub use http::{uri, uri::Uri, Method, StatusCode, Version}; +pub mod big_bytes; pub mod body; mod builder; mod config; diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index ad487e40..0a2c8996 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -4,6 +4,8 @@ use bytestring::ByteString; use tokio_util::codec::{Decoder, Encoder}; use tracing::error; +use crate::big_bytes::BigBytes; + use super::{ frame::Parser, proto::{CloseReason, OpCode}, @@ -116,51 +118,55 @@ impl Default for Codec { } } -impl Encoder for Codec { - type Error = ProtocolError; - - fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { +impl Codec { + pub fn encode_bigbytes( + &mut self, + item: Message, + dst: &mut BigBytes, + ) -> Result<(), ProtocolError> { match item { - Message::Text(txt) => Parser::write_message( + Message::Text(txt) => Parser::write_message_bigbytes( dst, - txt, + txt.into_bytes(), OpCode::Text, true, !self.flags.contains(Flags::SERVER), ), - Message::Binary(bin) => Parser::write_message( + Message::Binary(bin) => Parser::write_message_bigbytes( dst, bin, OpCode::Binary, true, !self.flags.contains(Flags::SERVER), ), - Message::Ping(txt) => Parser::write_message( + Message::Ping(txt) => Parser::write_message_bigbytes( dst, txt, OpCode::Ping, true, !self.flags.contains(Flags::SERVER), ), - Message::Pong(txt) => Parser::write_message( + Message::Pong(txt) => Parser::write_message_bigbytes( dst, txt, OpCode::Pong, true, !self.flags.contains(Flags::SERVER), ), - Message::Close(reason) => { - Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) - } + Message::Close(reason) => Parser::write_close( + dst.buffer_mut(), + reason, + !self.flags.contains(Flags::SERVER), + ), Message::Continuation(cont) => match cont { Item::FirstText(data) => { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { self.flags.insert(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_message_bigbytes( dst, - &data[..], + data, OpCode::Text, false, !self.flags.contains(Flags::SERVER), @@ -172,9 +178,9 @@ impl Encoder for Codec { return Err(ProtocolError::ContinuationStarted); } else { self.flags.insert(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_message_bigbytes( dst, - &data[..], + data, OpCode::Binary, false, !self.flags.contains(Flags::SERVER), @@ -183,9 +189,9 @@ impl Encoder for Codec { } Item::Continue(data) => { if self.flags.contains(Flags::W_CONTINUATION) { - Parser::write_message( + Parser::write_message_bigbytes( dst, - &data[..], + data, OpCode::Continue, false, !self.flags.contains(Flags::SERVER), @@ -197,9 +203,9 @@ impl Encoder for Codec { Item::Last(data) => { if self.flags.contains(Flags::W_CONTINUATION) { self.flags.remove(Flags::W_CONTINUATION); - Parser::write_message( + Parser::write_message_bigbytes( dst, - &data[..], + data, OpCode::Continue, true, !self.flags.contains(Flags::SERVER), @@ -215,6 +221,20 @@ impl Encoder for Codec { } } +impl Encoder for Codec { + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + let mut big_bytes = BigBytes::with_capacity(0); + + self.encode_bigbytes(item, &mut big_bytes)?; + + big_bytes.write_to(dst); + + Ok(()) + } +} + impl Decoder for Codec { type Item = Frame; type Error = ProtocolError; diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 35b3f8e6..22c0eee9 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -1,8 +1,10 @@ use std::cmp::min; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use tracing::debug; +use crate::big_bytes::BigBytes; + use super::{ mask::apply_mask, proto::{CloseCode, CloseReason, OpCode}, @@ -156,51 +158,68 @@ impl Parser { } } - /// Generate binary representation - pub fn write_message>( - dst: &mut BytesMut, - pl: B, + pub fn write_message_bigbytes( + dst: &mut BigBytes, + pl: Bytes, op: OpCode, fin: bool, mask: bool, ) { - let payload = pl.as_ref(); let one: u8 = if fin { 0x80 | Into::::into(op) } else { op.into() }; - let payload_len = payload.len(); - let (two, p_len) = if mask { - (0x80, payload_len + 4) - } else { - (0, payload_len) - }; + let payload_len = pl.len(); + let two = if mask { 0x80 } else { 0 }; if payload_len < 126 { - dst.reserve(p_len + 2); - dst.put_slice(&[one, two | payload_len as u8]); + dst.buffer_mut().reserve(2); + dst.buffer_mut().put_slice(&[one, two | payload_len as u8]); } else if payload_len <= 65_535 { - dst.reserve(p_len + 4); - dst.put_slice(&[one, two | 126]); - dst.put_u16(payload_len as u16); + dst.buffer_mut().reserve(4); + dst.buffer_mut().put_slice(&[one, two | 126]); + dst.buffer_mut().put_u16(payload_len as u16); } else { - dst.reserve(p_len + 10); - dst.put_slice(&[one, two | 127]); - dst.put_u64(payload_len as u64); + dst.buffer_mut().reserve(10); + dst.buffer_mut().put_slice(&[one, two | 127]); + dst.buffer_mut().put_u64(payload_len as u64); }; if mask { let mask = rand::random::<[u8; 4]>(); - dst.put_slice(mask.as_ref()); - dst.put_slice(payload.as_ref()); - let pos = dst.len() - payload_len; - apply_mask(&mut dst[pos..], mask); + dst.buffer_mut().put_slice(mask.as_ref()); + + match pl.try_into_mut() { + // Avoid copying bytes by mutating in-place + Ok(mut pl_mut) => { + apply_mask(&mut pl_mut, mask); + dst.put_bytes(pl_mut.freeze()); + } + + // We need to copy the bytes anyway at this point, so put them in the buffer + // directly + Err(pl) => { + dst.buffer_mut().reserve(pl.len()); + dst.buffer_mut().put_slice(pl.as_ref()); + let pos = dst.buffer_mut().len() - payload_len; + apply_mask(&mut dst.buffer_mut()[pos..], mask); + } + } } else { - dst.put_slice(payload.as_ref()); + dst.put_bytes(pl) } } + /// Generate binary representation + pub fn write_message(dst: &mut BytesMut, pl: Bytes, op: OpCode, fin: bool, mask: bool) { + let mut big_bytes = BigBytes::with_capacity(0); + + Self::write_message_bigbytes(&mut big_bytes, pl, op, fin, mask); + + big_bytes.write_to(dst); + } + /// Create a new Close control frame. #[inline] pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { @@ -215,7 +234,7 @@ impl Parser { } }; - Parser::write_message(dst, payload, OpCode::Close, true, mask) + Parser::write_message(dst, Bytes::from(payload), OpCode::Close, true, mask) } } @@ -368,7 +387,13 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_message( + &mut buf, + Bytes::from(Vec::from("data")), + OpCode::Ping, + true, + false, + ); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -378,7 +403,13 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_message( + &mut buf, + Bytes::from(Vec::from("data")), + OpCode::Pong, + true, + false, + ); let mut v = vec![138u8, 4u8]; v.extend(b"data");