add support for specifying protocols on websocket handshake

This commit is contained in:
jairinhohw 2019-05-09 23:44:13 -03:00
parent a17ff492a1
commit e97a1b748b
1 changed files with 120 additions and 15 deletions

View File

@ -31,7 +31,24 @@ where
A: Actor<Context = WebsocketContext<A>> + StreamHandler<Message, ProtocolError>, A: Actor<Context = WebsocketContext<A>> + StreamHandler<Message, ProtocolError>,
T: Stream<Item = Bytes, Error = PayloadError> + 'static, T: Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
let mut res = handshake(req)?; let mut res = handshake(req, &vec![])?;
Ok(res.streaming(WebsocketContext::create(actor, stream)))
}
/// Do websocket handshake and start ws actor.
///
/// `protocols` is a sequence of known protocols.
pub fn start_with_protocols<A, T>(
actor: A,
protocols: &Vec<&str>,
req: &HttpRequest,
stream: T,
) -> Result<HttpResponse, Error>
where
A: Actor<Context = WebsocketContext<A>> + StreamHandler<Message, ProtocolError>,
T: Stream<Item = Bytes, Error = PayloadError> + 'static,
{
let mut res = handshake(req, protocols)?;
Ok(res.streaming(WebsocketContext::create(actor, stream))) Ok(res.streaming(WebsocketContext::create(actor, stream)))
} }
@ -40,10 +57,13 @@ where
/// This function returns handshake `HttpResponse`, ready to send to peer. /// This function returns handshake `HttpResponse`, ready to send to peer.
/// It does not perform any IO. /// It does not perform any IO.
/// ///
// /// `protocols` is a sequence of known protocols. On successful handshake, /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list /// the returned response headers contain the first protocol in this list
// /// which the server also knows. /// which the server also knows.
pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeError> { pub fn handshake(
req: &HttpRequest,
protocols: &Vec<&str>,
) -> Result<HttpResponseBuilder, HandshakeError> {
// WebSocket accepts only GET // WebSocket accepts only GET
if *req.method() != Method::GET { if *req.method() != Method::GET {
return Err(HandshakeError::GetMethodRequired); return Err(HandshakeError::GetMethodRequired);
@ -92,11 +112,28 @@ pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeErro
hash_key(key.as_ref()) hash_key(key.as_ref())
}; };
Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) // check requested protocols
let protocol =
req.headers()
.get(&header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
req_protocols
.split(", ")
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
.upgrade("websocket") .upgrade("websocket")
.header(header::TRANSFER_ENCODING, "chunked") .header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take()) .take();
if let Some(protocol) = protocol {
response.header(&header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
Ok(response)
} }
/// Execution context for `WebSockets` actors /// Execution context for `WebSockets` actors
@ -439,13 +476,13 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::GetMethodRequired, HandshakeError::GetMethodRequired,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default().to_http_request(); let req = TestRequest::default().to_http_request();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -453,7 +490,7 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -464,7 +501,7 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::NoConnectionUpgrade, HandshakeError::NoConnectionUpgrade,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -479,7 +516,7 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::NoVersionHeader, HandshakeError::NoVersionHeader,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -498,7 +535,7 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::UnsupportedVersion, HandshakeError::UnsupportedVersion,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -517,7 +554,7 @@ mod tests {
.to_http_request(); .to_http_request();
assert_eq!( assert_eq!(
HandshakeError::BadWebsocketKey, HandshakeError::BadWebsocketKey,
handshake(&req).err().unwrap() handshake(&req, &vec![]).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -541,7 +578,75 @@ mod tests {
assert_eq!( assert_eq!(
StatusCode::SWITCHING_PROTOCOLS, StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status() handshake(&req, &vec![]).unwrap().finish().status()
);
let req = TestRequest::default()
.header(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
)
.header(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
)
.header(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"),
)
.header(
header::SEC_WEBSOCKET_KEY,
header::HeaderValue::from_static("13"),
)
.header(
header::SEC_WEBSOCKET_PROTOCOL,
header::HeaderValue::from_static("graphql"),
)
.to_http_request();
let protocols = vec!["graphql"];
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake(&req, &protocols).unwrap().finish().status()
);
assert_eq!(
Some(&header::HeaderValue::from_static("graphql")),
handshake(&req, &protocols).unwrap().finish().headers().get(&header::SEC_WEBSOCKET_PROTOCOL)
);
let req = TestRequest::default()
.header(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
)
.header(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
)
.header(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"),
)
.header(
header::SEC_WEBSOCKET_KEY,
header::HeaderValue::from_static("13"),
)
.header(
header::SEC_WEBSOCKET_PROTOCOL,
header::HeaderValue::from_static("p1, p2, p3"),
)
.to_http_request();
let protocols = vec!["p3", "p2"];
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake(&req, &protocols).unwrap().finish().status()
);
assert_eq!(
Some(&header::HeaderValue::from_static("p2")),
handshake(&req, &protocols).unwrap().finish().headers().get(&header::SEC_WEBSOCKET_PROTOCOL)
); );
} }
} }