diff --git a/.cspell.yml b/.cspell.yml index 3aa9eea1..26044022 100644 --- a/.cspell.yml +++ b/.cspell.yml @@ -2,9 +2,13 @@ version: "0.2" words: - actix - addrs + - ALPN + - arrayvec + - bitflags - clippy - deque - itertools + - itoa - mptcp - MSRV - nonblocking @@ -13,6 +17,7 @@ words: - rcgen - Rustls - rustup + - smallvec - spki - uring - webpki diff --git a/Cargo.lock b/Cargo.lock index 2b79b44a..9de63203 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,33 @@ dependencies = [ "trybuild", ] +[[package]] +name = "actix-proxy-protocol" +version = "0.0.1" +dependencies = [ + "actix-codec", + "actix-rt", + "actix-server", + "actix-service", + "actix-utils", + "arrayvec", + "bitflags 2.9.3", + "bytes", + "const-str", + "crc32fast", + "futures-core", + "futures-util", + "hex", + "itoa", + "nom 8.0.0", + "once_cell", + "pretty_assertions", + "smallvec", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "actix-rt" version = "2.10.0" @@ -187,6 +214,12 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "async-stream" version = "0.3.6" @@ -357,7 +390,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -438,6 +471,12 @@ dependencies = [ "cc", ] +[[package]] +name = "const-str" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3618cccc083bb987a415d85c02ca6c9994ea5b44731ec28b9ecf09658655fba9" + [[package]] name = "core-foundation" version = "0.9.4" @@ -454,6 +493,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -536,6 +584,12 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "displaydoc" version = "0.2.5" @@ -764,6 +818,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "home" version = "0.5.11" @@ -1198,6 +1258,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1423,6 +1492,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "pretty_env_logger" version = "0.5.0" @@ -2963,6 +3042,12 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index a1de08cb..0e2081b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "actix-codec", "actix-macros", + "actix-proxy-protocol", "actix-rt", "actix-server", "actix-service", @@ -22,6 +23,7 @@ rust-version = "1.75" [patch.crates-io] actix-codec = { path = "actix-codec" } actix-macros = { path = "actix-macros" } +actix-proxy-protocol = { path = "actix-proxy-protocol" } actix-rt = { path = "actix-rt" } actix-server = { path = "actix-server" } actix-service = { path = "actix-service" } diff --git a/actix-proxy-protocol/CHANGES.md b/actix-proxy-protocol/CHANGES.md new file mode 100644 index 00000000..566bebdc --- /dev/null +++ b/actix-proxy-protocol/CHANGES.md @@ -0,0 +1,7 @@ +# Changes + +## Unreleased - 2022-xx-xx + +## 0.0.1 - 2022-xx-xx + +- delete me diff --git a/actix-proxy-protocol/Cargo.toml b/actix-proxy-protocol/Cargo.toml new file mode 100755 index 00000000..c64f031c --- /dev/null +++ b/actix-proxy-protocol/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "actix-proxy-protocol" +version = "0.0.1" +authors = ["Rob Ede "] +description = "PROXY protocol utilities" +keywords = ["proxy", "protocol", "network", "haproxy", "tcp"] +categories = ["network-programming", "asynchronous"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-net" +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +actix-service = "2" +actix-utils = "3" + +arrayvec = "0.7" +bitflags = "2" +crc32fast = "1" +futures-core = { version = "0.3.17", default-features = false, features = ["std"] } +futures-util = { version = "0.3.17", default-features = false, features = ["std"] } +itoa = "1" +nom = "8" +smallvec = "1" +tokio = { version = "1.13.1", features = ["sync", "io-util"] } +tracing = { version = "0.1.30", default-features = false, features = ["log"] } + +[dev-dependencies] +actix-codec = "0.5" +actix-rt = "2.6" +actix-server = "2" +bytes = "1" +const-str = "0.5" +futures-util = { version = "0.3.7", default-features = false, features = ["sink", "async-await-macro"] } +hex = "0.4" +once_cell = "1" +pretty_assertions = "1" +tokio = { version = "1.13.1", features = ["io-util", "rt-multi-thread", "macros", "fs"] } +tracing-subscriber = "0.3" diff --git a/actix-proxy-protocol/LICENSE-APACHE b/actix-proxy-protocol/LICENSE-APACHE new file mode 120000 index 00000000..965b606f --- /dev/null +++ b/actix-proxy-protocol/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-proxy-protocol/LICENSE-MIT b/actix-proxy-protocol/LICENSE-MIT new file mode 120000 index 00000000..76219eb7 --- /dev/null +++ b/actix-proxy-protocol/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-proxy-protocol/README.md b/actix-proxy-protocol/README.md new file mode 100644 index 00000000..fed0fb06 --- /dev/null +++ b/actix-proxy-protocol/README.md @@ -0,0 +1,17 @@ +# actix-proxy-protocol + +> Implementation of the [PROXY protocol]. + +[![crates.io](https://img.shields.io/crates/v/actix-proxy-protocol?label=latest)](https://crates.io/crates/actix-proxy-protocol) +[![Documentation](https://docs.rs/actix-proxy-protocol/badge.svg?version=0.1.0)](https://docs.rs/actix-proxy-protocol/0.1.0) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) +![License](https://img.shields.io/crates/l/actix-proxy-protocol.svg) +[![Dependency Status](https://deps.rs/crate/actix-proxy-protocol/0.1.0/status.svg)](https://deps.rs/crate/actix-proxy-protocol/0.1.0) +![Downloads](https://img.shields.io/crates/d/actix-proxy-protocol.svg) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) + +## Resources + +- [Examples](./examples) + +[proxy protocol]: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt diff --git a/actix-proxy-protocol/examples/proxy-server.rs b/actix-proxy-protocol/examples/proxy-server.rs new file mode 100644 index 00000000..e26c518f --- /dev/null +++ b/actix-proxy-protocol/examples/proxy-server.rs @@ -0,0 +1,173 @@ +//! Adds PROXY protocol v1 prelude to connections. + +#![allow(unused)] + +use std::{ + io, mem, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use actix_proxy_protocol::{tlv, v1, v2, AddressFamily, Command, TransportProtocol}; +use actix_rt::net::TcpStream; +use actix_server::Server; +use actix_service::{fn_service, ServiceFactoryExt as _}; +use arrayvec::ArrayVec; +use bytes::BytesMut; +use const_str::concat_bytes; +use once_cell::sync::Lazy; +use tokio::io::{ + copy_bidirectional, AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _, BufReader, +}; + +static UPSTREAM: Lazy = Lazy::new(|| SocketAddr::from(([127, 0, 0, 1], 8080))); + +/* +NOTES: +108 byte buffer on receiver side is enough for any PROXY header +after PROXY, receive until CRLF, *then* decode parts +TLV = type-length-value + +TO DO: +handle UNKNOWN transport +v2 UNSPEC mode +AF_UNIX socket +*/ + +fn extend_with_ip_bytes(buf: &mut Vec, ip: IpAddr) { + match ip { + IpAddr::V4(ip) => buf.extend_from_slice(&ip.octets()), + IpAddr::V6(ip) => buf.extend_from_slice(&ip.octets()), + } +} + +async fn wrap_with_proxy_protocol_v1(mut stream: TcpStream) -> io::Result<()> { + let mut upstream = TcpStream::connect(("127.0.0.1", 8080)).await?; + + tracing::info!( + "PROXYv1 {} -> {}", + stream.peer_addr().unwrap(), + UPSTREAM.to_string(), + ); + + let proxy_header = v1::Header::new( + AddressFamily::Inet, + SocketAddr::from(([127, 0, 0, 1], 8081)), + *UPSTREAM, + ); + + proxy_header.write_to_tokio(&mut upstream).await?; + + let (_bytes_read, _bytes_written) = copy_bidirectional(&mut stream, &mut upstream).await?; + + Ok(()) +} + +async fn wrap_with_proxy_protocol_v2(mut stream: TcpStream) -> io::Result<()> { + let mut upstream = TcpStream::connect(("127.0.0.1", 8080)).await?; + + tracing::info!( + "PROXYv2 {} -> {}", + stream.peer_addr().unwrap(), + UPSTREAM.to_string(), + ); + + let mut proxy_header = v2::Header::new_tcp_ipv4_proxy(([127, 0, 0, 1], 8082), *UPSTREAM); + + proxy_header.add_typed_tlv(tlv::UniqueId::new("4269")); // UNIQUE_ID + proxy_header.add_typed_tlv(tlv::Noop::new("NOOP m8")); // NOOP + proxy_header.add_typed_tlv(tlv::Authority::new("localhost")); // NOOP + proxy_header.add_typed_tlv(tlv::Alpn::new("http/1.1")); // NOOP + proxy_header.add_crc23c_checksum(); + + proxy_header.write_to_tokio(&mut upstream).await?; + + let (_bytes_read, _bytes_written) = copy_bidirectional(&mut stream, &mut upstream).await?; + + Ok(()) +} + +async fn unwrap_proxy_protocol(mut stream: TcpStream) -> io::Result<()> { + let mut upstream = TcpStream::connect(("127.0.0.1", 8080)).await?; + + tracing::info!( + "PROXY unwrap {} -> {}", + stream.peer_addr().unwrap(), + UPSTREAM.to_string(), + ); + + let mut header = [0; 12]; + stream.peek(&mut header).await?; + + eprintln!("header: {}", String::from_utf8_lossy(&header)); + + if &header[..v1::SIGNATURE.len()] == v1::SIGNATURE.as_bytes() { + tracing::info!("v1"); + + let mut stream = BufReader::new(stream); + let mut buf = Vec::with_capacity(v1::MAX_HEADER_SIZE); + let _len = stream.read_until(b'\n', &mut buf).await?; + + eprintln!("{}", String::from_utf8_lossy(&buf)); + + let (rest, header) = match v1::Header::try_from_bytes(&buf) { + Ok((rest, header)) => (rest, header), + Err(err) => { + match err { + nom::Err::Incomplete(needed) => todo!(), + nom::Err::Error(err) => { + eprintln!( + "err {:?}, input: {}", + err.code, + String::from_utf8_lossy(err.input) + ) + } + nom::Err::Failure(_) => todo!(), + } + return Ok(()); + } + }; + eprintln!("{:02X?} - {:?}", rest, header); + + let (_bytes_read, _bytes_written) = copy_bidirectional(&mut stream, &mut upstream).await?; + } else if header == v2::SIGNATURE { + tracing::info!("v2"); + let (_bytes_read, _bytes_written) = copy_bidirectional(&mut stream, &mut upstream).await?; + } else { + tracing::warn!("No proxy header; closing"); + } + + Ok(()) +} + +fn start_server() -> io::Result { + tracing::info!("proxying to 127.0.0.1:8080"); + + Ok(Server::build() + .bind("proxy-protocol-v1", ("127.0.0.1", 8081), move || { + fn_service(wrap_with_proxy_protocol_v1) + .map_err(|err| tracing::error!("service error: {err:?}")) + })? + .bind("proxy-protocol-v2", ("127.0.0.1", 8082), move || { + fn_service(wrap_with_proxy_protocol_v2) + .map_err(|err| tracing::error!("service error: {err:?}")) + })? + .bind("proxy-protocol-unwrap", ("127.0.0.1", 8083), move || { + fn_service(unwrap_proxy_protocol) + .map_err(|err| tracing::error!("service error: {err:?}")) + })? + .workers(2) + .run()) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::fmt::fmt().without_time().init(); + + start_server()?.await?; + + Ok(()) +} diff --git a/actix-proxy-protocol/src/lib.rs b/actix-proxy-protocol/src/lib.rs new file mode 100644 index 00000000..32bc075c --- /dev/null +++ b/actix-proxy-protocol/src/lib.rs @@ -0,0 +1,156 @@ +//! PROXY protocol. + +#![expect(dead_code)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] + +pub mod tlv; +pub mod v1; +pub mod v2; + +/// PROXY Protocol Version. +#[derive(Debug, Clone, Copy)] +enum Version { + /// Human-readable header format (Version 1) + V1, + + /// Binary header format (Version 2) + V2, +} + +impl Version { + const fn signature(&self) -> &'static [u8] { + match self { + Version::V1 => v1::SIGNATURE.as_bytes(), + Version::V2 => v2::SIGNATURE.as_slice(), + } + } + + const fn v2_hi(&self) -> u8 { + (match self { + Version::V1 => panic!("v1 not supported in PROXY v2"), + Version::V2 => 0x2, + }) << 4 + } +} + +/// Command +/// +/// other values are unassigned and must not be emitted by senders. Receivers +/// must drop connections presenting unexpected values here. +#[derive(Debug, Clone, Copy)] +pub enum Command { + /// \x0 : LOCAL : the connection was established on purpose by the proxy + /// without being relayed. The connection endpoints are the sender and the + /// receiver. Such connections exist when the proxy sends health-checks to the + /// server. The receiver must accept this connection as valid and must use the + /// real connection endpoints and discard the protocol block including the + /// family which is ignored. + Local, + + /// \x1 : PROXY : the connection was established on behalf of another node, + /// and reflects the original connection endpoints. The receiver must then use + /// the information provided in the protocol block to get original the address. + Proxy, +} + +impl Command { + const fn v2_lo(&self) -> u8 { + match self { + Command::Local => 0x0, + Command::Proxy => 0x1, + } + } +} + +/// Address Family. +/// +/// maps to the original socket family without necessarily +/// matching the values internally used by the system. +/// +/// other values are unspecified and must not be emitted in version 2 of this +/// protocol and must be rejected as invalid by receivers. +#[derive(Debug, Clone, Copy)] +pub enum AddressFamily { + /// 0x0 : AF_UNSPEC : the connection is forwarded for an unknown, unspecified + /// or unsupported protocol. The sender should use this family when sending + /// LOCAL commands or when dealing with unsupported protocol families. The + /// receiver is free to accept the connection anyway and use the real endpoint + /// addresses or to reject it. The receiver should ignore address information. + Unspecified, + + /// 0x1 : AF_INET : the forwarded connection uses the AF_INET address family + /// (IPv4). The addresses are exactly 4 bytes each in network byte order, + /// followed by transport protocol information (typically ports). + Inet, + + /// 0x2 : AF_INET6 : the forwarded connection uses the AF_INET6 address family + /// (IPv6). The addresses are exactly 16 bytes each in network byte order, + /// followed by transport protocol information (typically ports). + Inet6, + + /// 0x3 : AF_UNIX : the forwarded connection uses the AF_UNIX address family + /// (UNIX). The addresses are exactly 108 bytes each. + Unix, +} + +impl AddressFamily { + pub(crate) fn v1_str(&self) -> &'static str { + match self { + AddressFamily::Inet => "TCP4", + AddressFamily::Inet6 => "TCP6", + af => panic!("{:?} is not supported in PROXY v1", af), + } + } + + const fn v2_hi(&self) -> u8 { + (match self { + AddressFamily::Unspecified => 0x0, + AddressFamily::Inet => 0x1, + AddressFamily::Inet6 => 0x2, + AddressFamily::Unix => 0x3, + }) << 4 + } +} + +/// Transport Protocol. +/// +/// other values are unspecified and must not be emitted in version 2 of this +/// protocol and must be rejected as invalid by receivers. +#[derive(Debug, Clone, Copy)] +pub enum TransportProtocol { + /// 0x0 : UNSPEC : the connection is forwarded for an unknown, unspecified + /// or unsupported protocol. The sender should use this family when sending + /// LOCAL commands or when dealing with unsupported protocol families. The + /// receiver is free to accept the connection anyway and use the real endpoint + /// addresses or to reject it. The receiver should ignore address information. + Unspecified, + + /// 0x1 : STREAM : the forwarded connection uses a SOCK_STREAM protocol (eg: + /// TCP or UNIX_STREAM). When used with AF_INET/AF_INET6 (TCP), the addresses + /// are followed by the source and destination ports represented on 2 bytes + /// each in network byte order. + Stream, + + /// 0x2 : DGRAM : the forwarded connection uses a SOCK_DGRAM protocol (eg: + /// UDP or UNIX_DGRAM). When used with AF_INET/AF_INET6 (UDP), the addresses + /// are followed by the source and destination ports represented on 2 bytes + /// each in network byte order. + Datagram, +} + +impl TransportProtocol { + const fn v2_lo(&self) -> u8 { + match self { + TransportProtocol::Unspecified => 0x0, + TransportProtocol::Stream => 0x1, + TransportProtocol::Datagram => 0x2, + } + } +} + +#[derive(Debug)] +enum ProxyProtocolHeader { + V1(v1::Header), + V2(v2::Header), +} diff --git a/actix-proxy-protocol/src/tlv.rs b/actix-proxy-protocol/src/tlv.rs new file mode 100644 index 00000000..04cd7173 --- /dev/null +++ b/actix-proxy-protocol/src/tlv.rs @@ -0,0 +1,292 @@ +use std::{borrow::Cow, convert::TryFrom, str}; + +const PP2_TYPE_ALPN: u8 = 0x01; // done +const PP2_TYPE_AUTHORITY: u8 = 0x02; // done +const PP2_TYPE_CRC32C: u8 = 0x03; // done +const PP2_TYPE_NOOP: u8 = 0x04; // done +const PP2_TYPE_UNIQUE_ID: u8 = 0x05; // done +const PP2_TYPE_SSL: u8 = 0x20; +const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21; +const PP2_SUBTYPE_SSL_CN: u8 = 0x22; +const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23; +const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24; +const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25; +const PP2_TYPE_NETNS: u8 = 0x30; + +pub trait Tlv: Sized { + const TYPE: u8; + + fn try_from_value(value: &[u8]) -> Option; + + fn value_bytes(&self) -> Cow<'_, [u8]>; + + fn try_from_parts(typ: u8, value: &[u8]) -> Option { + if typ != Self::TYPE { + return None; + } + + Self::try_from_value(value) + } +} + +/// Application-Layer Protocol Negotiation (ALPN). It is a byte sequence defining +/// the upper layer protocol in use over the connection. The most common use case +/// will be to pass the exact copy of the ALPN extension of the Transport Layer +/// Security (TLS) protocol as defined by RFC7301 [9]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Alpn { + alpn: Vec, +} + +impl Alpn { + /// + /// + /// # Panics + /// Panics if `alpn` is empty (i.e., has length of 0). + pub fn new(alpn: impl Into>) -> Self { + let alpn = alpn.into(); + + assert!(!alpn.is_empty(), "ALPN TLV value cannot be empty"); + + Self { alpn } + } +} + +impl Tlv for Alpn { + const TYPE: u8 = PP2_TYPE_ALPN; + + fn try_from_value(value: &[u8]) -> Option { + Some(Self { + alpn: value.to_owned(), + }) + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&self.alpn) + } +} + +/// Contains the host name value passed by the client, as an UTF8-encoded string. +/// In case of TLS being used on the client connection, this is the exact copy of +/// the "server_name" extension as defined by RFC3546 [10], section 3.1, often +/// referred to as "SNI". There are probably other situations where an authority +/// can be mentioned on a connection without TLS being involved at all. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Authority { + authority: String, +} + +impl Authority { + /// A UTF-8 + /// + /// # Panics + /// Panics if `authority` is an empty string. + pub fn new(authority: impl Into) -> Self { + let authority = authority.into(); + + assert!(!authority.is_empty(), "Authority TLV value cannot be empty"); + + Self { authority } + } +} + +impl Tlv for Authority { + const TYPE: u8 = PP2_TYPE_AUTHORITY; + + fn try_from_value(value: &[u8]) -> Option { + Some(Self { + authority: str::from_utf8(value).ok()?.to_owned(), + }) + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(self.authority.as_bytes()) + } +} + +/// The value of the type PP2_TYPE_CRC32C is a 32-bit number storing the CRC32c +/// checksum of the PROXY protocol header. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct Crc32c { + pub(crate) checksum: u32, +} + +impl Tlv for Crc32c { + const TYPE: u8 = PP2_TYPE_CRC32C; + + fn try_from_value(value: &[u8]) -> Option { + let checksum_bytes = <[u8; 4]>::try_from(value).ok()?; + + Some(Self { + checksum: u32::from_be_bytes(checksum_bytes), + }) + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Owned(self.checksum.to_be_bytes().to_vec()) + } +} + +/// The TLV of this type should be ignored when parsed. The value is zero or more +/// bytes. Can be used for data padding or alignment. Note that it can be used +/// to align only by 3 or more bytes because a TLV can not be smaller than that. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Noop { + value: Vec, +} + +impl Noop { + /// + /// + /// # Panics + /// Panics if `value` is empty (i.e., has length of 0). + pub fn new(value: impl Into>) -> Self { + let value = value.into(); + + assert!(!value.is_empty(), "Noop TLV `value` cannot be empty"); + + Self { value } + } +} + +impl Tlv for Noop { + const TYPE: u8 = PP2_TYPE_NOOP; + + fn try_from_value(value: &[u8]) -> Option { + Some(Self { + value: value.to_owned(), + }) + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&self.value) + } +} + +/// The value of the type PP2_TYPE_UNIQUE_ID is an opaque byte sequence of up to +/// 128 bytes generated by the upstream proxy that uniquely identifies the +/// connection. +/// +/// The unique ID can be used to easily correlate connections across multiple +/// layers of proxies, without needing to look up IP addresses and port numbers. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UniqueId { + value: Vec, +} + +impl UniqueId { + /// + /// + /// # Panics + /// Panics if `value` is 0 bytes or larger than 128 bytes. + pub fn new(id: impl Into>) -> Self { + let value = id.into(); + + assert!(!value.is_empty(), "UniqueId TLV `value` cannot be empty"); + assert!( + value.len() < 128, + "UniqueId TLV `value` cannot be larger than 128 bytes" + ); + + Self { value } + } +} + +impl Tlv for UniqueId { + const TYPE: u8 = PP2_TYPE_UNIQUE_ID; + + fn try_from_value(value: &[u8]) -> Option { + Some(Self { + value: value.to_owned(), + }) + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&self.value) + } +} + +bitflags::bitflags! { + #[derive(Debug, Clone, PartialEq, Eq)] + struct SslClientFlags: u8 { + const PP2_CLIENT_SSL = 0x01; + const PP2_CLIENT_CERT_CONN = 0x02; + const PP2_CLIENT_CERT_SESS = 0x04; + } +} + +/// TLS (SSL). +/// +/// Very broken atm. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ssl { + /// The field is made of a bit field indicating which element is present. + /// + /// Note, that each of these elements may lead to extra data being appended to + /// this TLV using a second level of TLV encapsulation. It is thus possible to + /// find multiple TLV values after this field. The total length of the pp2_tlv_ssl + /// TLV will reflect this. + client: SslClientFlags, + + /// The field will be zero if the client presented a certificate + /// and it was successfully verified, and non-zero otherwise. + verify: bool, + + /// Sub-TLVs. + tlvs: Vec, +} + +impl Tlv for Ssl { + const TYPE: u8 = PP2_TYPE_SSL; + + fn try_from_value(_value: &[u8]) -> Option { + /// The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When + /// this field is present, the US-ASCII string representation of the TLS version is + /// appended at the end of the field in the TLV format using the type + /// PP2_SUBTYPE_SSL_VERSION. + const PP2_CLIENT_SSL: u8 = 0x01; + + /// PP2_CLIENT_CERT_CONN indicates that the client provided a certificate over the + /// current connection. + const PP2_CLIENT_CERT_CONN: u8 = 0x02; + + /// PP2_CLIENT_CERT_SESS indicates that the client provided a + /// certificate at least once over the TLS session this connection belongs to. + const PP2_CLIENT_CERT_SESS: u8 = 0x04; + + // TODO: finish parsing + + None + } + + fn value_bytes(&self) -> Cow<'_, [u8]> { + Cow::Borrowed(&[]) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SslTlv {} + +#[cfg(test)] +mod tests { + use super::*; + + // #[test] + // #[should_panic] + // fn tlv_zero_len() { + // Tlv::new(0x00, vec![]); + // } + + #[test] + fn tlv_as_crc32c() { + // noop + assert_eq!(Crc32c::try_from_parts(0x04, &[0x00]), None); + + assert_eq!( + Crc32c::try_from_parts(0x03, &[0x08, 0x70, 0x17, 0x7b]), + Some(Crc32c { + checksum: 141563771 + }) + ); + } +} diff --git a/actix-proxy-protocol/src/v1.rs b/actix-proxy-protocol/src/v1.rs new file mode 100644 index 00000000..3cb29019 --- /dev/null +++ b/actix-proxy-protocol/src/v1.rs @@ -0,0 +1,148 @@ +use std::{fmt, io, net::SocketAddr}; + +use arrayvec::ArrayVec; +use nom::{IResult, Parser as _}; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; + +use crate::AddressFamily; + +pub const SIGNATURE: &str = "PROXY"; +pub const MAX_HEADER_SIZE: usize = 107; + +#[derive(Debug, Clone)] +pub struct Header { + /// Address family. + af: AddressFamily, + + /// Source address. + src: SocketAddr, + + /// Destination address. + dst: SocketAddr, +} + +impl Header { + pub const fn new(af: AddressFamily, src: SocketAddr, dst: SocketAddr) -> Self { + Self { af, src, dst } + } + + pub const fn new_inet(src: SocketAddr, dst: SocketAddr) -> Self { + Self::new(AddressFamily::Inet, src, dst) + } + + pub const fn new_inet6(src: SocketAddr, dst: SocketAddr) -> Self { + Self::new(AddressFamily::Inet6, src, dst) + } + + pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> { + write!(wrt, "{self}") + } + + pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { + // max length of a V1 header is 107 bytes + let mut buf = ArrayVec::<_, MAX_HEADER_SIZE>::new(); + self.write_to(&mut buf)?; + wrt.write_all(&buf).await + } + + pub fn try_from_bytes(slice: &[u8]) -> IResult<&[u8], Self> { + parsing::parse(slice) + } +} + +impl fmt::Display for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{proto_sig} {af} {src_ip} {dst_ip} {src_port} {dst_port}\r\n", + proto_sig = SIGNATURE, + af = self.af.v1_str(), + src_ip = self.src.ip(), + dst_ip = self.dst.ip(), + src_port = itoa::Buffer::new().format(self.src.port()), + dst_port = itoa::Buffer::new().format(self.dst.port()), + ) + } +} + +mod parsing { + use std::{ + net::{Ipv4Addr, SocketAddrV4}, + str::{self, FromStr}, + }; + + use nom::{ + branch::alt, + bytes::complete::{tag, take_while}, + character::complete::char, + combinator::{map, map_res}, + IResult, + }; + + use super::*; + + /// Parses a number from serialized representation (as bytes). + fn parse_number(input: &[u8]) -> IResult<&[u8], T> { + map_res(take_while(|c: u8| c.is_ascii_digit()), |s: &[u8]| { + let s = str::from_utf8(s).map_err(|_| "utf8 error")?; + let val = s.parse::().map_err(|_| "u8 parse error")?; + Ok::<_, Box>(val) + }) + .parse(input) + } + + /// Parses an address family. + fn parse_address_family(input: &[u8]) -> IResult<&[u8], AddressFamily> { + map_res(alt((tag("TCP4"), tag("TCP6"))), |af: &[u8]| match af { + b"TCP4" => Ok(AddressFamily::Inet), + b"TCP6" => Ok(AddressFamily::Inet6), + _ => Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid address family", + )), + }) + .parse(input) + } + + /// Parses an IPv4 address from serialized representation (as bytes). + fn parse_ipv4(input: &[u8]) -> IResult<&[u8], Ipv4Addr> { + map( + ( + parse_number::, + char('.'), + parse_number::, + char('.'), + parse_number::, + char('.'), + parse_number::, + ), + |(a, _, b, _, c, _, d)| Ipv4Addr::new(a, b, c, d), + ) + .parse(input) + } + + /// Parses an IPv4 address from ASCII bytes. + pub(super) fn parse(input: &[u8]) -> IResult<&[u8], Header> { + map( + ( + tag(SIGNATURE), + char(' '), + parse_address_family, + char(' '), + parse_ipv4, + char(' '), + parse_ipv4, + char(' '), + parse_number::, + char(' '), + parse_number::, + ), + |(_, _, af, _, src_ip, _, dst_ip, _, src_port, _, dst_port)| Header { + af, + src: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)), + dst: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)), + }, + ) + .parse(input) + } +} diff --git a/actix-proxy-protocol/src/v2.rs b/actix-proxy-protocol/src/v2.rs new file mode 100644 index 00000000..f5681e16 --- /dev/null +++ b/actix-proxy-protocol/src/v2.rs @@ -0,0 +1,304 @@ +use std::{ + io, + net::{IpAddr, SocketAddr}, +}; + +use smallvec::{SmallVec, ToSmallVec as _}; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; + +use crate::{ + tlv::{Crc32c, Tlv}, + AddressFamily, Command, TransportProtocol, Version, +}; + +pub const SIGNATURE: [u8; 12] = [ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +]; + +#[derive(Debug, Clone)] +pub struct Header { + command: Command, + transport_protocol: TransportProtocol, + address_family: AddressFamily, + src: SocketAddr, + dst: SocketAddr, + tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>, +} + +impl Header { + pub fn new( + command: Command, + transport_protocol: TransportProtocol, + address_family: AddressFamily, + src: impl Into, + dst: impl Into, + ) -> Self { + Self { + command, + transport_protocol, + address_family, + src: src.into(), + dst: dst.into(), + tlvs: SmallVec::new(), + } + } + + pub fn new_tcp_ipv4_proxy(src: impl Into, dst: impl Into) -> Self { + Self::new( + Command::Proxy, + TransportProtocol::Stream, + AddressFamily::Inet, + src, + dst, + ) + } + + pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) { + self.tlvs.push((typ, SmallVec::from_slice(value.as_ref()))); + } + + pub fn add_typed_tlv(&mut self, tlv: T) { + self.add_tlv(T::TYPE, tlv.value_bytes()); + } + + fn v2_len(&self) -> u16 { + let addr_len = if self.src.is_ipv4() { + 4 + 2 // 4b IPv4 + 2b port number + } else { + 16 + 2 // 16b IPv6 + 2b port number + }; + + (addr_len * 2) + + self + .tlvs + .iter() + .map(|(_, value)| 1 + 2 + value.len() as u16) + .sum::() + } + + pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> { + // PROXY v2 signature + wrt.write_all(&SIGNATURE)?; + + // version | command + wrt.write_all(&[Version::V2.v2_hi() | self.command.v2_lo()])?; + + // address family | transport protocol + wrt.write_all(&[self.address_family.v2_hi() | self.transport_protocol.v2_lo()])?; + + // rest-of-header length + wrt.write_all(&self.v2_len().to_be_bytes())?; + + tracing::debug!("proxy rest-of-header len: {}", self.v2_len()); + + fn write_ip_bytes_to(wrt: &mut impl io::Write, ip: IpAddr) -> io::Result<()> { + match ip { + IpAddr::V4(ip) => wrt.write_all(&ip.octets()), + IpAddr::V6(ip) => wrt.write_all(&ip.octets()), + } + } + + // L3 (IP) address + write_ip_bytes_to(wrt, self.src.ip())?; + write_ip_bytes_to(wrt, self.dst.ip())?; + + // L4 ports + wrt.write_all(&self.src.port().to_be_bytes())?; + wrt.write_all(&self.dst.port().to_be_bytes())?; + + // TLVs + for (typ, value) in &self.tlvs { + wrt.write_all(&[*typ])?; + wrt.write_all(&(value.len() as u16).to_be_bytes())?; + wrt.write_all(value)?; + } + + Ok(()) + } + + pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { + let buf = self.to_vec(); + wrt.write_all(&buf).await + } + + fn to_vec(&self) -> Vec { + // TODO: figure out cap + let mut buf = Vec::with_capacity(64); + self.write_to(&mut buf).unwrap(); + buf + } + + pub fn has_tlv(&self) -> bool { + self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE) + } + + /// Calculates and adds a crc32c TLV to the PROXY header. + /// + /// Uses method defined in spec. + /// + /// If this is not called last thing it will be wrong. + pub fn add_crc23c_checksum(&mut self) { + // don't add a checksum if it is already set + if self.has_tlv::() { + return; + } + + // When the checksum is supported by the sender after constructing the header + // the sender MUST: + // - initialize the checksum field to '0's. + // - calculate the CRC32c checksum of the PROXY header as described in RFC4960, + // Appendix B [8]. + // - put the resultant value into the checksum field, and leave the rest of + // the bits unchanged. + + // add zeroed checksum field to TLVs + self.add_typed_tlv(Crc32c::default()); + + // write PROXY header to buffer + let mut buf = Vec::new(); + self.write_to(&mut buf).unwrap(); + + // calculate CRC on buffer and update CRC TLV + let crc_calc = crc32fast::hash(&buf); + self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec(); + + tracing::debug!("checksum is {}", crc_calc); + } + + pub fn validate_crc32c_tlv(&self) -> Option { + // extract crc32c TLV or exit early if none is present + let crc_sent = self + .tlvs + .iter() + .filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, value)) + .next()?; + + // If the checksum is provided as part of the PROXY header and the checksum + // functionality is supported by the receiver, the receiver MUST: + // - store the received CRC32c checksum value aside. + // - replace the 32 bits of the checksum field in the received PROXY header with + // all '0's and calculate a CRC32c checksum value of the whole PROXY header. + // - verify that the calculated CRC32c checksum is the same as the received + // CRC32c checksum. If it is not, the receiver MUST treat the TCP connection + // providing the header as invalid. + // The default procedure for handling an invalid TCP connection is to abort it. + + let mut this = self.clone(); + for (typ, value) in this.tlvs.iter_mut() { + if Crc32c::try_from_parts(*typ, value).is_some() { + value.fill(0); + } + } + + let mut buf = Vec::new(); + this.write_to(&mut buf).unwrap(); + let crc_calc = crc32fast::hash(&buf); + + Some(crc_sent.checksum == crc_calc) + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv6Addr; + + use const_str::hex; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn write_v2_no_tlvs() { + let mut exp = Vec::new(); + exp.extend_from_slice(&SIGNATURE); // 0-11 + exp.extend_from_slice(&[0x21, 0x11]); // 12-13 + exp.extend_from_slice(&[0x00, 0x0C]); // 14-15 + exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 2]); // 16-23 + exp.extend_from_slice(&[0x04, 0xd2, 0x00, 80]); // 24-27 + + let header = Header::new( + Command::Proxy, + TransportProtocol::Stream, + AddressFamily::Inet, + SocketAddr::from(([127, 0, 0, 1], 1234)), + SocketAddr::from(([127, 0, 0, 2], 80)), + ); + + assert_eq!(header.v2_len(), 12); + assert_eq!(header.to_vec(), exp); + } + + #[test] + fn write_v2_ipv6_tlv_noop() { + let mut exp = Vec::new(); + exp.extend_from_slice(&SIGNATURE); // 0-11 + exp.extend_from_slice(&[0x20, 0x11]); // 12-13 + exp.extend_from_slice(&[0x00, 0x28]); // 14-15 + exp.extend_from_slice(&hex!("00000000000000000000000000000001")); // 16-31 + exp.extend_from_slice(&hex!("000102030405060708090A0B0C0D0E0F")); // 32-45 + exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); // 45-49 + exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); // 50-53 NOOP TLV + + let mut header = Header::new( + Command::Local, + TransportProtocol::Stream, + AddressFamily::Inet, + SocketAddr::from((Ipv6Addr::LOCALHOST, 80)), + SocketAddr::from(( + Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + 65535, + )), + ); + + header.add_tlv(0x04, [0]); + + assert_eq!(header.v2_len(), 36 + 4); + assert_eq!(header.to_vec(), exp); + } + + #[test] + fn write_v2_tlv_c2c() { + let mut exp = Vec::new(); + exp.extend_from_slice(&SIGNATURE); // 0-11 + exp.extend_from_slice(&[0x21, 0x11]); // 12-13 + exp.extend_from_slice(&[0x00, 0x13]); // 14-15 + exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 1]); // 16-23 + exp.extend_from_slice(&[0x00, 80, 0x00, 80]); // 24-27 + exp.extend_from_slice(&[0x03, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // 28-35 TLV crc32c + + assert_eq!( + crc32fast::hash(&exp), + // correct checksum calculated manually + u32::from_be_bytes([0x08, 0x70, 0x17, 0x7b]), + ); + + // re-assign actual checksum to last 4 bytes of expected byte array + exp[31..35].copy_from_slice(&[0x08, 0x70, 0x17, 0x7b]); + + let mut header = Header::new( + Command::Proxy, + TransportProtocol::Stream, + AddressFamily::Inet, + SocketAddr::from(([127, 0, 0, 1], 80)), + SocketAddr::from(([127, 0, 0, 1], 80)), + ); + + assert!( + header.validate_crc32c_tlv().is_none(), + "header doesn't have CRC TLV added yet" + ); + + // add crc32c TLV to header + header.add_crc23c_checksum(); + + assert_eq!(header.v2_len(), 12 + 7); + assert_eq!(header.to_vec(), exp); + + // struct can self-validate checksum + assert_eq!(header.validate_crc32c_tlv().unwrap(), true); + + // mangle crc32c TLV and assert that validate now fails + *header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00; + assert_eq!(header.validate_crc32c_tlv().unwrap(), false); + } +}