From f277b128b65c874b46ff897e4b770f8a713b88a4 Mon Sep 17 00:00:00 2001
From: fakeshadow <24548779@qq.com>
Date: Thu, 13 May 2021 19:24:32 +0800
Subject: [PATCH] cleanup ws test (#2213)

---
 actix-http/tests/test_ws.rs | 182 +++++++++++++++---------------------
 1 file changed, 76 insertions(+), 106 deletions(-)

diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs
index bf1ca938..b17d4211 100644
--- a/actix-http/tests/test_ws.rs
+++ b/actix-http/tests/test_ws.rs
@@ -1,193 +1,163 @@
-use std::cell::Cell;
-use std::future::Future;
-use std::marker::PhantomData;
-use std::pin::Pin;
-use std::sync::{Arc, Mutex};
-use std::task::{Context, Poll};
+use std::{
+    cell::Cell,
+    task::{Context, Poll},
+};
 
 use actix_codec::{AsyncRead, AsyncWrite, Framed};
-use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
+use actix_http::{
+    body::BodySize,
+    h1,
+    ws::{self, CloseCode, Frame, Item, Message},
+    Error, HttpService, Request, Response,
+};
 use actix_http_test::test_server;
 use actix_service::{fn_factory, Service};
-use actix_utils::future;
 use bytes::Bytes;
+use futures_core::future::LocalBoxFuture;
 use futures_util::{SinkExt as _, StreamExt as _};
 
-use crate::ws::Dispatcher;
+#[derive(Clone)]
+struct WsService(Cell<bool>);
 
-struct WsService<T>(Arc<Mutex<(PhantomData<T>, Cell<bool>)>>);
-
-impl<T> WsService<T> {
+impl WsService {
     fn new() -> Self {
-        WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false)))))
+        WsService(Cell::new(false))
     }
 
     fn set_polled(&self) {
-        *self.0.lock().unwrap().1.get_mut() = true;
+        self.0.set(true);
     }
 
     fn was_polled(&self) -> bool {
-        self.0.lock().unwrap().1.get()
+        self.0.get()
     }
 }
 
-impl<T> Clone for WsService<T> {
-    fn clone(&self) -> Self {
-        WsService(self.0.clone())
-    }
-}
-
-impl<T> Service<(Request, Framed<T, h1::Codec>)> for WsService<T>
+impl<T> Service<(Request, Framed<T, h1::Codec>)> for WsService
 where
     T: AsyncRead + AsyncWrite + Unpin + 'static,
 {
     type Response = ();
     type Error = Error;
-    type Future = Pin<Box<dyn Future<Output = Result<(), Error>>>>;
+    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
 
-    fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
         self.set_polled();
         Poll::Ready(Ok(()))
     }
 
     fn call(&self, (req, mut framed): (Request, Framed<T, h1::Codec>)) -> Self::Future {
-        let fut = async move {
-            let res = ws::handshake(req.head()).unwrap().message_body(()).unwrap();
+        assert!(self.was_polled());
 
-            framed
-                .send((res, body::BodySize::None).into())
-                .await
-                .unwrap();
+        Box::pin(async move {
+            let res = ws::handshake(req.head())?.message_body(())?;
 
-            Dispatcher::with(framed.replace_codec(ws::Codec::new()), service)
-                .await
-                .map_err(|_| panic!())
-        };
+            framed.send((res, BodySize::None).into()).await?;
 
-        Box::pin(fut)
+            let framed = framed.replace_codec(ws::Codec::new());
+
+            ws::Dispatcher::with(framed, service).await?;
+
+            Ok(())
+        })
     }
 }
 
-async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
+async fn service(msg: Frame) -> Result<Message, Error> {
     let msg = match msg {
-        ws::Frame::Ping(msg) => ws::Message::Pong(msg),
-        ws::Frame::Text(text) => {
-            ws::Message::Text(String::from_utf8_lossy(&text).into_owned().into())
+        Frame::Ping(msg) => Message::Pong(msg),
+        Frame::Text(text) => {
+            Message::Text(String::from_utf8_lossy(&text).into_owned().into())
         }
-        ws::Frame::Binary(bin) => ws::Message::Binary(bin),
-        ws::Frame::Continuation(item) => ws::Message::Continuation(item),
-        ws::Frame::Close(reason) => ws::Message::Close(reason),
-        _ => panic!(),
+        Frame::Binary(bin) => Message::Binary(bin),
+        Frame::Continuation(item) => Message::Continuation(item),
+        Frame::Close(reason) => Message::Close(reason),
+        _ => return Err(Error::from(ws::ProtocolError::BadOpCode)),
     };
+
     Ok(msg)
 }
 
 #[actix_rt::test]
 async fn test_simple() {
-    let ws_service = WsService::new();
-    let mut srv = test_server({
-        let ws_service = ws_service.clone();
-        move || {
-            let ws_service = ws_service.clone();
-            HttpService::build()
-                .upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone())))
-                .finish(|_| future::ok::<_, ()>(Response::not_found()))
-                .tcp()
-        }
+    let mut srv = test_server(|| {
+        HttpService::build()
+            .upgrade(fn_factory(|| async { Ok::<_, ()>(WsService::new()) }))
+            .finish(|_| async { Ok::<_, ()>(Response::not_found()) })
+            .tcp()
     })
     .await;
 
     // client service
     let mut framed = srv.ws().await.unwrap();
-    framed.send(ws::Message::Text("text".into())).await.unwrap();
-    let (item, mut framed) = framed.into_future().await;
-    assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Text(Bytes::from_static(b"text"))
-    );
+    framed.send(Message::Text("text".into())).await.unwrap();
+
+    let item = framed.next().await.unwrap().unwrap();
+    assert_eq!(item, Frame::Text(Bytes::from_static(b"text")));
+
+    framed.send(Message::Binary("text".into())).await.unwrap();
+
+    let item = framed.next().await.unwrap().unwrap();
+    assert_eq!(item, Frame::Binary(Bytes::from_static(&b"text"[..])));
+
+    framed.send(Message::Ping("text".into())).await.unwrap();
+    let item = framed.next().await.unwrap().unwrap();
+    assert_eq!(item, Frame::Pong("text".to_string().into()));
 
     framed
-        .send(ws::Message::Binary("text".into()))
+        .send(Message::Continuation(Item::FirstText("text".into())))
         .await
         .unwrap();
-    let (item, mut framed) = framed.into_future().await;
+    let item = framed.next().await.unwrap().unwrap();
     assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Binary(Bytes::from_static(&b"text"[..]))
-    );
-
-    framed.send(ws::Message::Ping("text".into())).await.unwrap();
-    let (item, mut framed) = framed.into_future().await;
-    assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Pong("text".to_string().into())
-    );
-
-    framed
-        .send(ws::Message::Continuation(ws::Item::FirstText(
-            "text".into(),
-        )))
-        .await
-        .unwrap();
-    let (item, mut framed) = framed.into_future().await;
-    assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text")))
+        item,
+        Frame::Continuation(Item::FirstText(Bytes::from_static(b"text")))
     );
 
     assert!(framed
-        .send(ws::Message::Continuation(ws::Item::FirstText(
-            "text".into()
-        )))
+        .send(Message::Continuation(Item::FirstText("text".into())))
         .await
         .is_err());
     assert!(framed
-        .send(ws::Message::Continuation(ws::Item::FirstBinary(
-            "text".into()
-        )))
+        .send(Message::Continuation(Item::FirstBinary("text".into())))
         .await
         .is_err());
 
     framed
-        .send(ws::Message::Continuation(ws::Item::Continue("text".into())))
+        .send(Message::Continuation(Item::Continue("text".into())))
         .await
         .unwrap();
-    let (item, mut framed) = framed.into_future().await;
+    let item = framed.next().await.unwrap().unwrap();
     assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text")))
+        item,
+        Frame::Continuation(Item::Continue(Bytes::from_static(b"text")))
     );
 
     framed
-        .send(ws::Message::Continuation(ws::Item::Last("text".into())))
+        .send(Message::Continuation(Item::Last("text".into())))
         .await
         .unwrap();
-    let (item, mut framed) = framed.into_future().await;
+    let item = framed.next().await.unwrap().unwrap();
     assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text")))
+        item,
+        Frame::Continuation(Item::Last(Bytes::from_static(b"text")))
     );
 
     assert!(framed
-        .send(ws::Message::Continuation(ws::Item::Continue("text".into())))
+        .send(Message::Continuation(Item::Continue("text".into())))
         .await
         .is_err());
 
     assert!(framed
-        .send(ws::Message::Continuation(ws::Item::Last("text".into())))
+        .send(Message::Continuation(Item::Last("text".into())))
         .await
         .is_err());
 
     framed
-        .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
+        .send(Message::Close(Some(CloseCode::Normal.into())))
         .await
         .unwrap();
 
-    let (item, _framed) = framed.into_future().await;
-    assert_eq!(
-        item.unwrap().unwrap(),
-        ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
-    );
-
-    assert!(ws_service.was_polled());
+    let item = framed.next().await.unwrap().unwrap();
+    assert_eq!(item, Frame::Close(Some(CloseCode::Normal.into())));
 }