Skip to content

Commit

Permalink
Upgrades to new ratchet version
Browse files Browse the repository at this point in the history
  • Loading branch information
SirCipher committed Sep 25, 2024
1 parent c76810b commit 5fb30f0
Show file tree
Hide file tree
Showing 15 changed files with 152 additions and 272 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ syn = "1.0"
quote = "1.0.3"
num-bigint = "0.4"
ratchet = { package = "ratchet_rs", version = "1.0" }
ratchet_fixture = "1.0"
ratchet_core = { version = "1.0" }
ratchet_fixture = { version = "1.0" }
flate2 = "1.0.22"
bitflags = "2.5"
rocksdb = "0.22"
Expand Down
4 changes: 2 additions & 2 deletions example_apps/console/src/runtime/dummy_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use futures::{
};
use parking_lot::RwLock;
use ratchet::{
CloseCode, CloseReason, NoExtDecoder, NoExtEncoder, NoExtProvider, ProtocolRegistry,
CloseCode, CloseReason, NoExtDecoder, NoExtEncoder, NoExtProvider, SubprotocolRegistry,
WebSocketConfig,
};
use swimos_agent_protocol::MapMessage;
Expand Down Expand Up @@ -216,7 +216,7 @@ impl DummyServer {
};
match event {
Event::NewConnection(stream) => {
let subprotocols = ProtocolRegistry::new(vec!["warp0"]).unwrap();
let subprotocols = SubprotocolRegistry::new(vec!["warp0"]).unwrap();
let upgrader = ratchet::accept_with(
stream,
WebSocketConfig::default(),
Expand Down
4 changes: 2 additions & 2 deletions example_apps/console/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use futures::{
use parking_lot::RwLock;
use ratchet::{
CloseCode, CloseReason, ErrorKind, Message, NoExt, NoExtDecoder, NoExtEncoder, NoExtProvider,
ProtocolRegistry, WebSocket, WebSocketConfig,
SubprotocolRegistry, WebSocket, WebSocketConfig,
};
use swimos_messages::warp::{peel_envelope_header_str, RawEnvelope};
use swimos_model::Value;
Expand Down Expand Up @@ -516,7 +516,7 @@ fn into_stream(remote: Host, rx: Rx) -> impl Stream<Item = Result<(Host, String)

async fn open_connection(host: &Host) -> Result<WebSocket<TcpStream, NoExt>, ratchet::Error> {
let socket = TcpStream::connect(&host.host_only()).await?;
let subprotocols = ProtocolRegistry::new(vec!["warp0"]).unwrap();
let subprotocols = SubprotocolRegistry::new(vec!["warp0"]).unwrap();
let r = ratchet::subscribe_with(
WebSocketConfig::default(),
socket,
Expand Down
1 change: 1 addition & 0 deletions runtime/swimos_http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ homepage.workspace = true
[dependencies]
futures = { workspace = true }
ratchet = { workspace = true }
ratchet_core = { workspace = true }
hyper = { workspace = true }
http = { workspace = true }
httparse = { workspace = true }
Expand Down
3 changes: 1 addition & 2 deletions runtime/swimos_http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,5 @@
mod websocket;

pub use websocket::{
fail_upgrade, negotiate_upgrade, upgrade, Negotiated, NoUnwrap, SockUnwrap, UpgradeError,
UpgradeFuture,
fail_upgrade, negotiate_upgrade, upgrade, NoUnwrap, SockUnwrap, UpgradeFuture, UpgradeStatus,
};
167 changes: 36 additions & 131 deletions runtime/swimos_http/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,35 @@
// limitations under the License.

use std::{
collections::HashSet,
pin::Pin,
task::{Context, Poll},
};

use base64::{engine::general_purpose::STANDARD, Engine};
use bytes::{Bytes, BytesMut};
use futures::{ready, Future, FutureExt};
use http::{header::HeaderName, HeaderMap, HeaderValue, Method};
use http_body_util::Full;
use httparse::Header;
use hyper::body::Incoming;
use hyper::{
upgrade::{OnUpgrade, Upgraded},
Request, Response,
};
use hyper_util::rt::TokioIo;
use ratchet::{
Extension, ExtensionProvider, NegotiatedExtension, Role, WebSocket, WebSocketConfig,
Extension, ExtensionProvider, Role, SubprotocolRegistry, WebSocket, WebSocketConfig,
};
use sha1::{Digest, Sha1};
use thiserror::Error;
use ratchet_core::server::{build_response, parse_request, UpgradeRequest};

const UPGRADE_STR: &str = "Upgrade";
const WEBSOCKET_STR: &str = "websocket";
const WEBSOCKET_VERSION_STR: &str = "13";
const ACCEPT_KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const FAILED_RESPONSE: &str = "Building response should not fail.";

/// Result of a successful websocket negotiation.
pub struct Negotiated<'a, Ext> {
pub protocol: Option<&'a str>,
pub extension: Option<(Ext, HeaderValue)>,
pub key: Bytes,
pub enum UpgradeStatus<E, T> {
Upgradeable {
result: Result<UpgradeRequest<E, T>, ratchet::Error>,
},
NotRequested {
request: Request<T>,
},
}

/// Attempt to negotiate a websocket upgrade on a hyper request. If [`Ok(None)`] is returned,
Expand All @@ -57,10 +52,10 @@ pub struct Negotiated<'a, Ext> {
/// * `protocols` - The supported protocols for the negotiation.
/// * `extension_provider` - The extension provider (for example compression support).
pub fn negotiate_upgrade<'a, T, E>(
request: &Request<T>,
protocols: &'a HashSet<&str>,
request: Request<T>,
registry: &SubprotocolRegistry,
extension_provider: &E,
) -> Result<Option<Negotiated<'a, E::Extension>>, UpgradeError<E::Error>>
) -> UpgradeStatus<E::Extension, T>
where
E: ExtensionProvider,
{
Expand All @@ -69,48 +64,16 @@ where
let has_upgrade = headers_contains(headers, http::header::UPGRADE, WEBSOCKET_STR);

if request.method() == Method::GET && has_conn && has_upgrade {
if !headers_contains(
headers,
http::header::SEC_WEBSOCKET_VERSION,
WEBSOCKET_VERSION_STR,
) {
return Err(UpgradeError::InvalidWebsocketVersion);
UpgradeStatus::Upgradeable {
result: parse_request(request, extension_provider, registry),
}

let key = if let Some(key) = headers
.get(http::header::SEC_WEBSOCKET_KEY)
.map(|v| Bytes::from(trim(v.as_bytes()).to_vec()))
{
key
} else {
return Err(UpgradeError::NoKey);
};

let protocol = headers
.get_all(http::header::SEC_WEBSOCKET_PROTOCOL)
.iter()
.flat_map(|h| h.as_bytes().split(|c| *c == b' ' || *c == b','))
.map(trim)
.filter_map(|b| std::str::from_utf8(b).ok())
.find_map(|p| protocols.get(p).copied());

let ext_headers = extension_headers(headers);

let extension = extension_provider.negotiate_server(&ext_headers)?;
Ok(Some(Negotiated {
protocol,
extension,
key,
}))
} else {
Ok(None)
UpgradeStatus::NotRequested { request }
}
}

/// Produce a bad request response for a bad websocket upgrade request.
pub fn fail_upgrade<ExtErr: std::error::Error>(
error: UpgradeError<ExtErr>,
) -> Response<Full<Bytes>> {
pub fn fail_upgrade(error: ratchet::Error) -> Response<Full<Bytes>> {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Full::from(error.to_string()))
Expand All @@ -120,110 +83,52 @@ pub fn fail_upgrade<ExtErr: std::error::Error>(
/// Upgrade a hyper request to a websocket, based on a successful negotiation.
///
/// # Arguments
/// * `request` - The hyper HTTP request.
/// * `negotiated` - Negotiated parameters for the websocket connection.
/// * `request` - The upgrade request request.
/// * `config` - Websocket configuration parameters.
/// * `unwrap_fn` - Used to unwrap the underlying socket type from the opaque [`Upgraded`] socket
/// provided by hyper.
pub fn upgrade<Ext, U>(
request: Request<Incoming>,
negotiated: Negotiated<'_, Ext>,
pub fn upgrade<Ext, U, B>(
request: UpgradeRequest<Ext, B>,
config: Option<WebSocketConfig>,
unwrap_fn: U,
) -> (Response<Full<Bytes>>, UpgradeFuture<Ext, U>)
) -> Result<(Response<Full<Bytes>>, UpgradeFuture<Ext, U>), ratchet::Error>
where
Ext: Extension + Send,
{
let Negotiated {
protocol,
extension,
let UpgradeRequest {
key,
} = negotiated;
let mut digest = Sha1::new();
Digest::update(&mut digest, key);
Digest::update(&mut digest, ACCEPT_KEY);

let sec_websocket_accept = STANDARD.encode(digest.finalize());
let mut builder = Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sec_websocket_accept)
.header(http::header::CONNECTION, UPGRADE_STR)
.header(http::header::UPGRADE, WEBSOCKET_STR);

if let Some(protocol) = protocol {
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
let ext = match extension {
Some((ext, header)) => {
builder = builder.header(http::header::SEC_WEBSOCKET_EXTENSIONS, header);
Some(ext)
}
None => None,
};
subprotocol,
extension,
extension_header,
request,
..
} = request;
let response = build_response(key, subprotocol, extension_header)?;
let (parts, _body) = response.into_parts();
let fut = UpgradeFuture {
upgrade: hyper::upgrade::on(request),
config: config.unwrap_or_default(),
extension: ext,
extension,
unwrap_fn,
};

let response = builder.body(Full::default()).expect(FAILED_RESPONSE);
(response, fut)
}

fn extension_headers(headers: &HeaderMap) -> Vec<Header<'_>> {
headers
.iter()
.map(|(name, value)| Header {
name: name.as_str(),
value: value.as_bytes(),
})
.collect()
Ok((Response::from_parts(parts, Full::default()), fut))
}

fn headers_contains(headers: &HeaderMap, name: HeaderName, value: &str) -> bool {
headers.get_all(name).iter().any(header_contains(value))
}

fn header_contains(content: &str) -> impl Fn(&HeaderValue) -> bool + '_ {
|header| {
move |header| {
header
.as_bytes()
.split(|c| *c == b' ' || *c == b',')
.map(trim)
.any(|s| s.eq_ignore_ascii_case(content.as_bytes()))
.split(|&c| c == b' ' || c == b',')
.map(|s| std::str::from_utf8(s).unwrap_or("").trim())
.any(|s| s.eq_ignore_ascii_case(content))
}
}

fn trim(bytes: &[u8]) -> &[u8] {
let not_ws = |b: &u8| !b.is_ascii_whitespace();
let start = bytes.iter().position(not_ws);
let end = bytes.iter().rposition(not_ws);
match (start, end) {
(Some(s), Some(e)) => &bytes[s..e + 1],
_ => &[],
}
}

/// Reasons that a websocket upgrade request could fail.
#[derive(Debug, Error, Clone, Copy)]
pub enum UpgradeError<ExtErr: std::error::Error> {
/// An invalid websocket version was specified.
#[error("Invalid websocket version specified.")]
InvalidWebsocketVersion,
/// No websocket key was provided.
#[error("No websocket key provided.")]
NoKey,
/// The headers provided for the websocket extension were not valid.
#[error("Invalid extension headers: {0}")]
ExtensionError(ExtErr),
}

impl<ExtErr: std::error::Error> From<ExtErr> for UpgradeError<ExtErr> {
fn from(err: ExtErr) -> Self {
UpgradeError::ExtensionError(err)
}
}
/// Trait for unwrapping the concrete type of an upgraded socket.
/// Upon a connection upgrade, hyper returns the upgraded socket indirected through a trait object.
/// The caller will generally know the real underlying type and this allows for that type to be
Expand Down Expand Up @@ -275,7 +180,7 @@ where
Poll::Ready(Ok(WebSocket::from_upgraded(
std::mem::take(config),
upgraded,
NegotiatedExtension::from(extension.take()),
extension.take(),
prefix,
Role::Server,
)))
Expand Down
40 changes: 29 additions & 11 deletions runtime/swimos_http/tests/wsserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{
sync::Arc,
time::Duration,
};
use swimos_http::NoUnwrap;
use swimos_http::{NoUnwrap, UpgradeStatus};

use futures::{
channel::oneshot,
Expand All @@ -34,7 +34,10 @@ use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use hyper_util::server::graceful::GracefulShutdown;
use ratchet::{CloseCode, CloseReason, Message, NoExt, NoExtProvider, PayloadType, WebSocket};
use ratchet::{
CloseCode, CloseReason, Message, NoExt, NoExtProvider, PayloadType, SubprotocolRegistry,
WebSocket,
};
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::{net::TcpSocket, sync::Notify};
Expand All @@ -49,8 +52,15 @@ async fn run_server(

let (io, _) = listener.accept().await?;
let builder = Builder::new(TokioExecutor::new());
let connection =
builder.serve_connection_with_upgrades(TokioIo::new(io), service_fn(upgrade_server));

let connection = builder.serve_connection_with_upgrades(
TokioIo::new(io),
service_fn(move |req| {
let registry =
SubprotocolRegistry::new(["warp0"]).expect("Failed to build subprotocol registry");
async move { upgrade_server(req, &registry).await }
}),
);
let shutdown = GracefulShutdown::new();

let server = pin!(shutdown.watch(connection));
Expand All @@ -70,16 +80,24 @@ async fn run_server(

async fn upgrade_server(
request: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, hyper::http::Error> {
let protocols = ["warp0"].into_iter().collect();
match swimos_http::negotiate_upgrade(&request, &protocols, &NoExtProvider) {
Ok(Some(negotiated)) => {
let (response, upgraded) = swimos_http::upgrade(request, negotiated, None, NoUnwrap);
registry: &SubprotocolRegistry,
) -> Result<Response<Full<Bytes>>, ratchet::Error> {
match swimos_http::negotiate_upgrade(request, registry, &NoExtProvider) {
UpgradeStatus::Upgradeable { result: Ok(result) } => {
let (response, upgraded) = swimos_http::upgrade(result, None, NoUnwrap)?;
tokio::spawn(run_websocket(upgraded));
Ok(response)
}
Ok(None) => Response::builder().body(Full::from("Success")),
Err(err) => Ok(swimos_http::fail_upgrade(err)),
UpgradeStatus::Upgradeable { result: Err(err) } => {
if err.is_io() {
Err(err)
} else {
Ok(swimos_http::fail_upgrade(err))
}
}
UpgradeStatus::NotRequested { request: _ } => Response::builder()
.body(Full::from("Success"))
.map_err(Into::into),
}
}

Expand Down
Loading

0 comments on commit 5fb30f0

Please sign in to comment.