mirror of https://github.com/fafhrd91/actix-web
Merge b6893b13cd
into 3f9d88f859
This commit is contained in:
commit
aca3f9d2d1
|
@ -1320,6 +1320,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"libz-sys",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
|
@ -1922,6 +1923,18 @@ dependencies = [
|
|||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libz-sys"
|
||||
version = "1.1.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
## Unreleased
|
||||
|
||||
- Add DEFLATE compression support for WebSocket.
|
||||
|
||||
## 3.11.0
|
||||
|
||||
- Update `brotli` dependency to `8`.
|
||||
|
|
|
@ -29,6 +29,7 @@ features = [
|
|||
"compress-brotli",
|
||||
"compress-gzip",
|
||||
"compress-zstd",
|
||||
"compress-ws-deflate",
|
||||
]
|
||||
|
||||
[package.metadata.cargo_check_external_types]
|
||||
|
@ -83,6 +84,7 @@ rustls-0_23 = ["__tls", "actix-tls/accept", "actix-tls/rustls-0_23"]
|
|||
compress-brotli = ["__compress", "dep:brotli"]
|
||||
compress-gzip = ["__compress", "dep:flate2"]
|
||||
compress-zstd = ["__compress", "dep:zstd"]
|
||||
compress-ws-deflate = ["dep:flate2", "flate2/zlib-default"]
|
||||
|
||||
# Internal (PRIVATE!) features used to aid testing and checking feature status.
|
||||
# Don't rely on these whatsoever. They are semver-exempt and may disappear at anytime.
|
||||
|
|
|
@ -2,18 +2,19 @@
|
|||
//!
|
||||
//! ## Crate Features
|
||||
//!
|
||||
//! | Feature | Functionality |
|
||||
//! | ------------------- | ------------------------------------------- |
|
||||
//! | `http2` | HTTP/2 support via [h2]. |
|
||||
//! | `openssl` | TLS support via [OpenSSL]. |
|
||||
//! | `rustls-0_20` | TLS support via rustls 0.20. |
|
||||
//! | `rustls-0_21` | TLS support via rustls 0.21. |
|
||||
//! | `rustls-0_22` | TLS support via rustls 0.22. |
|
||||
//! | `rustls-0_23` | TLS support via [rustls] 0.23. |
|
||||
//! | `compress-brotli` | Payload compression support: Brotli. |
|
||||
//! | `compress-gzip` | Payload compression support: Deflate, Gzip. |
|
||||
//! | `compress-zstd` | Payload compression support: Zstd. |
|
||||
//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. |
|
||||
//! | Feature | Functionality |
|
||||
//! | --------------------- | ------------------------------------------- |
|
||||
//! | `http2` | HTTP/2 support via [h2]. |
|
||||
//! | `openssl` | TLS support via [OpenSSL]. |
|
||||
//! | `rustls-0_20` | TLS support via rustls 0.20. |
|
||||
//! | `rustls-0_21` | TLS support via rustls 0.21. |
|
||||
//! | `rustls-0_22` | TLS support via rustls 0.22. |
|
||||
//! | `rustls-0_23` | TLS support via [rustls] 0.23. |
|
||||
//! | `compress-brotli` | Payload compression support: Brotli. |
|
||||
//! | `compress-gzip` | Payload compression support: Deflate, Gzip. |
|
||||
//! | `compress-zstd` | Payload compression support: Zstd. |
|
||||
//! | `compress-ws-deflate` | WebSocket DEFLATE compression support. |
|
||||
//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. |
|
||||
//!
|
||||
//! [h2]: https://crates.io/crates/h2
|
||||
//! [OpenSSL]: https://crates.io/crates/openssl
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
use bitflags::bitflags;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytestring::ByteString;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
use tokio_util::codec;
|
||||
use tracing::error;
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
use super::deflate::{
|
||||
DeflateCompressionContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG,
|
||||
};
|
||||
use super::{
|
||||
frame::Parser,
|
||||
proto::{CloseReason, OpCode},
|
||||
proto::{CloseReason, OpCode, RsvBits},
|
||||
ProtocolError,
|
||||
};
|
||||
|
||||
|
@ -66,13 +70,6 @@ pub enum Item {
|
|||
Last(Bytes),
|
||||
}
|
||||
|
||||
/// WebSocket protocol codec.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Codec {
|
||||
flags: Flags,
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Flags: u8 {
|
||||
|
@ -82,63 +79,122 @@ bitflags! {
|
|||
}
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
/// Create new WebSocket frames decoder.
|
||||
pub const fn new() -> Codec {
|
||||
Codec {
|
||||
max_size: 65_536,
|
||||
/// WebSocket message encoder.
|
||||
#[derive(Debug)]
|
||||
pub struct Encoder {
|
||||
flags: Flags,
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_compress: Option<DeflateCompressionContext>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
/// Create new WebSocket frames encoder.
|
||||
pub const fn new() -> Encoder {
|
||||
Encoder {
|
||||
flags: Flags::SERVER,
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_compress: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max frame size.
|
||||
/// Create new WebSocket frames encoder with `permessage-deflate` extension support.
|
||||
/// Compression context can be made from
|
||||
/// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context).
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder {
|
||||
Encoder {
|
||||
flags: Flags::SERVER,
|
||||
|
||||
deflate_compress: Some(compress),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set encoder to client mode.
|
||||
///
|
||||
/// By default max size is set to 64KiB.
|
||||
#[must_use = "This returns the a new Codec, without modifying the original."]
|
||||
pub fn max_size(mut self, size: usize) -> Self {
|
||||
self.max_size = size;
|
||||
/// By default encoder works in server mode.
|
||||
#[must_use = "This returns the a new Encoder, without modifying the original."]
|
||||
pub fn client_mode(mut self) -> Self {
|
||||
self.flags = Flags::empty();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set decoder to client mode.
|
||||
///
|
||||
/// By default decoder works in server mode.
|
||||
#[must_use = "This returns the a new Codec, without modifying the original."]
|
||||
pub fn client_mode(mut self) -> Self {
|
||||
self.flags.remove(Flags::SERVER);
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
fn set_client_mode_deflate(
|
||||
mut self,
|
||||
remote_no_context_takeover: bool,
|
||||
remote_max_window_bits: u8,
|
||||
) -> Self {
|
||||
self.deflate_compress = self
|
||||
.deflate_compress
|
||||
.map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits));
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
fn process_payload(
|
||||
&mut self,
|
||||
fin: bool,
|
||||
bytes: Bytes,
|
||||
) -> Result<(Bytes, RsvBits), ProtocolError> {
|
||||
if let Some(compress) = &mut self.deflate_compress {
|
||||
Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG))
|
||||
} else {
|
||||
Ok((bytes, RsvBits::empty()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "compress-ws-deflate"))]
|
||||
fn process_payload(
|
||||
&mut self,
|
||||
_fin: bool,
|
||||
bytes: Bytes,
|
||||
) -> Result<(Bytes, RsvBits), ProtocolError> {
|
||||
Ok((bytes, RsvBits::empty()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Codec {
|
||||
impl Default for Encoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Message> for Codec {
|
||||
impl codec::Encoder<Message> for Encoder {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Text(txt) => Parser::write_message(
|
||||
dst,
|
||||
txt,
|
||||
OpCode::Text,
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
),
|
||||
Message::Binary(bin) => Parser::write_message(
|
||||
dst,
|
||||
bin,
|
||||
OpCode::Binary,
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
),
|
||||
Message::Text(txt) => {
|
||||
let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?;
|
||||
|
||||
Parser::write_message(
|
||||
dst,
|
||||
bytes,
|
||||
OpCode::Text,
|
||||
rsv_bits,
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
}
|
||||
Message::Binary(bin) => {
|
||||
let (bin, rsv_bits) = self.process_payload(true, bin)?;
|
||||
|
||||
Parser::write_message(
|
||||
dst,
|
||||
bin,
|
||||
OpCode::Binary,
|
||||
rsv_bits,
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
}
|
||||
Message::Ping(txt) => Parser::write_message(
|
||||
dst,
|
||||
txt,
|
||||
OpCode::Ping,
|
||||
RsvBits::empty(),
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
),
|
||||
|
@ -146,22 +202,29 @@ impl Encoder<Message> for Codec {
|
|||
dst,
|
||||
txt,
|
||||
OpCode::Pong,
|
||||
RsvBits::empty(),
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
),
|
||||
Message::Close(reason) => {
|
||||
Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
|
||||
}
|
||||
Message::Close(reason) => Parser::write_close(
|
||||
dst,
|
||||
reason,
|
||||
RsvBits::empty(),
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
),
|
||||
Message::Continuation(cont) => match cont {
|
||||
Item::FirstText(data) => {
|
||||
if self.flags.contains(Flags::W_CONTINUATION) {
|
||||
return Err(ProtocolError::ContinuationStarted);
|
||||
} else {
|
||||
let (data, rsv_bits) = self.process_payload(false, data)?;
|
||||
|
||||
self.flags.insert(Flags::W_CONTINUATION);
|
||||
Parser::write_message(
|
||||
dst,
|
||||
&data[..],
|
||||
data,
|
||||
OpCode::Text,
|
||||
rsv_bits,
|
||||
false,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
|
@ -171,11 +234,14 @@ impl Encoder<Message> for Codec {
|
|||
if self.flags.contains(Flags::W_CONTINUATION) {
|
||||
return Err(ProtocolError::ContinuationStarted);
|
||||
} else {
|
||||
let (data, rsv_bits) = self.process_payload(false, data)?;
|
||||
|
||||
self.flags.insert(Flags::W_CONTINUATION);
|
||||
Parser::write_message(
|
||||
dst,
|
||||
&data[..],
|
||||
data,
|
||||
OpCode::Binary,
|
||||
rsv_bits,
|
||||
false,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
|
@ -183,10 +249,13 @@ impl Encoder<Message> for Codec {
|
|||
}
|
||||
Item::Continue(data) => {
|
||||
if self.flags.contains(Flags::W_CONTINUATION) {
|
||||
let (data, rsv_bits) = self.process_payload(false, data)?;
|
||||
|
||||
Parser::write_message(
|
||||
dst,
|
||||
&data[..],
|
||||
data,
|
||||
OpCode::Continue,
|
||||
rsv_bits,
|
||||
false,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
|
@ -197,10 +266,14 @@ impl Encoder<Message> for Codec {
|
|||
Item::Last(data) => {
|
||||
if self.flags.contains(Flags::W_CONTINUATION) {
|
||||
self.flags.remove(Flags::W_CONTINUATION);
|
||||
|
||||
let (data, rsv_bits) = self.process_payload(true, data)?;
|
||||
|
||||
Parser::write_message(
|
||||
dst,
|
||||
&data[..],
|
||||
data,
|
||||
OpCode::Continue,
|
||||
rsv_bits,
|
||||
true,
|
||||
!self.flags.contains(Flags::SERVER),
|
||||
)
|
||||
|
@ -215,20 +288,130 @@ impl Encoder<Message> for Codec {
|
|||
}
|
||||
}
|
||||
|
||||
impl Decoder for Codec {
|
||||
/// WebSocket message decoder.
|
||||
#[derive(Debug)]
|
||||
pub struct Decoder {
|
||||
flags: Flags,
|
||||
max_size: usize,
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_decompress: Option<DeflateDecompressionContext>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
/// Create new WebSocket frames decoder.
|
||||
pub const fn new() -> Decoder {
|
||||
Decoder {
|
||||
flags: Flags::SERVER,
|
||||
max_size: 65_536,
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_decompress: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create new WebSocket frames decoder with `permessage-deflate` extension support.
|
||||
/// Decompression context can be made from
|
||||
/// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context).
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder {
|
||||
Decoder {
|
||||
flags: Flags::SERVER,
|
||||
max_size: 65_536,
|
||||
|
||||
deflate_decompress: Some(decompress),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max frame size.
|
||||
///
|
||||
/// By default max size is set to 64KiB.
|
||||
#[must_use = "This returns the a new Decoder, without modifying the original."]
|
||||
pub fn max_size(mut self, size: usize) -> Self {
|
||||
self.max_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set decoder to client mode.
|
||||
///
|
||||
/// By default decoder works in server mode.
|
||||
#[must_use = "This returns the a new Decoder, without modifying the original."]
|
||||
pub fn client_mode(mut self) -> Self {
|
||||
self.flags = Flags::empty();
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
fn set_client_mode_deflate(
|
||||
mut self,
|
||||
local_no_context_takeover: bool,
|
||||
local_max_window_bits: u8,
|
||||
) -> Self {
|
||||
if let Some(decompress) = &mut self.deflate_decompress {
|
||||
decompress.reset_with(local_no_context_takeover, local_max_window_bits);
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
fn process_payload(
|
||||
&mut self,
|
||||
fin: bool,
|
||||
opcode: OpCode,
|
||||
rsv_bits: RsvBits,
|
||||
bytes: Option<Bytes>,
|
||||
) -> Result<Option<Bytes>, ProtocolError> {
|
||||
if let Some(bytes) = bytes {
|
||||
if let Some(decompress) = &mut self.deflate_decompress {
|
||||
Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?))
|
||||
} else {
|
||||
Ok(Some(bytes))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "compress-ws-deflate"))]
|
||||
fn process_payload(
|
||||
&mut self,
|
||||
_fin: bool,
|
||||
_opcode: OpCode,
|
||||
_rsv_bits: RsvBits,
|
||||
bytes: Option<Bytes>,
|
||||
) -> Result<Option<Bytes>, ProtocolError> {
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Decoder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl codec::Decoder for Decoder {
|
||||
type Item = Frame;
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
|
||||
Ok(Some((finished, opcode, payload))) => {
|
||||
Ok(Some((finished, opcode, rsv_bits, payload))) => {
|
||||
let payload = self.process_payload(
|
||||
finished,
|
||||
opcode,
|
||||
rsv_bits,
|
||||
payload.map(BytesMut::freeze),
|
||||
)?;
|
||||
|
||||
// continuation is not supported
|
||||
if !finished {
|
||||
return match opcode {
|
||||
OpCode::Continue => {
|
||||
if self.flags.contains(Flags::CONTINUATION) {
|
||||
Ok(Some(Frame::Continuation(Item::Continue(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
payload.unwrap_or_else(Bytes::new),
|
||||
))))
|
||||
} else {
|
||||
Err(ProtocolError::ContinuationNotStarted)
|
||||
|
@ -238,7 +421,7 @@ impl Decoder for Codec {
|
|||
if !self.flags.contains(Flags::CONTINUATION) {
|
||||
self.flags.insert(Flags::CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::FirstBinary(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
payload.unwrap_or_else(Bytes::new),
|
||||
))))
|
||||
} else {
|
||||
Err(ProtocolError::ContinuationStarted)
|
||||
|
@ -248,7 +431,7 @@ impl Decoder for Codec {
|
|||
if !self.flags.contains(Flags::CONTINUATION) {
|
||||
self.flags.insert(Flags::CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::FirstText(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
payload.unwrap_or_else(Bytes::new),
|
||||
))))
|
||||
} else {
|
||||
Err(ProtocolError::ContinuationStarted)
|
||||
|
@ -266,7 +449,7 @@ impl Decoder for Codec {
|
|||
if self.flags.contains(Flags::CONTINUATION) {
|
||||
self.flags.remove(Flags::CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::Last(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
payload.unwrap_or_else(Bytes::new),
|
||||
))))
|
||||
} else {
|
||||
Err(ProtocolError::ContinuationNotStarted)
|
||||
|
@ -281,18 +464,10 @@ impl Decoder for Codec {
|
|||
Ok(Some(Frame::Close(None)))
|
||||
}
|
||||
}
|
||||
OpCode::Ping => Ok(Some(Frame::Ping(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
))),
|
||||
OpCode::Pong => Ok(Some(Frame::Pong(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
))),
|
||||
OpCode::Binary => Ok(Some(Frame::Binary(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
))),
|
||||
OpCode::Text => Ok(Some(Frame::Text(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
))),
|
||||
OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))),
|
||||
OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))),
|
||||
OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))),
|
||||
OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))),
|
||||
}
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
|
@ -300,3 +475,130 @@ impl Decoder for Codec {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket protocol codec.
|
||||
/// This is essentially a combination of [`Encoder`] and [`Decoder`] and
|
||||
/// actual conversion behaviors are defined in both structs respectively.
|
||||
///
|
||||
/// # Note
|
||||
/// Cloning [`Codec`] creates a new codec with existing configurations
|
||||
/// and will not preserve the context information.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Codec {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
}
|
||||
|
||||
impl Clone for Codec {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
encoder: Encoder {
|
||||
flags: self.encoder.flags & Flags::SERVER,
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_compress: self.encoder.deflate_compress.as_ref().map(|c| {
|
||||
DeflateCompressionContext::new(
|
||||
Some(c.compression_level),
|
||||
c.remote_no_context_takeover,
|
||||
c.remote_max_window_bits,
|
||||
)
|
||||
}),
|
||||
},
|
||||
decoder: Decoder {
|
||||
flags: self.decoder.flags & Flags::SERVER,
|
||||
max_size: self.decoder.max_size,
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_decompress: self.decoder.deflate_decompress.as_ref().map(|d| {
|
||||
DeflateDecompressionContext::new(
|
||||
d.local_no_context_takeover,
|
||||
d.local_max_window_bits,
|
||||
)
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
/// Create new WebSocket frames codec.
|
||||
pub fn new() -> Codec {
|
||||
Codec {
|
||||
encoder: Encoder::new(),
|
||||
decoder: Decoder::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create new WebSocket frames codec with DEFLATE compression.
|
||||
/// Both compression and decompression contexts can be made from
|
||||
/// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context).
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub fn new_deflate(
|
||||
compress: DeflateCompressionContext,
|
||||
decompress: DeflateDecompressionContext,
|
||||
) -> Codec {
|
||||
Codec {
|
||||
encoder: Encoder::new_deflate(compress),
|
||||
decoder: Decoder::new_deflate(decompress),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max frame size.
|
||||
///
|
||||
/// By default max size is set to 64KiB.
|
||||
#[must_use = "This returns the a new Codec, without modifying the original."]
|
||||
pub fn max_size(self, size: usize) -> Self {
|
||||
let Self { encoder, decoder } = self;
|
||||
|
||||
Codec {
|
||||
encoder,
|
||||
decoder: decoder.max_size(size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set codec to client mode.
|
||||
///
|
||||
/// By default codec works in server mode.
|
||||
#[must_use = "This returns the a new Codec, without modifying the original."]
|
||||
pub fn client_mode(self) -> Self {
|
||||
let Self {
|
||||
mut encoder,
|
||||
mut decoder,
|
||||
} = self;
|
||||
|
||||
encoder = encoder.client_mode();
|
||||
decoder = decoder.client_mode();
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
{
|
||||
if let Some(decoder) = &decoder.deflate_decompress {
|
||||
encoder = encoder.set_client_mode_deflate(
|
||||
decoder.local_no_context_takeover,
|
||||
decoder.local_max_window_bits,
|
||||
);
|
||||
}
|
||||
if let Some(encoder) = &encoder.deflate_compress {
|
||||
decoder = decoder.set_client_mode_deflate(
|
||||
encoder.remote_no_context_takeover,
|
||||
encoder.remote_max_window_bits,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Self { encoder, decoder }
|
||||
}
|
||||
}
|
||||
|
||||
impl codec::Decoder for Codec {
|
||||
type Item = Frame;
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
self.decoder.decode(src)
|
||||
}
|
||||
}
|
||||
|
||||
impl codec::Encoder<Message> for Codec {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
self.encoder.encode(item, dst)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,850 @@
|
|||
//! WebSocket permessage-deflate compression implementation.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use bytes::Bytes;
|
||||
pub use flate2::Compression as DeflateCompressionLevel;
|
||||
|
||||
use super::{OpCode, ProtocolError, RsvBits};
|
||||
use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS};
|
||||
|
||||
// NOTE: according to [RFC 7692 §7.1.2.1] window bit size should be within 8..=15
|
||||
// but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15.
|
||||
//
|
||||
// [RFC 6792 §7.1.2.1]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1
|
||||
// [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits
|
||||
const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive<u8> = 9..=15;
|
||||
const DEFAULT_WINDOW_BITS: u8 = 15;
|
||||
|
||||
const BUF_SIZE: usize = 2048;
|
||||
|
||||
pub(super) const RSV_BIT_DEFLATE_FLAG: RsvBits = RsvBits::RSV1;
|
||||
|
||||
/// DEFLATE compression related handshake errors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum DeflateHandshakeError {
|
||||
/// Unknown extension parameter given.
|
||||
UnknownWebSocketParameters,
|
||||
|
||||
/// Duplicate parameter found in single extension statement.
|
||||
DuplicateParameter(&'static str),
|
||||
|
||||
/// Max window bits size out of range. Should be in 9..=15
|
||||
MaxWindowBitsOutOfRange,
|
||||
|
||||
/// Multiple `permessage-deflate` statements found but failed to negotiate any.
|
||||
NoSuitableConfigurationFound,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DeflateHandshakeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::UnknownWebSocketParameters => {
|
||||
write!(f, "Unknown WebSocket `permessage-deflate` parameters.")
|
||||
}
|
||||
Self::DuplicateParameter(p) => {
|
||||
write!(f, "Duplicate WebSocket `permessage-deflate` parameter: {p}")
|
||||
}
|
||||
Self::MaxWindowBitsOutOfRange => write!(
|
||||
f,
|
||||
"Max window bits out of range. ({} to {} expected)",
|
||||
MAX_WINDOW_BITS_RANGE.start(),
|
||||
MAX_WINDOW_BITS_RANGE.end()
|
||||
),
|
||||
Self::NoSuitableConfigurationFound => write!(
|
||||
f,
|
||||
"No suitable WebSocket `permedia-deflate` parameter configurations found."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DeflateHandshakeError {}
|
||||
|
||||
/// Maximum size of client's DEFLATE sliding window.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ClientMaxWindowBits {
|
||||
/// Unspecified. Indicates client will follow server configuration.
|
||||
NotSpecified,
|
||||
/// Specified size of client's DEFLATE sliding window size in bits, between 9 and 15.
|
||||
Specified(u8),
|
||||
}
|
||||
|
||||
/// Per-session DEFLATE configuration parameter.
|
||||
///
|
||||
/// It can be used both client and server side.
|
||||
/// At client side, it can be used to pass desired configuration to server.
|
||||
/// At server side, negotiated parameter will be sent to client with this.
|
||||
/// This can be represented in HTTP header form as it implements [`TryIntoHeaderPair`] trait.
|
||||
#[derive(Debug, Clone, Default, Eq, PartialEq)]
|
||||
pub struct DeflateSessionParameters {
|
||||
/// Disallow server from take over context.
|
||||
pub server_no_context_takeover: bool,
|
||||
/// Disallow client from take over context.
|
||||
pub client_no_context_takeover: bool,
|
||||
/// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15.
|
||||
pub server_max_window_bits: Option<u8>,
|
||||
/// Maximum size of client's DEFLATE sliding window.
|
||||
pub client_max_window_bits: Option<ClientMaxWindowBits>,
|
||||
}
|
||||
|
||||
impl TryIntoHeaderPair for DeflateSessionParameters {
|
||||
type Error = Infallible;
|
||||
|
||||
fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> {
|
||||
let mut response_extension = vec!["permessage-deflate".to_owned()];
|
||||
|
||||
if self.server_no_context_takeover {
|
||||
response_extension.push("server_no_context_takeover".to_owned());
|
||||
}
|
||||
if self.client_no_context_takeover {
|
||||
response_extension.push("client_no_context_takeover".to_owned());
|
||||
}
|
||||
if let Some(server_max_window_bits) = self.server_max_window_bits {
|
||||
response_extension.push(format!("server_max_window_bits={server_max_window_bits}"));
|
||||
}
|
||||
if let Some(client_max_window_bits) = self.client_max_window_bits {
|
||||
match client_max_window_bits {
|
||||
ClientMaxWindowBits::NotSpecified => {
|
||||
response_extension.push("client_max_window_bits".to_string());
|
||||
}
|
||||
ClientMaxWindowBits::Specified(bits) => {
|
||||
response_extension.push(format!("client_max_window_bits={bits}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
SEC_WEBSOCKET_EXTENSIONS,
|
||||
HeaderValue::from_str(&response_extension.join("; ")).unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl DeflateSessionParameters {
|
||||
fn parse<'a>(
|
||||
extension_frags: impl Iterator<Item = &'a str>,
|
||||
) -> Result<Self, DeflateHandshakeError> {
|
||||
let mut client_max_window_bits = None;
|
||||
let mut server_max_window_bits = None;
|
||||
let mut client_no_context_takeover = None;
|
||||
let mut server_no_context_takeover = None;
|
||||
|
||||
let mut unknown_parameters = vec![];
|
||||
|
||||
for fragment in extension_frags {
|
||||
if fragment.is_empty() {
|
||||
continue;
|
||||
} else if fragment == "client_max_window_bits" {
|
||||
if client_max_window_bits.is_some() {
|
||||
return Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"client_max_window_bits",
|
||||
));
|
||||
}
|
||||
client_max_window_bits = Some(ClientMaxWindowBits::NotSpecified);
|
||||
} else if let Some(value) = fragment.strip_prefix("client_max_window_bits=") {
|
||||
if client_max_window_bits.is_some() {
|
||||
return Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"client_max_window_bits",
|
||||
));
|
||||
}
|
||||
let bits = value
|
||||
.parse::<u8>()
|
||||
.map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?;
|
||||
if !MAX_WINDOW_BITS_RANGE.contains(&bits) {
|
||||
return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange);
|
||||
}
|
||||
client_max_window_bits = Some(ClientMaxWindowBits::Specified(bits));
|
||||
} else if let Some(value) = fragment.strip_prefix("server_max_window_bits=") {
|
||||
if server_max_window_bits.is_some() {
|
||||
return Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"server_max_window_bits",
|
||||
));
|
||||
}
|
||||
let bits = value
|
||||
.parse::<u8>()
|
||||
.map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?;
|
||||
if !MAX_WINDOW_BITS_RANGE.contains(&bits) {
|
||||
return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange);
|
||||
}
|
||||
server_max_window_bits = Some(bits);
|
||||
} else if fragment == "server_no_context_takeover" {
|
||||
if server_no_context_takeover.is_some() {
|
||||
return Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"server_no_context_takeover",
|
||||
));
|
||||
}
|
||||
server_no_context_takeover = Some(true);
|
||||
} else if fragment == "client_no_context_takeover" {
|
||||
if client_no_context_takeover.is_some() {
|
||||
return Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"client_no_context_takeover",
|
||||
));
|
||||
}
|
||||
client_no_context_takeover = Some(true);
|
||||
} else {
|
||||
unknown_parameters.push(fragment.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
if !unknown_parameters.is_empty() {
|
||||
Err(DeflateHandshakeError::UnknownWebSocketParameters)
|
||||
} else {
|
||||
Ok(DeflateSessionParameters {
|
||||
server_no_context_takeover: server_no_context_takeover.unwrap_or(false),
|
||||
client_no_context_takeover: client_no_context_takeover.unwrap_or(false),
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse desired parameters from `Sec-WebSocket-Extensions` header.
|
||||
/// The result may contain multiple values as it's possible to pass multiple parameters
|
||||
/// separated with comma.
|
||||
pub fn from_extension_header(header_value: &str) -> Vec<Result<Self, DeflateHandshakeError>> {
|
||||
let mut results = vec![];
|
||||
for extension in header_value.split(',').map(str::trim) {
|
||||
let mut fragments = extension.split(';').map(str::trim);
|
||||
if fragments.next() == Some("permessage-deflate") {
|
||||
results.push(Self::parse(fragments));
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Create compression and decompression context based on the parameter.
|
||||
pub fn create_context(
|
||||
&self,
|
||||
compression_level: Option<DeflateCompressionLevel>,
|
||||
is_client_mode: bool,
|
||||
) -> (DeflateCompressionContext, DeflateDecompressionContext) {
|
||||
let client_max_window_bits =
|
||||
if let Some(ClientMaxWindowBits::Specified(value)) = self.client_max_window_bits {
|
||||
value
|
||||
} else {
|
||||
DEFAULT_WINDOW_BITS
|
||||
};
|
||||
let server_max_window_bits = self.server_max_window_bits.unwrap_or(DEFAULT_WINDOW_BITS);
|
||||
|
||||
let (remote_no_context_takeover, remote_max_window_bits) = if is_client_mode {
|
||||
(self.server_no_context_takeover, server_max_window_bits)
|
||||
} else {
|
||||
(self.client_no_context_takeover, client_max_window_bits)
|
||||
};
|
||||
|
||||
let (local_no_context_takeover, local_max_window_bits) = if is_client_mode {
|
||||
(self.client_no_context_takeover, client_max_window_bits)
|
||||
} else {
|
||||
(self.server_no_context_takeover, server_max_window_bits)
|
||||
};
|
||||
|
||||
(
|
||||
DeflateCompressionContext::new(
|
||||
compression_level,
|
||||
remote_no_context_takeover,
|
||||
remote_max_window_bits,
|
||||
),
|
||||
DeflateDecompressionContext::new(local_no_context_takeover, local_max_window_bits),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side DEFLATE configuration.
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct DeflateServerConfig {
|
||||
/// DEFLATE compression level. See [`flate2::Compression`] for details.
|
||||
pub compression_level: Option<DeflateCompressionLevel>,
|
||||
/// Disallow server from take over context. Default is false.
|
||||
pub server_no_context_takeover: bool,
|
||||
/// Disallow client from take over context. Default is false.
|
||||
pub client_no_context_takeover: bool,
|
||||
/// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15. Default is 15.
|
||||
pub server_max_window_bits: Option<u8>,
|
||||
/// Maximum size of client's DEFLATE sliding window in bits, between 9 and 15. Default is 15.
|
||||
pub client_max_window_bits: Option<u8>,
|
||||
}
|
||||
|
||||
impl DeflateServerConfig {
|
||||
/// Negotiate context parameters.
|
||||
/// Since parameters from the client may be incompatible with the server configuration,
|
||||
/// actual parameters could be adjusted here. Conversion rules are as follows:
|
||||
///
|
||||
/// ## server_no_context_takeover
|
||||
///
|
||||
/// | Config | Request | Response |
|
||||
/// | ------ | ------- | --------- |
|
||||
/// | false | false | false |
|
||||
/// | false | true | true |
|
||||
/// | true | false | true |
|
||||
/// | true | true | true |
|
||||
///
|
||||
/// ## client_no_context_takeover
|
||||
///
|
||||
/// | Config | Request | Response |
|
||||
/// | ------ | ------- | --------- |
|
||||
/// | false | false | false |
|
||||
/// | false | true | true |
|
||||
/// | true | false | true |
|
||||
/// | true | true | true |
|
||||
///
|
||||
/// ## server_max_window_bits
|
||||
///
|
||||
/// | Config | Request | Response |
|
||||
/// | ------------ | ------------ | -------- |
|
||||
/// | None | None | None |
|
||||
/// | None | 9 <= R <= 15 | R |
|
||||
/// | 9 <= C <= 15 | None | C |
|
||||
/// | 9 <= C <= 15 | 9 <= R <= C | R |
|
||||
/// | 9 <= C <= 15 | C <= R <= 15 | C |
|
||||
///
|
||||
/// ## client_max_window_bits
|
||||
///
|
||||
/// | Config | Request | Response |
|
||||
/// | ------------ | ------------ | -------- |
|
||||
/// | None | None | None |
|
||||
/// | None | Unspecified | None |
|
||||
/// | None | 9 <= R <= 15 | R |
|
||||
/// | 9 <= C <= 15 | None | None |
|
||||
/// | 9 <= C <= 15 | Unspecified | C |
|
||||
/// | 9 <= C <= 15 | 9 <= R <= C | R |
|
||||
/// | 9 <= C <= 15 | C <= R <= 15 | C |
|
||||
pub fn negotiate(&self, params: DeflateSessionParameters) -> DeflateSessionParameters {
|
||||
let server_no_context_takeover =
|
||||
if self.server_no_context_takeover && !params.server_no_context_takeover {
|
||||
true
|
||||
} else {
|
||||
params.server_no_context_takeover
|
||||
};
|
||||
|
||||
let client_no_context_takeover =
|
||||
if self.client_no_context_takeover && !params.client_no_context_takeover {
|
||||
true
|
||||
} else {
|
||||
params.client_no_context_takeover
|
||||
};
|
||||
|
||||
let server_max_window_bits =
|
||||
match (self.server_max_window_bits, params.server_max_window_bits) {
|
||||
(None, value) => value,
|
||||
(Some(config_value), None) => Some(config_value),
|
||||
(Some(config_value), Some(value)) => {
|
||||
if value > config_value {
|
||||
Some(config_value)
|
||||
} else {
|
||||
Some(value)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client_max_window_bits =
|
||||
match (self.client_max_window_bits, params.client_max_window_bits) {
|
||||
(None, None | Some(ClientMaxWindowBits::NotSpecified)) => None,
|
||||
(None, Some(ClientMaxWindowBits::Specified(value))) => Some(value),
|
||||
(Some(_), None) => None,
|
||||
(Some(config_value), Some(ClientMaxWindowBits::NotSpecified)) => Some(config_value),
|
||||
(Some(config_value), Some(ClientMaxWindowBits::Specified(value))) => {
|
||||
if value > config_value {
|
||||
Some(config_value)
|
||||
} else {
|
||||
Some(value)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
DeflateSessionParameters {
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits: client_max_window_bits.map(ClientMaxWindowBits::Specified),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DEFLATE decompression context.
|
||||
#[derive(Debug)]
|
||||
pub struct DeflateDecompressionContext {
|
||||
pub(super) local_no_context_takeover: bool,
|
||||
pub(super) local_max_window_bits: u8,
|
||||
|
||||
decompress: flate2::Decompress,
|
||||
|
||||
decode_continuation: bool,
|
||||
total_bytes_written: u64,
|
||||
total_bytes_read: u64,
|
||||
}
|
||||
|
||||
impl DeflateDecompressionContext {
|
||||
pub(super) fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self {
|
||||
Self {
|
||||
local_no_context_takeover,
|
||||
local_max_window_bits,
|
||||
|
||||
decompress: flate2::Decompress::new_with_window_bits(false, local_max_window_bits),
|
||||
|
||||
decode_continuation: false,
|
||||
total_bytes_written: 0,
|
||||
total_bytes_read: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn reset_with(
|
||||
&mut self,
|
||||
local_no_context_takeover: bool,
|
||||
local_max_window_bits: u8,
|
||||
) {
|
||||
*self = Self::new(local_no_context_takeover, local_max_window_bits);
|
||||
}
|
||||
|
||||
pub(super) fn decompress(
|
||||
&mut self,
|
||||
fin: bool,
|
||||
opcode: OpCode,
|
||||
rsv: RsvBits,
|
||||
payload: Bytes,
|
||||
) -> Result<Bytes, ProtocolError> {
|
||||
if !matches!(opcode, OpCode::Text | OpCode::Binary | OpCode::Continue)
|
||||
|| !rsv.contains(RSV_BIT_DEFLATE_FLAG)
|
||||
{
|
||||
return Ok(payload);
|
||||
}
|
||||
|
||||
if opcode == OpCode::Continue {
|
||||
if !self.decode_continuation {
|
||||
return Ok(payload);
|
||||
}
|
||||
} else {
|
||||
self.decode_continuation = true;
|
||||
}
|
||||
|
||||
let mut output: Vec<u8> = vec![];
|
||||
let mut buf = [0u8; BUF_SIZE];
|
||||
|
||||
let mut offset: usize = 0;
|
||||
loop {
|
||||
let res = if offset >= payload.len() {
|
||||
self.decompress
|
||||
.decompress(
|
||||
&[0x00, 0x00, 0xff, 0xff],
|
||||
&mut buf,
|
||||
flate2::FlushDecompress::Finish,
|
||||
)
|
||||
.map_err(|err| {
|
||||
self.reset();
|
||||
ProtocolError::Io(err.into())
|
||||
})?
|
||||
} else {
|
||||
self.decompress
|
||||
.decompress(&payload[offset..], &mut buf, flate2::FlushDecompress::None)
|
||||
.map_err(|err| {
|
||||
self.reset();
|
||||
ProtocolError::Io(err.into())
|
||||
})?
|
||||
};
|
||||
|
||||
let read = self.decompress.total_in() - self.total_bytes_read;
|
||||
let written = self.decompress.total_out() - self.total_bytes_written;
|
||||
|
||||
offset += read as usize;
|
||||
self.total_bytes_read += read;
|
||||
if written > 0 {
|
||||
output.extend(buf.iter().take(written as usize));
|
||||
self.total_bytes_written += written;
|
||||
}
|
||||
|
||||
if res != flate2::Status::Ok {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if fin {
|
||||
self.decode_continuation = false;
|
||||
if self.local_no_context_takeover {
|
||||
self.reset();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.decompress.reset(false);
|
||||
self.total_bytes_read = 0;
|
||||
self.total_bytes_written = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// DEFLATE compression context.
|
||||
#[derive(Debug)]
|
||||
pub struct DeflateCompressionContext {
|
||||
pub(super) compression_level: flate2::Compression,
|
||||
pub(super) remote_no_context_takeover: bool,
|
||||
pub(super) remote_max_window_bits: u8,
|
||||
|
||||
compress: flate2::Compress,
|
||||
total_bytes_written: u64,
|
||||
total_bytes_read: u64,
|
||||
}
|
||||
|
||||
impl DeflateCompressionContext {
|
||||
pub(super) fn new(
|
||||
compression_level: Option<flate2::Compression>,
|
||||
remote_no_context_takeover: bool,
|
||||
remote_max_window_bits: u8,
|
||||
) -> Self {
|
||||
let compression_level = compression_level.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
compression_level,
|
||||
remote_no_context_takeover,
|
||||
remote_max_window_bits,
|
||||
|
||||
compress: flate2::Compress::new_with_window_bits(
|
||||
compression_level,
|
||||
false,
|
||||
remote_max_window_bits,
|
||||
),
|
||||
|
||||
total_bytes_written: 0,
|
||||
total_bytes_read: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn reset_with(
|
||||
mut self,
|
||||
remote_no_context_takeover: bool,
|
||||
remote_max_window_bits: u8,
|
||||
) -> Self {
|
||||
self = Self::new(
|
||||
Some(self.compression_level),
|
||||
remote_no_context_takeover,
|
||||
remote_max_window_bits,
|
||||
);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(super) fn compress(&mut self, fin: bool, payload: Bytes) -> Result<Bytes, ProtocolError> {
|
||||
let mut output = vec![];
|
||||
let mut buf = [0u8; BUF_SIZE];
|
||||
|
||||
loop {
|
||||
let bytes_in = self.compress.total_in() - self.total_bytes_read;
|
||||
let res = if bytes_in >= payload.len() as u64 {
|
||||
self.compress
|
||||
.compress(&[], &mut buf, flate2::FlushCompress::Sync)
|
||||
.map_err(|err| {
|
||||
self.reset();
|
||||
ProtocolError::Io(err.into())
|
||||
})?
|
||||
} else {
|
||||
self.compress
|
||||
.compress(
|
||||
&payload[bytes_in as usize..],
|
||||
&mut buf,
|
||||
flate2::FlushCompress::None,
|
||||
)
|
||||
.map_err(|err| {
|
||||
self.reset();
|
||||
ProtocolError::Io(err.into())
|
||||
})?
|
||||
};
|
||||
|
||||
let written = self.compress.total_out() - self.total_bytes_written;
|
||||
if written > 0 {
|
||||
output.extend(buf.iter().take(written as usize));
|
||||
self.total_bytes_written += written;
|
||||
}
|
||||
|
||||
if res != flate2::Status::Ok {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.total_bytes_read = self.compress.total_in();
|
||||
|
||||
if output.iter().rev().take(4).eq(&[0xff, 0xff, 0x00, 0x00]) {
|
||||
output.drain(output.len() - 4..);
|
||||
}
|
||||
|
||||
if fin && self.remote_no_context_takeover {
|
||||
self.reset();
|
||||
}
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.compress.reset();
|
||||
self.total_bytes_read = 0;
|
||||
self.total_bytes_written = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::body::MessageBody;
|
||||
|
||||
#[test]
|
||||
fn test_session_parameters() {
|
||||
let extension = "abc, def, permessage-deflate";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![Ok(DeflateSessionParameters::default())]
|
||||
);
|
||||
|
||||
let extension = "permessage-deflate; unknown_parameter";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![Err(DeflateHandshakeError::UnknownWebSocketParameters)]
|
||||
);
|
||||
|
||||
let extension = "permessage-deflate; client_max_window_bits=9; client_max_window_bits=10";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![Err(DeflateHandshakeError::DuplicateParameter(
|
||||
"client_max_window_bits"
|
||||
))]
|
||||
);
|
||||
|
||||
let extension = "permessage-deflate; server_max_window_bits=8";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)]
|
||||
);
|
||||
|
||||
let extension = "permessage-deflate; server_max_window_bits=16";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)]
|
||||
);
|
||||
|
||||
let extension = "permessage-deflate; client_max_window_bits; server_max_window_bits=15; \
|
||||
client_no_context_takeover; server_no_context_takeover, \
|
||||
permessage-deflate; client_max_window_bits=10";
|
||||
assert_eq!(
|
||||
DeflateSessionParameters::from_extension_header(extension),
|
||||
vec![
|
||||
Ok(DeflateSessionParameters {
|
||||
server_no_context_takeover: true,
|
||||
client_no_context_takeover: true,
|
||||
server_max_window_bits: Some(15),
|
||||
client_max_window_bits: Some(ClientMaxWindowBits::NotSpecified)
|
||||
}),
|
||||
Ok(DeflateSessionParameters {
|
||||
server_no_context_takeover: false,
|
||||
client_no_context_takeover: false,
|
||||
server_max_window_bits: None,
|
||||
client_max_window_bits: Some(ClientMaxWindowBits::Specified(10))
|
||||
})
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compress() {
|
||||
// With context takeover
|
||||
|
||||
let mut compress = DeflateCompressionContext::new(None, false, 15);
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
);
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2@0\x01\0")
|
||||
);
|
||||
|
||||
// Without context takeover
|
||||
|
||||
let mut compress = DeflateCompressionContext::new(None, true, 15);
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
);
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
);
|
||||
|
||||
// With continuation
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(false, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
);
|
||||
// Continuation keeps context.
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2@0\x01\0")
|
||||
);
|
||||
// after continuation, context resets
|
||||
assert_eq!(
|
||||
compress
|
||||
.compress(true, "Hello World".try_into_bytes().unwrap())
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompress() {
|
||||
// With context takeover
|
||||
|
||||
let mut decompress = DeflateDecompressionContext::new(false, 15);
|
||||
|
||||
// Without RSV1 bit, decompression does not happen.
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::empty(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Control frames (such as ping/pong) are not decompressed
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Ping,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"Hello World")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Successful decompression
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Success subsequent decompression
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2@0\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Invalid compression payload
|
||||
assert!(decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"Hello World")
|
||||
)
|
||||
.is_err());
|
||||
|
||||
// When there was error, context is reset.
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Without context takeover
|
||||
|
||||
let mut decompress = DeflateDecompressionContext::new(true, 15);
|
||||
|
||||
// Successful decompression
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// Context has been reset.
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
|
||||
// With continuation
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
false,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
// Continuation keeps context.
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
true,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2@0\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
// When continuation has finished, context is reset.
|
||||
assert_eq!(
|
||||
decompress
|
||||
.decompress(
|
||||
false,
|
||||
OpCode::Text,
|
||||
RsvBits::RSV1,
|
||||
Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0")
|
||||
)
|
||||
.unwrap(),
|
||||
Bytes::from_static(b"Hello World")
|
||||
);
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@ use tracing::debug;
|
|||
|
||||
use super::{
|
||||
mask::apply_mask,
|
||||
proto::{CloseCode, CloseReason, OpCode},
|
||||
proto::{CloseCode, CloseReason, OpCode, RsvBits},
|
||||
ProtocolError,
|
||||
};
|
||||
|
||||
|
@ -17,7 +17,7 @@ impl Parser {
|
|||
fn parse_metadata(
|
||||
src: &[u8],
|
||||
server: bool,
|
||||
) -> Result<Option<(usize, bool, OpCode, usize, Option<[u8; 4]>)>, ProtocolError> {
|
||||
) -> Result<Option<(usize, bool, OpCode, RsvBits, usize, Option<[u8; 4]>)>, ProtocolError> {
|
||||
let chunk_len = src.len();
|
||||
|
||||
let mut idx = 2;
|
||||
|
@ -37,6 +37,9 @@ impl Parser {
|
|||
return Err(ProtocolError::MaskedFrame);
|
||||
}
|
||||
|
||||
// RSV bits
|
||||
let rsv_bits = RsvBits::from_bits((first & 0x70) >> 4).unwrap_or(RsvBits::empty());
|
||||
|
||||
// Op code
|
||||
let opcode = OpCode::from(first & 0x0F);
|
||||
|
||||
|
@ -79,7 +82,7 @@ impl Parser {
|
|||
None
|
||||
};
|
||||
|
||||
Ok(Some((idx, finished, opcode, length, mask)))
|
||||
Ok(Some((idx, finished, opcode, rsv_bits, length, mask)))
|
||||
}
|
||||
|
||||
/// Parse the input stream into a frame.
|
||||
|
@ -87,12 +90,13 @@ impl Parser {
|
|||
src: &mut BytesMut,
|
||||
server: bool,
|
||||
max_size: usize,
|
||||
) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
|
||||
) -> Result<Option<(bool, OpCode, RsvBits, Option<BytesMut>)>, ProtocolError> {
|
||||
// try to parse ws frame metadata
|
||||
let (idx, finished, opcode, length, mask) = match Parser::parse_metadata(src, server)? {
|
||||
None => return Ok(None),
|
||||
Some(res) => res,
|
||||
};
|
||||
let (idx, finished, opcode, rsv_bits, length, mask) =
|
||||
match Parser::parse_metadata(src, server)? {
|
||||
None => return Ok(None),
|
||||
Some(res) => res,
|
||||
};
|
||||
|
||||
// not enough data
|
||||
if src.len() < idx + length {
|
||||
|
@ -115,7 +119,7 @@ impl Parser {
|
|||
|
||||
// no need for body
|
||||
if length == 0 {
|
||||
return Ok(Some((finished, opcode, None)));
|
||||
return Ok(Some((finished, opcode, rsv_bits, None)));
|
||||
}
|
||||
|
||||
let mut data = src.split_to(length);
|
||||
|
@ -127,7 +131,7 @@ impl Parser {
|
|||
}
|
||||
OpCode::Close if length > 125 => {
|
||||
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
|
||||
return Ok(Some((true, OpCode::Close, None)));
|
||||
return Ok(Some((true, OpCode::Close, rsv_bits, None)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -137,7 +141,7 @@ impl Parser {
|
|||
apply_mask(&mut data, mask);
|
||||
}
|
||||
|
||||
Ok(Some((finished, opcode, Some(data))))
|
||||
Ok(Some((finished, opcode, rsv_bits, Some(data))))
|
||||
}
|
||||
|
||||
/// Parse the payload of a close frame.
|
||||
|
@ -161,15 +165,15 @@ impl Parser {
|
|||
dst: &mut BytesMut,
|
||||
pl: B,
|
||||
op: OpCode,
|
||||
rsv_bits: RsvBits,
|
||||
fin: bool,
|
||||
mask: bool,
|
||||
) {
|
||||
let payload = pl.as_ref();
|
||||
let one: u8 = if fin {
|
||||
0x80 | Into::<u8>::into(op)
|
||||
} else {
|
||||
op.into()
|
||||
};
|
||||
let fin_bits = if fin { 0x80 } else { 0x00 };
|
||||
let rsv_bits = rsv_bits.bits() << 4;
|
||||
|
||||
let one: u8 = fin_bits | rsv_bits | Into::<u8>::into(op);
|
||||
let payload_len = payload.len();
|
||||
let (two, p_len) = if mask {
|
||||
(0x80, payload_len + 4)
|
||||
|
@ -203,7 +207,12 @@ impl Parser {
|
|||
|
||||
/// Create a new Close control frame.
|
||||
#[inline]
|
||||
pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
|
||||
pub fn write_close(
|
||||
dst: &mut BytesMut,
|
||||
reason: Option<CloseReason>,
|
||||
rsv_bits: RsvBits,
|
||||
mask: bool,
|
||||
) {
|
||||
let payload = match reason {
|
||||
None => Vec::new(),
|
||||
Some(reason) => {
|
||||
|
@ -215,7 +224,7 @@ impl Parser {
|
|||
}
|
||||
};
|
||||
|
||||
Parser::write_message(dst, payload, OpCode::Close, true, mask)
|
||||
Parser::write_message(dst, payload, OpCode::Close, rsv_bits, true, mask)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -228,18 +237,22 @@ mod tests {
|
|||
struct F {
|
||||
finished: bool,
|
||||
opcode: OpCode,
|
||||
rsv_bits: RsvBits,
|
||||
payload: Bytes,
|
||||
}
|
||||
|
||||
fn is_none(frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> bool {
|
||||
fn is_none(
|
||||
frm: &Result<Option<(bool, OpCode, RsvBits, Option<BytesMut>)>, ProtocolError>,
|
||||
) -> bool {
|
||||
matches!(*frm, Ok(None))
|
||||
}
|
||||
|
||||
fn extract(frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>) -> F {
|
||||
fn extract(frm: Result<Option<(bool, OpCode, RsvBits, Option<BytesMut>)>, ProtocolError>) -> F {
|
||||
match frm {
|
||||
Ok(Some((finished, opcode, payload))) => F {
|
||||
Ok(Some((finished, opcode, rsv_bits, payload))) => F {
|
||||
finished,
|
||||
opcode,
|
||||
rsv_bits,
|
||||
payload: payload
|
||||
.map(|b| b.freeze())
|
||||
.unwrap_or_else(|| Bytes::from("")),
|
||||
|
@ -260,6 +273,17 @@ mod tests {
|
|||
assert!(!frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload.as_ref(), &b"1"[..]);
|
||||
|
||||
let mut buf = BytesMut::from(&[0b1111_0001u8, 0b0000_0001u8][..]);
|
||||
buf.extend(b"2");
|
||||
|
||||
let frame = extract(Parser::parse(&mut buf, false, 1024));
|
||||
assert!(frame.finished);
|
||||
assert_eq!(frame.opcode, OpCode::Text);
|
||||
assert_eq!(frame.payload.as_ref(), &b"2"[..]);
|
||||
assert!(frame.rsv_bits.contains(RsvBits::RSV1));
|
||||
assert!(frame.rsv_bits.contains(RsvBits::RSV2));
|
||||
assert!(frame.rsv_bits.contains(RsvBits::RSV3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -368,7 +392,14 @@ mod tests {
|
|||
#[test]
|
||||
fn test_ping_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
|
||||
Parser::write_message(
|
||||
&mut buf,
|
||||
Vec::from("data"),
|
||||
OpCode::Ping,
|
||||
RsvBits::empty(),
|
||||
true,
|
||||
false,
|
||||
);
|
||||
|
||||
let mut v = vec![137u8, 4u8];
|
||||
v.extend(b"data");
|
||||
|
@ -378,7 +409,14 @@ mod tests {
|
|||
#[test]
|
||||
fn test_pong_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
|
||||
Parser::write_message(
|
||||
&mut buf,
|
||||
Vec::from("data"),
|
||||
OpCode::Pong,
|
||||
RsvBits::empty(),
|
||||
true,
|
||||
false,
|
||||
);
|
||||
|
||||
let mut v = vec![138u8, 4u8];
|
||||
v.extend(b"data");
|
||||
|
@ -389,7 +427,7 @@ mod tests {
|
|||
fn test_close_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
let reason = (CloseCode::Normal, "data");
|
||||
Parser::write_close(&mut buf, Some(reason.into()), false);
|
||||
Parser::write_close(&mut buf, Some(reason.into()), RsvBits::empty(), false);
|
||||
|
||||
let mut v = vec![136u8, 6u8, 3u8, 232u8];
|
||||
v.extend(b"data");
|
||||
|
@ -399,7 +437,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_empty_close_frame() {
|
||||
let mut buf = BytesMut::new();
|
||||
Parser::write_close(&mut buf, None, false);
|
||||
Parser::write_close(&mut buf, None, RsvBits::empty(), false);
|
||||
assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,16 +11,20 @@ use http::{header, Method, StatusCode};
|
|||
use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder};
|
||||
|
||||
mod codec;
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
mod deflate;
|
||||
mod dispatcher;
|
||||
mod frame;
|
||||
mod mask;
|
||||
mod proto;
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub use self::deflate::{DeflateCompressionLevel, DeflateServerConfig, DeflateSessionParameters};
|
||||
pub use self::{
|
||||
codec::{Codec, Frame, Item, Message},
|
||||
codec::{Codec, Decoder, Encoder, Frame, Item, Message},
|
||||
dispatcher::Dispatcher,
|
||||
frame::Parser,
|
||||
proto::{hash_key, CloseCode, CloseReason, OpCode},
|
||||
proto::{hash_key, CloseCode, CloseReason, OpCode, RsvBits},
|
||||
};
|
||||
|
||||
/// WebSocket protocol errors.
|
||||
|
@ -93,6 +97,11 @@ pub enum HandshakeError {
|
|||
/// WebSocket key is not set or wrong.
|
||||
#[display("unknown WebSocket key")]
|
||||
BadWebsocketKey,
|
||||
|
||||
/// Invalid `permessage-deflate` request.
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
#[display("invalid WebSocket `permessage-deflate` extension request")]
|
||||
BadDeflateRequest(deflate::DeflateHandshakeError),
|
||||
}
|
||||
|
||||
impl From<HandshakeError> for Response<BoxBody> {
|
||||
|
@ -135,6 +144,13 @@ impl From<HandshakeError> for Response<BoxBody> {
|
|||
res.head_mut().reason = Some("Handshake error");
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
HandshakeError::BadDeflateRequest(_) => {
|
||||
let mut res = Response::bad_request();
|
||||
res.head_mut().reason = Some("Invalid permessage-deflate request");
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -151,6 +167,69 @@ pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
|
|||
Ok(handshake_response(req))
|
||||
}
|
||||
|
||||
/// Verify WebSocket handshake request with DEFLATE compression configurations.
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub fn handshake_deflate(
|
||||
config: &deflate::DeflateServerConfig,
|
||||
req: &RequestHead,
|
||||
) -> Result<
|
||||
(
|
||||
ResponseBuilder,
|
||||
Option<(
|
||||
deflate::DeflateCompressionContext,
|
||||
deflate::DeflateDecompressionContext,
|
||||
)>,
|
||||
),
|
||||
HandshakeError,
|
||||
> {
|
||||
verify_handshake(req)?;
|
||||
|
||||
let mut available_configurations = vec![];
|
||||
for header in req.headers().get_all(header::SEC_WEBSOCKET_EXTENSIONS) {
|
||||
let Ok(header_str) = header.to_str() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
available_configurations.extend(deflate::DeflateSessionParameters::from_extension_header(
|
||||
header_str,
|
||||
));
|
||||
}
|
||||
|
||||
let mut selected_config = None;
|
||||
let mut selected_error = None;
|
||||
for config in available_configurations {
|
||||
match config {
|
||||
Ok(config) => {
|
||||
selected_config = Some(config);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
if selected_error.is_none() {
|
||||
selected_error = Some(err);
|
||||
} else {
|
||||
selected_error =
|
||||
Some(deflate::DeflateHandshakeError::NoSuitableConfigurationFound);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(selected_error) = selected_error {
|
||||
Err(HandshakeError::BadDeflateRequest(selected_error))
|
||||
} else {
|
||||
let mut response = handshake_response(req);
|
||||
|
||||
if let Some(selected_config) = selected_config {
|
||||
let param = config.negotiate(selected_config);
|
||||
let contexts = param.create_context(config.compression_level, false);
|
||||
response.insert_header(param);
|
||||
Ok((response, Some(contexts)))
|
||||
} else {
|
||||
Ok((response, None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify WebSocket handshake request.
|
||||
pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
|
||||
// WebSocket accepts only GET
|
||||
|
@ -196,6 +275,7 @@ pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
|
|||
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
|
||||
return Err(HandshakeError::BadWebsocketKey);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -222,6 +222,25 @@ impl<T: Into<String>> From<(CloseCode, T)> for CloseReason {
|
|||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// RSV bits defined in [RFC 6455 §5.2].
|
||||
/// Reserved for extensions and should be set to zero if no extensions are applicable.
|
||||
///
|
||||
/// [RFC 6455 §5.2]: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
|
||||
pub struct RsvBits: u8 {
|
||||
const RSV1 = 0b0000_0100;
|
||||
const RSV2 = 0b0000_0010;
|
||||
const RSV3 = 0b0000_0001;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RsvBits {
|
||||
fn default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// The WebSocket GUID as stated in the spec.
|
||||
/// See <https://datatracker.ietf.org/doc/html/rfc6455#section-1.3>.
|
||||
static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
|
|
|
@ -79,6 +79,8 @@ compress-brotli = ["actix-http/compress-brotli", "__compress"]
|
|||
compress-gzip = ["actix-http/compress-gzip", "__compress"]
|
||||
# Zstd algorithm content-encoding support
|
||||
compress-zstd = ["actix-http/compress-zstd", "__compress"]
|
||||
# Deflate compression for WebSocket
|
||||
compress-ws-deflate = ["actix-http/compress-ws-deflate"]
|
||||
|
||||
# Cookie parsing and cookie jar
|
||||
cookies = ["dep:cookie"]
|
||||
|
|
|
@ -30,6 +30,8 @@ use std::{fmt, net::SocketAddr, str};
|
|||
|
||||
use actix_codec::Framed;
|
||||
pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub use actix_http::ws::{DeflateCompressionLevel, DeflateSessionParameters};
|
||||
use actix_http::{ws, Payload, RequestHead};
|
||||
use actix_rt::time::timeout;
|
||||
use actix_service::Service as _;
|
||||
|
@ -59,6 +61,9 @@ pub struct WebsocketsRequest {
|
|||
server_mode: bool,
|
||||
config: ClientConfig,
|
||||
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_compression_level: Option<DeflateCompressionLevel>,
|
||||
|
||||
#[cfg(feature = "cookies")]
|
||||
cookies: Option<CookieJar>,
|
||||
}
|
||||
|
@ -94,6 +99,8 @@ impl WebsocketsRequest {
|
|||
protocols: None,
|
||||
max_size: 65_536,
|
||||
server_mode: false,
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
deflate_compression_level: None,
|
||||
#[cfg(feature = "cookies")]
|
||||
cookies: None,
|
||||
}
|
||||
|
@ -249,6 +256,22 @@ impl WebsocketsRequest {
|
|||
self.header(AUTHORIZATION, format!("Bearer {}", token))
|
||||
}
|
||||
|
||||
/// Enable DEFLATE compression
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
pub fn deflate(
|
||||
mut self,
|
||||
compression_level: Option<DeflateCompressionLevel>,
|
||||
params: DeflateSessionParameters,
|
||||
) -> Self {
|
||||
use actix_http::header::TryIntoHeaderPair;
|
||||
// Assume session parameters are always valid.
|
||||
let (key, value) = params.try_into_pair().unwrap();
|
||||
|
||||
self.deflate_compression_level = compression_level;
|
||||
|
||||
self.header(key, value)
|
||||
}
|
||||
|
||||
/// Complete request construction and connect to a WebSocket server.
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
|
@ -409,17 +432,52 @@ impl WebsocketsRequest {
|
|||
return Err(WsClientError::MissingWebSocketAcceptHeader);
|
||||
};
|
||||
|
||||
// response and ws framed
|
||||
Ok((
|
||||
ClientResponse::new(head, Payload::None),
|
||||
framed.into_map_codec(|_| {
|
||||
if server_mode {
|
||||
ws::Codec::new().max_size(max_size)
|
||||
#[cfg(feature = "compress-ws-deflate")]
|
||||
let framed = {
|
||||
let selected_parameter = head
|
||||
.headers
|
||||
.get_all(header::SEC_WEBSOCKET_EXTENSIONS)
|
||||
.filter_map(|header| {
|
||||
if let Ok(header_str) = header.to_str() {
|
||||
Some(DeflateSessionParameters::from_extension_header(header_str))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.filter_map(Result::ok)
|
||||
.next();
|
||||
|
||||
framed.into_map_codec(move |_| {
|
||||
let codec = if let Some(parameter) = selected_parameter.clone() {
|
||||
let (compress, decompress) =
|
||||
parameter.create_context(self.deflate_compression_level, false);
|
||||
Codec::new_deflate(compress, decompress)
|
||||
} else {
|
||||
ws::Codec::new().max_size(max_size).client_mode()
|
||||
Codec::new()
|
||||
}
|
||||
}),
|
||||
))
|
||||
.max_size(max_size);
|
||||
|
||||
if server_mode {
|
||||
codec
|
||||
} else {
|
||||
codec.client_mode()
|
||||
}
|
||||
})
|
||||
};
|
||||
#[cfg(not(feature = "compress-ws-deflate"))]
|
||||
let framed = framed.into_map_codec(move |_| {
|
||||
let codec = Codec::new().max_size(max_size);
|
||||
|
||||
if server_mode {
|
||||
codec
|
||||
} else {
|
||||
codec.client_mode()
|
||||
}
|
||||
});
|
||||
|
||||
// response and ws framed
|
||||
Ok((ClientResponse::new(head, Payload::None), framed))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue