add connection upgrade test

This commit is contained in:
Rob Ede 2020-12-22 23:25:30 +00:00
parent 4e988b9d2f
commit 52cbcf5245
No known key found for this signature in database
GPG Key ID: C2A3B36E841A91E6
2 changed files with 49 additions and 12 deletions

View File

@ -1,8 +1,11 @@
use std::collections::VecDeque; use std::{
use std::future::Future; collections::VecDeque,
use std::pin::Pin; fmt,
use std::task::{Context, Poll}; future::Future,
use std::{fmt, io, net}; io, mem, net,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
use actix_rt::time::{delay_until, Delay, Instant}; use actix_rt::time::{delay_until, Delay, Instant};
@ -800,10 +803,10 @@ where
let inner_p = inner.as_mut().project(); let inner_p = inner.as_mut().project();
let mut parts = FramedParts::with_read_buf( let mut parts = FramedParts::with_read_buf(
inner_p.io.take().unwrap(), inner_p.io.take().unwrap(),
std::mem::take(inner_p.codec), mem::take(inner_p.codec),
std::mem::take(inner_p.read_buf), mem::take(inner_p.read_buf),
); );
parts.write_buf = std::mem::take(inner_p.write_buf); parts.write_buf = mem::take(inner_p.write_buf);
let framed = Framed::from_parts(parts); let framed = Framed::from_parts(parts);
let upgrade = let upgrade =
inner_p.upgrade.take().unwrap().call((req, framed)); inner_p.upgrade.take().unwrap().call((req, framed));
@ -937,7 +940,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::str; use std::{marker::PhantomData, str};
use actix_service::fn_service; use actix_service::fn_service;
use futures_util::future::{lazy, ready}; use futures_util::future::{lazy, ready};
@ -1263,4 +1266,38 @@ mod tests {
}) })
.await; .await;
} }
#[actix_rt::test]
async fn test_upgrade() {
lazy(|cx| {
let mut buf = TestSeqBuffer::empty();
let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None);
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new(
buf.clone(),
cfg,
CloneableService::new(ok_service()),
CloneableService::new(ExpectHandler),
Some(CloneableService::new(UpgradeHandler(PhantomData))),
None,
Extensions::new(),
None,
);
buf.extend_read_buf(
"\
GET /ws HTTP/1.1\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n\
",
);
assert!(Pin::new(&mut h1).poll(cx).is_ready());
assert!(matches!(&h1.inner, DispatcherState::Upgrade(_)));
// polls: manual shutdown
assert_eq!(h1.poll_count, 2);
})
.await;
}
} }

View File

@ -3,13 +3,13 @@ use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::Ready; use futures_util::future::{ready, Ready};
use crate::error::Error; use crate::error::Error;
use crate::h1::Codec; use crate::h1::Codec;
use crate::request::Request; use crate::request::Request;
pub struct UpgradeHandler<T>(PhantomData<T>); pub struct UpgradeHandler<T>(pub(crate) PhantomData<T>);
impl<T> ServiceFactory for UpgradeHandler<T> { impl<T> ServiceFactory for UpgradeHandler<T> {
type Config = (); type Config = ();
@ -36,6 +36,6 @@ impl<T> Service for UpgradeHandler<T> {
} }
fn call(&mut self, _: Self::Request) -> Self::Future { fn call(&mut self, _: Self::Request) -> Self::Future {
unimplemented!() ready(Ok(()))
} }
} }