From f12b6132116a4488794a50a288e68411f10959fa Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Thu, 8 Mar 2018 20:39:05 -0800
Subject: [PATCH] more ws optimizations

---
 src/payload.rs  |  15 +++++
 src/ws/frame.rs | 144 +++++++++++++++++++++++++++++++++++++-----------
 src/ws/mask.rs  |  22 ++++----
 3 files changed, 138 insertions(+), 43 deletions(-)

diff --git a/src/payload.rs b/src/payload.rs
index bfa4dc81..512d56f1 100644
--- a/src/payload.rs
+++ b/src/payload.rs
@@ -297,6 +297,21 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
         }
     }
 
+    #[inline]
+    pub fn get_chunk(&mut self) -> Poll<Option<&[u8]>, PayloadError> {
+        if self.items.is_empty() {
+            match self.poll_stream()? {
+                Async::Ready(true) => (),
+                Async::Ready(false) => return Ok(Async::Ready(None)),
+                Async::NotReady => return Ok(Async::NotReady),
+            }
+        }
+        match self.items.front().map(|c| c.as_ref()) {
+            Some(chunk) => Ok(Async::Ready(Some(chunk))),
+            None => Ok(Async::NotReady),
+        }
+    }
+
     #[inline]
     pub fn readexactly(&mut self, size: usize) -> Poll<Option<Bytes>, PayloadError> {
         if size <= self.len {
diff --git a/src/ws/frame.rs b/src/ws/frame.rs
index 1d758298..52a20e50 100644
--- a/src/ws/frame.rs
+++ b/src/ws/frame.rs
@@ -1,4 +1,4 @@
-use std::{fmt, mem};
+use std::{fmt, mem, ptr};
 use std::iter::FromIterator;
 use bytes::{Bytes, BytesMut, BufMut};
 use byteorder::{ByteOrder, BigEndian, NetworkEndian};
@@ -17,9 +17,6 @@ use ws::mask::apply_mask;
 #[derive(Debug)]
 pub(crate) struct Frame {
     finished: bool,
-    rsv1: bool,
-    rsv2: bool,
-    rsv3: bool,
     opcode: OpCode,
     payload: Binary,
 }
@@ -51,9 +48,9 @@ impl Frame {
         Frame::message(payload, OpCode::Close, true, genmask)
     }
 
-    /// Parse the input stream into a frame.
-    pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
-                    -> Poll<Option<Frame>, ProtocolError>
+    fn read_copy_md<S>(
+        pl: &mut PayloadHelper<S>, server: bool, max_size: usize
+    ) -> Poll<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError>
         where S: Stream<Item=Bytes, Error=PayloadError>
     {
         let mut idx = 2;
@@ -74,12 +71,14 @@ impl Frame {
             return Err(ProtocolError::MaskedFrame)
         }
 
-        let rsv1 = first & 0x40 != 0;
-        let rsv2 = first & 0x20 != 0;
-        let rsv3 = first & 0x10 != 0;
+        // Op code
         let opcode = OpCode::from(first & 0x0F);
-        let len = second & 0x7F;
 
+        if let OpCode::Bad = opcode {
+            return Err(ProtocolError::InvalidOpcode(first & 0x0F))
+        }
+
+        let len = second & 0x7F;
         let length = if len == 126 {
             let buf = match pl.copy(4)? {
                 Async::Ready(Some(buf)) => buf,
@@ -114,14 +113,106 @@ impl Frame {
                 Async::NotReady => return Ok(Async::NotReady),
             };
 
-            let mut mask_bytes = [0u8; 4];
-            mask_bytes.copy_from_slice(&buf[idx..idx+4]);
+            let mask: &[u8] = &buf[idx..idx+4];
+            let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
             idx += 4;
-            Some(mask_bytes)
+            Some(mask_u32)
         } else {
             None
         };
 
+        Ok(Async::Ready(Some((idx, finished, opcode, length, mask))))
+    }
+
+    fn read_chunk_md(chunk: &[u8], server: bool, max_size: usize)
+                     -> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError>
+    {
+        let chunk_len = chunk.len();
+
+        let mut idx = 2;
+        if chunk_len < 2 {
+            return Ok(Async::NotReady)
+        }
+
+        let first = chunk[0];
+        let second = chunk[1];
+        let finished = first & 0x80 != 0;
+
+        // check masking
+        let masked = second & 0x80 != 0;
+        if !masked && server {
+            return Err(ProtocolError::UnmaskedFrame)
+        } else if masked && !server {
+            return Err(ProtocolError::MaskedFrame)
+        }
+
+        // Op code
+        let opcode = OpCode::from(first & 0x0F);
+
+        if let OpCode::Bad = opcode {
+            return Err(ProtocolError::InvalidOpcode(first & 0x0F))
+        }
+
+        let len = second & 0x7F;
+        let length = if len == 126 {
+            if chunk_len < 4 {
+                return Ok(Async::NotReady)
+            }
+            let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize;
+            idx += 2;
+            len
+        } else if len == 127 {
+            if chunk_len < 10 {
+                return Ok(Async::NotReady)
+            }
+            let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize;
+            idx += 8;
+            len
+        } else {
+            len as usize
+        };
+
+        // check for max allowed size
+        if length > max_size {
+            return Err(ProtocolError::Overflow)
+        }
+
+        let mask = if server {
+            if chunk_len < idx + 4 {
+                return Ok(Async::NotReady)
+            }
+
+            let mask: &[u8] = &chunk[idx..idx+4];
+            let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
+            idx += 4;
+            Some(mask_u32)
+        } else {
+            None
+        };
+
+        Ok(Async::Ready((idx, finished, opcode, length, mask)))
+    }
+
+    /// Parse the input stream into a frame.
+    pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
+                    -> Poll<Option<Frame>, ProtocolError>
+        where S: Stream<Item=Bytes, Error=PayloadError>
+    {
+        let result = match pl.get_chunk()? {
+            Async::NotReady => return Ok(Async::NotReady),
+            Async::Ready(Some(chunk)) => Frame::read_chunk_md(chunk, server, max_size)?,
+            Async::Ready(None) => return Ok(Async::Ready(None)),
+        };
+
+        let (idx, finished, opcode, length, mask) = match result {
+            Async::NotReady => match Frame::read_copy_md(pl, server, max_size)? {
+                Async::NotReady => return Ok(Async::NotReady),
+                Async::Ready(Some(item)) => item,
+                Async::Ready(None) => return Ok(Async::Ready(None)),
+            },
+            Async::Ready(item) => item,
+        };
+
         match pl.can_read(idx + length)? {
             Async::Ready(Some(true)) => (),
             Async::Ready(None) => return Ok(Async::Ready(None)),
@@ -134,7 +225,7 @@ impl Frame {
         // get body
         if length == 0 {
             return Ok(Async::Ready(Some(Frame {
-                finished, rsv1, rsv2, rsv3, opcode, payload: Binary::from("") })));
+                finished, opcode, payload: Binary::from("") })));
         }
 
         let data = match pl.readexactly(length)? {
@@ -143,11 +234,6 @@ impl Frame {
             Async::NotReady => panic!(),
         };
 
-        // Disallow bad opcode
-        if let OpCode::Bad = opcode {
-            return Err(ProtocolError::InvalidOpcode(first & 0x0F))
-        }
-
         // control frames must have length <= 125
         match opcode {
             OpCode::Ping | OpCode::Pong if length > 125 => {
@@ -161,14 +247,14 @@ impl Frame {
         }
 
         // unmask
-        if let Some(ref mask) = mask {
+        if let Some(mask) = mask {
             #[allow(mutable_transmutes)]
             let p: &mut [u8] = unsafe{let ptr: &[u8] = &data; mem::transmute(ptr)};
             apply_mask(p, mask);
         }
 
         Ok(Async::Ready(Some(Frame {
-            finished, rsv1, rsv2, rsv3, opcode, payload: data.into() })))
+            finished, opcode, payload: data.into() })))
     }
 
     /// Generate binary representation
@@ -213,13 +299,13 @@ impl Frame {
         };
 
         if genmask {
-            let mask: [u8; 4] = rand::random();
+            let mask = rand::random::<u32>();
             unsafe {
                 {
                     let buf_mut = buf.bytes_mut();
-                    buf_mut[..4].copy_from_slice(&mask);
+                    *(buf_mut as *mut _ as *mut u32) = mask;
                     buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref());
-                    apply_mask(&mut buf_mut[4..], &mask);
+                    apply_mask(&mut buf_mut[4..], mask);
                 }
                 buf.advance_mut(payload_len + 4);
             }
@@ -235,9 +321,6 @@ impl Default for Frame {
     fn default() -> Frame {
         Frame {
             finished: true,
-            rsv1: false,
-            rsv2: false,
-            rsv3: false,
             opcode: OpCode::Close,
             payload: Binary::from(&b""[..]),
         }
@@ -250,15 +333,11 @@ impl fmt::Display for Frame {
             "
 <FRAME>
     final: {}
-    reserved: {} {} {}
     opcode: {}
     payload length: {}
     payload: 0x{}
 </FRAME>",
                self.finished,
-               self.rsv1,
-               self.rsv2,
-               self.rsv3,
                self.opcode,
                self.payload.len(),
                self.payload.as_ref().iter().map(
@@ -296,7 +375,6 @@ mod tests {
         let mut buf = PayloadHelper::new(once(Ok(buf.freeze())));
 
         let frame = extract(Frame::parse(&mut buf, false, 1024));
-        println!("FRAME: {}", frame);
         assert!(!frame.finished);
         assert_eq!(frame.opcode, OpCode::Text);
         assert_eq!(frame.payload.as_ref(), &b"1"[..]);
diff --git a/src/ws/mask.rs b/src/ws/mask.rs
index 2e5a2960..e29eefd9 100644
--- a/src/ws/mask.rs
+++ b/src/ws/mask.rs
@@ -2,11 +2,10 @@
 use std::cmp::min;
 use std::mem::uninitialized;
 use std::ptr::copy_nonoverlapping;
-use std::ptr;
 
 /// Mask/unmask a frame.
 #[inline]
-pub fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
+pub fn apply_mask(buf: &mut [u8], mask: u32) {
     apply_mask_fast32(buf, mask)
 }
 
@@ -21,9 +20,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
 
 /// Faster version of `apply_mask()` which operates on 8-byte blocks.
 #[inline]
-fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) {
-    let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
-
+fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) {
     let mut ptr = buf.as_mut_ptr();
     let mut len = buf.len();
 
@@ -35,12 +32,14 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) {
             ptr = ptr.offset(head as isize);
         }
         len -= head;
-        let mask_u32 = if cfg!(target_endian = "big") {
+        //let mask_u32 =
+        if cfg!(target_endian = "big") {
             mask_u32.rotate_left(8 * head as u32)
         } else {
             mask_u32.rotate_right(8 * head as u32)
-        };
+        }//;
 
+        /*
         let head = min(len, (4 - (ptr as usize & 3)) & 3);
         if head > 0 {
             unsafe {
@@ -55,7 +54,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) {
             }
         } else {
             mask_u32
-        }
+        }*/
     } else {
         mask_u32
     };
@@ -106,6 +105,7 @@ unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
 
 #[cfg(test)]
 mod tests {
+    use std::ptr;
  use super::{apply_mask_fallback, apply_mask_fast32};
 
     #[test]
@@ -113,6 +113,8 @@ mod tests {
         let mask = [
             0x6d, 0xb6, 0xb2, 0x80,
         ];
+        let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
+
         let unmasked = vec![
             0xf3, 0x00, 0x01, 0x02,  0x03, 0x80, 0x81, 0x82,
             0xff, 0xfe, 0x00, 0x17,  0x74, 0xf9, 0x12, 0x03,
@@ -124,7 +126,7 @@ mod tests {
             apply_mask_fallback(&mut masked, &mask);
 
             let mut masked_fast = unmasked.clone();
-            apply_mask_fast32(&mut masked_fast, &mask);
+            apply_mask_fast32(&mut masked_fast, mask_u32);
 
             assert_eq!(masked, masked_fast);
         }
@@ -135,7 +137,7 @@ mod tests {
             apply_mask_fallback(&mut masked[1..], &mask);
 
             let mut masked_fast = unmasked.clone();
-            apply_mask_fast32(&mut masked_fast[1..], &mask);
+            apply_mask_fast32(&mut masked_fast[1..], mask_u32);
 
             assert_eq!(masked, masked_fast);
         }