diff --git a/Cargo.toml b/Cargo.toml index fac484459..6d04e1da8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,8 @@ syn = "1.0" quote = "1.0.3" num-bigint = "0.4" ratchet = { package = "ratchet_rs", version = "1.0" } -ratchet_fixture = "1.0" +ratchet_core = { version = "1.0" } +ratchet_fixture = { version = "1.0" } flate2 = "1.0.22" bitflags = "2.5" rocksdb = "0.22" diff --git a/example_apps/console/src/runtime/dummy_server.rs b/example_apps/console/src/runtime/dummy_server.rs index e9e77fd68..62a4453cd 100644 --- a/example_apps/console/src/runtime/dummy_server.rs +++ b/example_apps/console/src/runtime/dummy_server.rs @@ -27,7 +27,7 @@ use futures::{ }; use parking_lot::RwLock; use ratchet::{ - CloseCode, CloseReason, NoExtDecoder, NoExtEncoder, NoExtProvider, ProtocolRegistry, + CloseCode, CloseReason, NoExtDecoder, NoExtEncoder, NoExtProvider, SubprotocolRegistry, WebSocketConfig, }; use swimos_agent_protocol::MapMessage; @@ -216,7 +216,7 @@ impl DummyServer { }; match event { Event::NewConnection(stream) => { - let subprotocols = ProtocolRegistry::new(vec!["warp0"]).unwrap(); + let subprotocols = SubprotocolRegistry::new(vec!["warp0"]).unwrap(); let upgrader = ratchet::accept_with( stream, WebSocketConfig::default(), diff --git a/example_apps/console/src/runtime/mod.rs b/example_apps/console/src/runtime/mod.rs index 2b1dada79..86430956b 100644 --- a/example_apps/console/src/runtime/mod.rs +++ b/example_apps/console/src/runtime/mod.rs @@ -25,7 +25,7 @@ use futures::{ use parking_lot::RwLock; use ratchet::{ CloseCode, CloseReason, ErrorKind, Message, NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider, - ProtocolRegistry, WebSocket, WebSocketConfig, + SubprotocolRegistry, WebSocket, WebSocketConfig, }; use swimos_messages::warp::{peel_envelope_header_str, RawEnvelope}; use swimos_model::Value; @@ -516,7 +516,7 @@ fn into_stream(remote: Host, rx: Rx) -> impl Stream Result, ratchet::Error> { let socket = TcpStream::connect(&host.host_only()).await?; - let subprotocols = ProtocolRegistry::new(vec!["warp0"]).unwrap(); + let subprotocols = SubprotocolRegistry::new(vec!["warp0"]).unwrap(); let r = ratchet::subscribe_with( WebSocketConfig::default(), socket, diff --git a/runtime/swimos_http/Cargo.toml b/runtime/swimos_http/Cargo.toml index 2fe72ed31..6ebea5f96 100644 --- a/runtime/swimos_http/Cargo.toml +++ b/runtime/swimos_http/Cargo.toml @@ -11,6 +11,7 @@ homepage.workspace = true [dependencies] futures = { workspace = true } ratchet = { workspace = true } +ratchet_core = { workspace = true } hyper = { workspace = true } http = { workspace = true } httparse = { workspace = true } diff --git a/runtime/swimos_http/src/lib.rs b/runtime/swimos_http/src/lib.rs index f5737c53c..1f5cf4efc 100644 --- a/runtime/swimos_http/src/lib.rs +++ b/runtime/swimos_http/src/lib.rs @@ -20,6 +20,5 @@ mod websocket; pub use websocket::{ - fail_upgrade, negotiate_upgrade, upgrade, Negotiated, NoUnwrap, SockUnwrap, UpgradeError, - UpgradeFuture, + fail_upgrade, negotiate_upgrade, upgrade, NoUnwrap, SockUnwrap, UpgradeFuture, UpgradeStatus, }; diff --git a/runtime/swimos_http/src/websocket.rs b/runtime/swimos_http/src/websocket.rs index f7052e6b2..d6e6a0919 100644 --- a/runtime/swimos_http/src/websocket.rs +++ b/runtime/swimos_http/src/websocket.rs @@ -13,40 +13,35 @@ // limitations under the License. use std::{ - collections::HashSet, pin::Pin, task::{Context, Poll}, }; -use base64::{engine::general_purpose::STANDARD, Engine}; use bytes::{Bytes, BytesMut}; use futures::{ready, Future, FutureExt}; use http::{header::HeaderName, HeaderMap, HeaderValue, Method}; use http_body_util::Full; -use httparse::Header; -use hyper::body::Incoming; use hyper::{ upgrade::{OnUpgrade, Upgraded}, Request, Response, }; use hyper_util::rt::TokioIo; use ratchet::{ - Extension, ExtensionProvider, NegotiatedExtension, Role, WebSocket, WebSocketConfig, + Extension, ExtensionProvider, Role, SubprotocolRegistry, WebSocket, WebSocketConfig, }; -use sha1::{Digest, Sha1}; -use thiserror::Error; +use ratchet_core::server::{build_response, parse_request, UpgradeRequest}; const UPGRADE_STR: &str = "Upgrade"; const WEBSOCKET_STR: &str = "websocket"; -const WEBSOCKET_VERSION_STR: &str = "13"; -const ACCEPT_KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; const FAILED_RESPONSE: &str = "Building response should not fail."; -/// Result of a successful websocket negotiation. -pub struct Negotiated<'a, Ext> { - pub protocol: Option<&'a str>, - pub extension: Option<(Ext, HeaderValue)>, - pub key: Bytes, +pub enum UpgradeStatus { + Upgradeable { + result: Result, ratchet::Error>, + }, + NotRequested { + request: Request, + }, } /// Attempt to negotiate a websocket upgrade on a hyper request. If [`Ok(None)`] is returned, @@ -57,10 +52,10 @@ pub struct Negotiated<'a, Ext> { /// * `protocols` - The supported protocols for the negotiation. /// * `extension_provider` - The extension provider (for example compression support). pub fn negotiate_upgrade<'a, T, E>( - request: &Request, - protocols: &'a HashSet<&str>, + request: Request, + registry: &SubprotocolRegistry, extension_provider: &E, -) -> Result>, UpgradeError> +) -> UpgradeStatus where E: ExtensionProvider, { @@ -69,48 +64,16 @@ where let has_upgrade = headers_contains(headers, http::header::UPGRADE, WEBSOCKET_STR); if request.method() == Method::GET && has_conn && has_upgrade { - if !headers_contains( - headers, - http::header::SEC_WEBSOCKET_VERSION, - WEBSOCKET_VERSION_STR, - ) { - return Err(UpgradeError::InvalidWebsocketVersion); + UpgradeStatus::Upgradeable { + result: parse_request(request, extension_provider, registry), } - - let key = if let Some(key) = headers - .get(http::header::SEC_WEBSOCKET_KEY) - .map(|v| Bytes::from(trim(v.as_bytes()).to_vec())) - { - key - } else { - return Err(UpgradeError::NoKey); - }; - - let protocol = headers - .get_all(http::header::SEC_WEBSOCKET_PROTOCOL) - .iter() - .flat_map(|h| h.as_bytes().split(|c| *c == b' ' || *c == b',')) - .map(trim) - .filter_map(|b| std::str::from_utf8(b).ok()) - .find_map(|p| protocols.get(p).copied()); - - let ext_headers = extension_headers(headers); - - let extension = extension_provider.negotiate_server(&ext_headers)?; - Ok(Some(Negotiated { - protocol, - extension, - key, - })) } else { - Ok(None) + UpgradeStatus::NotRequested { request } } } /// Produce a bad request response for a bad websocket upgrade request. -pub fn fail_upgrade( - error: UpgradeError, -) -> Response> { +pub fn fail_upgrade(error: ratchet::Error) -> Response> { Response::builder() .status(http::StatusCode::BAD_REQUEST) .body(Full::from(error.to_string())) @@ -120,65 +83,36 @@ pub fn fail_upgrade( /// Upgrade a hyper request to a websocket, based on a successful negotiation. /// /// # Arguments -/// * `request` - The hyper HTTP request. -/// * `negotiated` - Negotiated parameters for the websocket connection. +/// * `request` - The upgrade request request. /// * `config` - Websocket configuration parameters. /// * `unwrap_fn` - Used to unwrap the underlying socket type from the opaque [`Upgraded`] socket /// provided by hyper. -pub fn upgrade( - request: Request, - negotiated: Negotiated<'_, Ext>, +pub fn upgrade( + request: UpgradeRequest, config: Option, unwrap_fn: U, -) -> (Response>, UpgradeFuture) +) -> Result<(Response>, UpgradeFuture), ratchet::Error> where Ext: Extension + Send, { - let Negotiated { - protocol, - extension, + let UpgradeRequest { key, - } = negotiated; - let mut digest = Sha1::new(); - Digest::update(&mut digest, key); - Digest::update(&mut digest, ACCEPT_KEY); - - let sec_websocket_accept = STANDARD.encode(digest.finalize()); - let mut builder = Response::builder() - .status(http::StatusCode::SWITCHING_PROTOCOLS) - .header(http::header::SEC_WEBSOCKET_ACCEPT, sec_websocket_accept) - .header(http::header::CONNECTION, UPGRADE_STR) - .header(http::header::UPGRADE, WEBSOCKET_STR); - - if let Some(protocol) = protocol { - builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); - } - let ext = match extension { - Some((ext, header)) => { - builder = builder.header(http::header::SEC_WEBSOCKET_EXTENSIONS, header); - Some(ext) - } - None => None, - }; + subprotocol, + extension, + extension_header, + request, + .. + } = request; + let response = build_response(key, subprotocol, extension_header)?; + let (parts, _body) = response.into_parts(); let fut = UpgradeFuture { upgrade: hyper::upgrade::on(request), config: config.unwrap_or_default(), - extension: ext, + extension, unwrap_fn, }; - let response = builder.body(Full::default()).expect(FAILED_RESPONSE); - (response, fut) -} - -fn extension_headers(headers: &HeaderMap) -> Vec> { - headers - .iter() - .map(|(name, value)| Header { - name: name.as_str(), - value: value.as_bytes(), - }) - .collect() + Ok((Response::from_parts(parts, Full::default()), fut)) } fn headers_contains(headers: &HeaderMap, name: HeaderName, value: &str) -> bool { @@ -186,44 +120,15 @@ fn headers_contains(headers: &HeaderMap, name: HeaderName, value: &str) -> bool } fn header_contains(content: &str) -> impl Fn(&HeaderValue) -> bool + '_ { - |header| { + move |header| { header .as_bytes() - .split(|c| *c == b' ' || *c == b',') - .map(trim) - .any(|s| s.eq_ignore_ascii_case(content.as_bytes())) + .split(|&c| c == b' ' || c == b',') + .map(|s| std::str::from_utf8(s).unwrap_or("").trim()) + .any(|s| s.eq_ignore_ascii_case(content)) } } -fn trim(bytes: &[u8]) -> &[u8] { - let not_ws = |b: &u8| !b.is_ascii_whitespace(); - let start = bytes.iter().position(not_ws); - let end = bytes.iter().rposition(not_ws); - match (start, end) { - (Some(s), Some(e)) => &bytes[s..e + 1], - _ => &[], - } -} - -/// Reasons that a websocket upgrade request could fail. -#[derive(Debug, Error, Clone, Copy)] -pub enum UpgradeError { - /// An invalid websocket version was specified. - #[error("Invalid websocket version specified.")] - InvalidWebsocketVersion, - /// No websocket key was provided. - #[error("No websocket key provided.")] - NoKey, - /// The headers provided for the websocket extension were not valid. - #[error("Invalid extension headers: {0}")] - ExtensionError(ExtErr), -} - -impl From for UpgradeError { - fn from(err: ExtErr) -> Self { - UpgradeError::ExtensionError(err) - } -} /// Trait for unwrapping the concrete type of an upgraded socket. /// Upon a connection upgrade, hyper returns the upgraded socket indirected through a trait object. /// The caller will generally know the real underlying type and this allows for that type to be @@ -275,7 +180,7 @@ where Poll::Ready(Ok(WebSocket::from_upgraded( std::mem::take(config), upgraded, - NegotiatedExtension::from(extension.take()), + extension.take(), prefix, Role::Server, ))) diff --git a/runtime/swimos_http/tests/wsserver.rs b/runtime/swimos_http/tests/wsserver.rs index a7420f912..f161ea72a 100644 --- a/runtime/swimos_http/tests/wsserver.rs +++ b/runtime/swimos_http/tests/wsserver.rs @@ -22,7 +22,7 @@ use std::{ sync::Arc, time::Duration, }; -use swimos_http::NoUnwrap; +use swimos_http::{NoUnwrap, UpgradeStatus}; use futures::{ channel::oneshot, @@ -34,7 +34,10 @@ use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server::conn::auto::Builder; use hyper_util::server::graceful::GracefulShutdown; -use ratchet::{CloseCode, CloseReason, Message, NoExt, NoExtProvider, PayloadType, WebSocket}; +use ratchet::{ + CloseCode, CloseReason, Message, NoExt, NoExtProvider, PayloadType, SubprotocolRegistry, + WebSocket, +}; use thiserror::Error; use tokio::net::TcpListener; use tokio::{net::TcpSocket, sync::Notify}; @@ -49,8 +52,15 @@ async fn run_server( let (io, _) = listener.accept().await?; let builder = Builder::new(TokioExecutor::new()); - let connection = - builder.serve_connection_with_upgrades(TokioIo::new(io), service_fn(upgrade_server)); + + let connection = builder.serve_connection_with_upgrades( + TokioIo::new(io), + service_fn(move |req| { + let registry = + SubprotocolRegistry::new(["warp0"]).expect("Failed to build subprotocol registry"); + async move { upgrade_server(req, ®istry).await } + }), + ); let shutdown = GracefulShutdown::new(); let server = pin!(shutdown.watch(connection)); @@ -70,16 +80,24 @@ async fn run_server( async fn upgrade_server( request: Request, -) -> Result>, hyper::http::Error> { - let protocols = ["warp0"].into_iter().collect(); - match swimos_http::negotiate_upgrade(&request, &protocols, &NoExtProvider) { - Ok(Some(negotiated)) => { - let (response, upgraded) = swimos_http::upgrade(request, negotiated, None, NoUnwrap); + registry: &SubprotocolRegistry, +) -> Result>, ratchet::Error> { + match swimos_http::negotiate_upgrade(request, registry, &NoExtProvider) { + UpgradeStatus::Upgradeable { result: Ok(result) } => { + let (response, upgraded) = swimos_http::upgrade(result, None, NoUnwrap)?; tokio::spawn(run_websocket(upgraded)); Ok(response) } - Ok(None) => Response::builder().body(Full::from("Success")), - Err(err) => Ok(swimos_http::fail_upgrade(err)), + UpgradeStatus::Upgradeable { result: Err(err) } => { + if err.is_io() { + Err(err) + } else { + Ok(swimos_http::fail_upgrade(err)) + } + } + UpgradeStatus::NotRequested { request: _ } => Response::builder() + .body(Full::from("Success")) + .map_err(Into::into), } } diff --git a/runtime/swimos_remote/src/task/tests.rs b/runtime/swimos_remote/src/task/tests.rs index 7dc35aaef..7da57472f 100644 --- a/runtime/swimos_remote/src/task/tests.rs +++ b/runtime/swimos_remote/src/task/tests.rs @@ -20,8 +20,8 @@ use futures::{ Future, SinkExt, StreamExt, }; use ratchet::{ - CloseCode, CloseReason, Message, NegotiatedExtension, NoExt, NoExtDecoder, Receiver, Role, - WebSocket, WebSocketConfig, + CloseCode, CloseReason, Message, NoExt, NoExtDecoder, Receiver, Role, WebSocket, + WebSocketConfig, }; use swimos_api::address::RelativeAddress; use swimos_messages::{ @@ -607,20 +607,8 @@ fn make_fake_ws() -> ( let (server, client) = duplex(BUFFER_SIZE.get()); let config = WebSocketConfig::default(); - let server = WebSocket::from_upgraded( - config, - server, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Server, - ); - let client = WebSocket::from_upgraded( - config, - client, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Client, - ); + let server = WebSocket::from_upgraded(config, server, None, BytesMut::new(), Role::Server); + let client = WebSocket::from_upgraded(config, client, None, BytesMut::new(), Role::Client); (server, client) } @@ -766,20 +754,10 @@ where let (server, client) = duplex(BUFFER_SIZE.get()); let config = WebSocketConfig::default(); - let server = WebSocket::from_upgraded( - config, - server, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Server, - ); - let client = WebSocket::from_upgraded( - config, - client, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Client, - ); + let server = + WebSocket::from_upgraded(config, server, Some(NoExt), BytesMut::new(), Role::Server); + let client = + WebSocket::from_upgraded(config, client, Some(NoExt), BytesMut::new(), Role::Client); let (mut server_tx, server_rx) = server.split().expect("Split failed."); @@ -1010,20 +988,10 @@ where let (server, client) = duplex(BUFFER_SIZE.get()); let config = WebSocketConfig::default(); - let server = WebSocket::from_upgraded( - config, - server, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Server, - ); - let client = WebSocket::from_upgraded( - config, - client, - NegotiatedExtension::from(NoExt), - BytesMut::new(), - Role::Client, - ); + let server = + WebSocket::from_upgraded(config, server, Some(NoExt), BytesMut::new(), Role::Server); + let client = + WebSocket::from_upgraded(config, client, Some(NoExt), BytesMut::new(), Role::Client); let context = CombinedTestContext { stop_tx: Some(stop_tx), diff --git a/runtime/swimos_remote/src/ws/mod.rs b/runtime/swimos_remote/src/ws/mod.rs index b872e257d..1eb77bba0 100644 --- a/runtime/swimos_remote/src/ws/mod.rs +++ b/runtime/swimos_remote/src/ws/mod.rs @@ -20,7 +20,9 @@ use futures::Stream; use swimos_messages::remote_protocol::FindNode; use swimos_utilities::errors::Recoverable; -use ratchet::{ExtensionProvider, ProtocolRegistry, WebSocket, WebSocketConfig, WebSocketStream}; +use ratchet::{ + ExtensionProvider, SubprotocolRegistry, WebSocket, WebSocketConfig, WebSocketStream, +}; use thiserror::Error; use tokio::sync::mpsc; @@ -111,7 +113,7 @@ impl WebsocketClient for RatchetClient { { let config = self.0; Box::pin(async move { - let subprotocols = ProtocolRegistry::new([WARP])?; + let subprotocols = SubprotocolRegistry::new([WARP])?; let socket = ratchet::subscribe_with(config, socket, addr, provider, subprotocols) .await? .into_websocket(); diff --git a/server/swimos_server_app/Cargo.toml b/server/swimos_server_app/Cargo.toml index 51aabd2fc..05a812816 100644 --- a/server/swimos_server_app/Cargo.toml +++ b/server/swimos_server_app/Cargo.toml @@ -19,6 +19,7 @@ aws_lc_rs_provider = ["swimos_remote/aws_lc_rs_provider"] [dependencies] futures = { workspace = true } ratchet = { workspace = true, features = ["deflate", "split"] } +ratchet_core = { workspace = true } swimos_utilities = { workspace = true, features = ["io", "trigger", "text", "time"] } swimos_runtime = { workspace = true } swimos_messages = { workspace = true } diff --git a/server/swimos_server_app/src/server/http/mod.rs b/server/swimos_server_app/src/server/http/mod.rs index 0384cfd65..ec61a6df5 100644 --- a/server/swimos_server_app/src/server/http/mod.rs +++ b/server/swimos_server_app/src/server/http/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use self::resolver::Resolver; +use crate::config::HttpConfig; use bytes::{Bytes, BytesMut}; use futures::{ future::BoxFuture, @@ -32,22 +34,23 @@ use hyper::{ use hyper_util::rt::TokioIo; use pin_project::pin_project; use ratchet::{ - Extension, ExtensionProvider, ProtocolRegistry, WebSocket, WebSocketConfig, WebSocketStream, + Error, Extension, ExtensionProvider, SubprotocolRegistry, WebSocket, WebSocketConfig, + WebSocketStream, }; +use ratchet_core::server::UpgradeRequest; use std::{ - collections::HashSet, marker::PhantomData, net::SocketAddr, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, - Arc, OnceLock, + Arc, }, task::{Context, Poll}, time::{Duration, Instant}, }; use swimos_api::{agent::HttpLaneRequest, http::HttpRequest}; -use swimos_http::{Negotiated, SockUnwrap, UpgradeError, UpgradeFuture}; +use swimos_http::{SockUnwrap, UpgradeFuture, UpgradeStatus}; use swimos_messages::remote_protocol::{AgentResolutionError, FindNode, NoSuchAgent}; use swimos_remote::{ websocket::{RatchetError, WebsocketClient, WebsocketServer, WsOpenFuture, WARP}, @@ -59,10 +62,6 @@ use tokio::{ time::{sleep, Sleep}, }; -use crate::config::HttpConfig; - -use self::resolver::Resolver; - mod resolver; #[cfg(test)] mod tests; @@ -373,31 +372,24 @@ where type BytesHyperResult = Result>, hyper::Error>; /// Perform the websocket negotiation and assign the upgrade future to the target parameter. -fn perform_upgrade( - request: Request, +fn perform_upgrade( config: WebSocketConfig, - result: Result, UpgradeError>, + result: Result, Error>, scheme: Scheme, addr: SocketAddr, ) -> (BytesHyperResult, Option>) where Sock: Send + 'static, Ext: Extension + Send, - Err: std::error::Error + Send, { + let result = result.and_then(|result| { + swimos_http::upgrade(result, Some(config), ReclaimSock::::default()) + }); match result { - Ok(negotiated) => { - let (response, upgrade_fut) = swimos_http::upgrade( - request, - negotiated, - Some(config), - ReclaimSock::::default(), - ); - ( - Ok(response), - Some(UpgradeFutureWithSock::new(upgrade_fut, scheme, addr)), - ) - } + Ok((response, upgrade_fut)) => ( + Ok(response), + Some(UpgradeFutureWithSock::new(upgrade_fut, scheme, addr)), + ), Err(err) => (Ok(swimos_http::fail_upgrade(err)), None), } } @@ -410,6 +402,7 @@ struct Upgrader { config: WebSocketConfig, request_timeout: Duration, upgrade_tx: mpsc::Sender>, + subprotocol_registry: SubprotocolRegistry, } impl Upgrader @@ -430,6 +423,8 @@ where config, request_timeout, upgrade_tx, + subprotocol_registry: SubprotocolRegistry::new(["warp0"]) + .expect("Failed to build Ratchet Subprotocol Registry"), } } @@ -444,6 +439,7 @@ where config, request_timeout, upgrade_tx, + subprotocol_registry, } = self; UpgradeService::new( extension_provider.clone(), @@ -453,6 +449,7 @@ where addr, *request_timeout, upgrade_tx.clone(), + subprotocol_registry.clone(), ) } } @@ -468,6 +465,7 @@ struct UpgradeService { addr: SocketAddr, request_timeout: Duration, did_upgrade: AtomicBool, + subprotocol_registry: SubprotocolRegistry, } impl UpgradeService @@ -482,6 +480,7 @@ where addr: SocketAddr, request_timeout: Duration, upgrade_tx: mpsc::Sender>, + subprotocol_registry: SubprotocolRegistry, ) -> Self { UpgradeService { extension_provider, @@ -492,20 +491,11 @@ where addr, request_timeout, did_upgrade: AtomicBool::new(false), + subprotocol_registry, } } } -static PROTOCOLS: OnceLock> = OnceLock::new(); - -fn warp_protocol() -> &'static HashSet<&'static str> { - PROTOCOLS.get_or_init(|| { - let mut s = HashSet::new(); - s.insert(WARP); - s - }) -} - impl<'a, Ext, Sock> Service> for &'a UpgradeService where Sock: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -528,31 +518,36 @@ where resolver, request_timeout, did_upgrade, + subprotocol_registry, } = *self; - let result = - swimos_http::negotiate_upgrade(&request, warp_protocol(), extension_provider.as_ref()) - .transpose(); + // If the request in a websocket upgrade, perform the upgrade, otherwise attempt to delegate // the request to an HTTP lane on an agent. - if let Some(result) = result { - let (upgrade_result, maybe_fut) = - perform_upgrade(request, *config, result, *scheme, *addr); - did_upgrade.store(true, Ordering::Release); - if let Some(upgrade_fut) = maybe_fut { - let tx = upgrade_tx.clone(); - async move { - // This can only fail if the server is no longer running, in which case it is irrelevant. - let _ = tx.send(upgrade_fut).await; - upgrade_result + match swimos_http::negotiate_upgrade( + request, + subprotocol_registry, + extension_provider.as_ref(), + ) { + UpgradeStatus::Upgradeable { result } => { + let (upgrade_result, maybe_fut) = perform_upgrade(*config, result, *scheme, *addr); + did_upgrade.store(true, Ordering::Release); + if let Some(upgrade_fut) = maybe_fut { + let tx = upgrade_tx.clone(); + async move { + // This can only fail if the server is no longer running, in which case it is irrelevant. + let _ = tx.send(upgrade_fut).await; + upgrade_result + } + .boxed() + } else { + async move { upgrade_result }.boxed() } - .boxed() - } else { - async move { upgrade_result }.boxed() } - } else { - serve_request(request, *request_timeout, resolver.clone()) - .map(Ok) - .boxed() + UpgradeStatus::NotRequested { request } => { + serve_request(request, *request_timeout, resolver.clone()) + .map(Ok) + .boxed() + } } } } @@ -674,7 +669,7 @@ impl WebsocketClient for HyperWebsockets { let config = *config; Box::pin(async move { - let subprotocols = ProtocolRegistry::new([WARP])?; + let subprotocols = SubprotocolRegistry::new([WARP])?; let socket = ratchet::subscribe_with(config.websockets, socket, addr, provider, subprotocols) .await? diff --git a/server/swimos_server_app/src/server/runtime/tests/connections.rs b/server/swimos_server_app/src/server/runtime/tests/connections.rs index 2aebbe1e0..84eddc514 100644 --- a/server/swimos_server_app/src/server/runtime/tests/connections.rs +++ b/server/swimos_server_app/src/server/runtime/tests/connections.rs @@ -21,9 +21,7 @@ use bytes::BytesMut; use futures::future::ready; use futures::stream::BoxStream; use futures::{future::BoxFuture, FutureExt, Stream, StreamExt}; -use ratchet::{ - ExtensionProvider, NegotiatedExtension, Role, WebSocket, WebSocketConfig, WebSocketStream, -}; +use ratchet::{ExtensionProvider, Role, WebSocket, WebSocketConfig, WebSocketStream}; use swimos_messages::remote_protocol::FindNode; use swimos_remote::dns::{DnsFut, DnsResolver}; use swimos_remote::websocket::{RatchetError, WebsocketClient, WebsocketServer, WsOpenFuture}; @@ -78,7 +76,7 @@ impl WebsocketClient for TestWs { ready(Ok(WebSocket::from_upgraded( self.config, socket, - NegotiatedExtension::from(None), + None, BytesMut::new(), Role::Client, ))) @@ -108,13 +106,7 @@ impl WebsocketServer for TestWs { .map(move |result| { result.map(|(sock, _, addr)| { ( - WebSocket::from_upgraded( - config, - sock, - NegotiatedExtension::from(None), - BytesMut::new(), - Role::Server, - ), + WebSocket::from_upgraded(config, sock, None, BytesMut::new(), Role::Server), addr, ) }) diff --git a/server/swimos_server_app/src/server/runtime/tests/mod.rs b/server/swimos_server_app/src/server/runtime/tests/mod.rs index de515e463..9b8bcc8dd 100644 --- a/server/swimos_server_app/src/server/runtime/tests/mod.rs +++ b/server/swimos_server_app/src/server/runtime/tests/mod.rs @@ -24,9 +24,7 @@ use futures::{ future::{join, join3}, Future, }; -use ratchet::{ - Message, NegotiatedExtension, NoExt, NoExtProvider, Role, WebSocket, WebSocketConfig, -}; +use ratchet::{Message, NoExt, NoExtProvider, Role, WebSocket, WebSocketConfig}; use swimos_api::{address::RelativeAddress, persistence::StoreDisabled}; use swimos_form::write::StructuralWritable; use swimos_recon::print_recon_compact; @@ -206,7 +204,7 @@ impl TestClient { ws: WebSocket::from_upgraded( WebSocketConfig::default(), stream, - NegotiatedExtension::from(None), + None, BytesMut::new(), Role::Client, ), diff --git a/swimos_client/src/commander.rs b/swimos_client/src/commander.rs index 2edd3045a..d766dd08f 100644 --- a/swimos_client/src/commander.rs +++ b/swimos_client/src/commander.rs @@ -18,7 +18,7 @@ use bytes::BytesMut; use futures::stream::FuturesUnordered; use futures::StreamExt; use ratchet::{ - CloseCode, CloseReason, NoExt, NoExtProvider, ProtocolRegistry, WebSocket, WebSocketConfig, + CloseCode, CloseReason, NoExt, NoExtProvider, SubprotocolRegistry, WebSocket, WebSocketConfig, }; use swimos_form::write::StructuralWritable; use swimos_recon::print_recon_compact; @@ -121,7 +121,7 @@ async fn open_connection(shp: SchemeHostPort) -> Result Ok(WebSocket::from_upgraded( WebSocketConfig::default(), socket, - NegotiatedExtension::from(None), + None, BytesMut::default(), Role::Client, )), @@ -477,7 +477,7 @@ impl Server { transport: WebSocket::from_upgraded( WebSocketConfig::default(), transport, - NegotiatedExtension::from(NoExt), + Some(NoExt), BytesMut::default(), Role::Server, ), @@ -549,7 +549,7 @@ async fn transport_opens_connection_ok() { let mut ws_server = WebSocket::from_upgraded( WebSocketConfig::default(), server, - NegotiatedExtension::from(NoExt), + Some(NoExt), buf, Role::Server, );