diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index c6bc440c8..a8846ea47 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -31,7 +31,7 @@ where A: Actor> + StreamHandler, T: Stream + 'static, { - let mut res = handshake(req, &vec![])?; + let mut res = handshake(req)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } @@ -40,7 +40,7 @@ where /// `protocols` is a sequence of known protocols. pub fn start_with_protocols( actor: A, - protocols: &Vec<&str>, + protocols: &[&str], req: &HttpRequest, stream: T, ) -> Result @@ -48,10 +48,18 @@ where A: Actor> + StreamHandler, T: Stream + 'static, { - let mut res = handshake(req, protocols)?; + let mut res = handshake_with_protocols(req, protocols)?; Ok(res.streaming(WebsocketContext::create(actor, stream))) } +/// Prepare `WebSocket` handshake response. +/// +/// This function returns handshake `HttpResponse`, ready to send to peer. +/// It does not perform any IO. +pub fn handshake(req: &HttpRequest) -> Result { + handshake_with_protocols(req, &[]) +} + /// Prepare `WebSocket` handshake response. /// /// This function returns handshake `HttpResponse`, ready to send to peer. @@ -60,9 +68,9 @@ where /// `protocols` is a sequence of known protocols. On successful handshake, /// the returned response headers contain the first protocol in this list /// which the server also knows. -pub fn handshake( +pub fn handshake_with_protocols( req: &HttpRequest, - protocols: &Vec<&str>, + protocols: &[&str], ) -> Result { // WebSocket accepts only GET if *req.method() != Method::GET { @@ -476,13 +484,13 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::GetMethodRequired, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default().to_http_request(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -490,7 +498,7 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -501,7 +509,7 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::NoConnectionUpgrade, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -516,7 +524,7 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::NoVersionHeader, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -535,7 +543,7 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::UnsupportedVersion, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -554,7 +562,7 @@ mod tests { .to_http_request(); assert_eq!( HandshakeError::BadWebsocketKey, - handshake(&req, &vec![]).err().unwrap() + handshake(&req).err().unwrap() ); let req = TestRequest::default() @@ -578,7 +586,7 @@ mod tests { assert_eq!( StatusCode::SWITCHING_PROTOCOLS, - handshake(&req, &vec![]).unwrap().finish().status() + handshake(&req).unwrap().finish().status() ); let req = TestRequest::default() @@ -604,15 +612,22 @@ mod tests { ) .to_http_request(); - let protocols = vec!["graphql"]; + let protocols = ["graphql"]; assert_eq!( StatusCode::SWITCHING_PROTOCOLS, - handshake(&req, &protocols).unwrap().finish().status() + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .status() ); assert_eq!( Some(&header::HeaderValue::from_static("graphql")), - handshake(&req, &protocols).unwrap().finish().headers().get(&header::SEC_WEBSOCKET_PROTOCOL) + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) ); let req = TestRequest::default() @@ -642,11 +657,18 @@ mod tests { assert_eq!( StatusCode::SWITCHING_PROTOCOLS, - handshake(&req, &protocols).unwrap().finish().status() + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .status() ); assert_eq!( Some(&header::HeaderValue::from_static("p2")), - handshake(&req, &protocols).unwrap().finish().headers().get(&header::SEC_WEBSOCKET_PROTOCOL) + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) ); } }