Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove websocket dependency and replace it with async bridging code to tokio-tungstenite #146

Merged
merged 4 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 130 additions & 368 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion engineio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ reqwest = { version = "0.11.8", features = ["blocking", "native-tls"] }
adler32 = "1.2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
websocket = { version = "0.26.2", features = ["sync-ssl"], default-features = false }
http = "0.2.5"
tokio-tungstenite = { version = "0.16.1", features = ["native-tls"] }
tungstenite = "0.16.0"
tokio = "1.0.0"
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
async-trait = "0.1.51"
thiserror = "1.0"
native-tls = "0.2.7"
url = "2.2.2"
Expand Down
95 changes: 95 additions & 0 deletions engineio/src/async_transports/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::{borrow::Cow, str::from_utf8, sync::Arc};

use crate::{error::Result, Error, Packet, PacketId};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tungstenite::Message;

pub(crate) mod transport;
pub(crate) mod websocket;
pub(crate) mod websocket_secure;

/// A general purpose asynchronous websocket transport type. Holds
/// the sender and receiver stream of a websocket connection
/// and implements the common methods `update`, `poll` and `emit`.
pub(crate) struct AsyncWebsocketGeneralTransport {
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
receiver: Arc<Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
}

impl AsyncWebsocketGeneralTransport {
pub(crate) async fn new(
sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) -> Self {
AsyncWebsocketGeneralTransport {
sender: Arc::new(Mutex::new(sender)),
receiver: Arc::new(Mutex::new(receiver)),
}
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) async fn upgrade(&self) -> Result<()> {
let mut receiver = self.receiver.lock().await;
let mut sender = self.sender.lock().await;

sender
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Ping, Bytes::from("probe")),
))?)))
.await?;

let msg = receiver
.next()
.await
.ok_or(Error::IllegalWebsocketUpgrade())??;

if msg.into_data() != Bytes::from(Packet::new(PacketId::Pong, Bytes::from("probe"))) {
return Err(Error::InvalidPacket());
}

sender
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Upgrade, Bytes::from("")),
))?)))
.await?;

Ok(())
}

async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
let mut sender = self.sender.lock().await;

let message = if is_binary_att {
Message::binary(Cow::Borrowed(data.as_ref()))
} else {
Message::text(Cow::Borrowed(std::str::from_utf8(data.as_ref())?))
};

sender.send(message).await?;

Ok(())
}

async fn poll(&self) -> Result<Bytes> {
let mut receiver = self.receiver.lock().await;

let message = receiver.next().await.ok_or(Error::IncompletePacket())??;
if message.is_binary() {
let data = message.into_data();
let mut msg = BytesMut::with_capacity(data.len() + 1);
msg.put_u8(PacketId::Message as u8);
msg.put(data.as_ref());

Ok(msg.freeze())
} else {
Ok(Bytes::from(message.into_data()))
}
}
}
34 changes: 34 additions & 0 deletions engineio/src/async_transports/transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use crate::error::Result;
use adler32::adler32;
use async_trait::async_trait;
use bytes::Bytes;
use std::time::SystemTime;
use url::Url;

#[async_trait]
pub(crate) trait AsyncTransport {
/// Sends a packet to the server. This optionally handles sending of a
/// socketio binary attachment via the boolean attribute `is_binary_att`.
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()>;

/// Performs the server long polling procedure as long as the client is
/// connected. This should run separately at all time to ensure proper
/// response handling from the server.
async fn poll(&self) -> Result<Bytes>;

/// Returns start of the url. ex. http://localhost:2998/engine.io/?EIO=4&transport=polling
/// Must have EIO and transport already set.
async fn base_url(&self) -> Result<Url>;

/// Used to update the base path, like when adding the sid.
async fn set_base_url(&self, base_url: Url) -> Result<()>;

/// Full query address
async fn address(&self) -> Result<Url> {
let reader = format!("{:#?}", SystemTime::now());
let hash = adler32(reader.as_bytes()).unwrap();
let mut url = self.base_url().await?;
url.query_pairs_mut().append_pair("t", &hash.to_string());
Ok(url)
}
}
68 changes: 68 additions & 0 deletions engineio/src/async_transports/websocket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::sync::Arc;

use crate::error::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use tokio::sync::RwLock;
use tokio_tungstenite::connect_async;
use url::Url;

use super::transport::AsyncTransport;
use super::AsyncWebsocketGeneralTransport;

/// An asynchronous websocket transport type.
/// This type only allows for plain websocket
/// connections ("ws://").
pub(crate) struct AsyncWebsocketTransport {
inner: AsyncWebsocketGeneralTransport,
base_url: Arc<RwLock<Url>>,
}

impl AsyncWebsocketTransport {
/// Creates a new instance over a request that might hold additional headers and an URL.
pub async fn new(request: http::request::Request<()>, url: Url) -> Result<Self> {
let (ws_stream, _) = connect_async(request).await?;
let (sen, rec) = ws_stream.split();

let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
Ok(AsyncWebsocketTransport {
inner,
base_url: Arc::new(RwLock::new(url)),
})
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) async fn upgrade(&self) -> Result<()> {
self.inner.upgrade().await
}
}

#[async_trait]
impl AsyncTransport for AsyncWebsocketTransport {
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
self.inner.emit(data, is_binary_att).await
}

async fn poll(&self) -> Result<Bytes> {
self.inner.poll().await
}

async fn base_url(&self) -> Result<Url> {
Ok(self.base_url.read().await.clone())
}

async fn set_base_url(&self, base_url: Url) -> Result<()> {
let mut url = base_url;
if !url
.query_pairs()
.any(|(k, v)| k == "transport" && v == "websocket")
{
url.query_pairs_mut().append_pair("transport", "websocket");
}
url.set_scheme("ws").unwrap();
*self.base_url.write().await = url;
Ok(())
}
}
78 changes: 78 additions & 0 deletions engineio/src/async_transports/websocket_secure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::sync::Arc;

use crate::error::Result;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::StreamExt;
use native_tls::TlsConnector;
use tokio::sync::RwLock;
use tokio_tungstenite::connect_async_tls_with_config;
use tokio_tungstenite::Connector;
use url::Url;

use super::transport::AsyncTransport;
use super::AsyncWebsocketGeneralTransport;

/// An asynchronous websocket transport type.
/// This type only allows for secure websocket
/// connections ("wss://").
pub(crate) struct AsyncWebsocketSecureTransport {
inner: AsyncWebsocketGeneralTransport,
base_url: Arc<RwLock<url::Url>>,
}

impl AsyncWebsocketSecureTransport {
/// Creates a new instance over a request that might hold additional headers, a possible
/// Tls connector and an URL.
pub(crate) async fn new(
request: http::request::Request<()>,
base_url: url::Url,
tls_config: Option<TlsConnector>,
) -> Result<Self> {
let (ws_stream, _) =
connect_async_tls_with_config(request, None, tls_config.map(Connector::NativeTls))
.await?;

let (sen, rec) = ws_stream.split();
let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;

Ok(AsyncWebsocketSecureTransport {
inner,
base_url: Arc::new(RwLock::new(base_url)),
})
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) async fn upgrade(&self) -> Result<()> {
self.inner.upgrade().await
}
}

#[async_trait]
impl AsyncTransport for AsyncWebsocketSecureTransport {
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
self.inner.emit(data, is_binary_att).await
}

async fn poll(&self) -> Result<Bytes> {
self.inner.poll().await
}

async fn base_url(&self) -> Result<Url> {
Ok(self.base_url.read().await.clone())
}

async fn set_base_url(&self, base_url: Url) -> Result<()> {
let mut url = base_url;
if !url
.query_pairs()
.any(|(k, v)| k == "transport" && v == "websocket")
{
url.query_pairs_mut().append_pair("transport", "websocket");
}
url.set_scheme("wss").unwrap();
*self.base_url.write().await = url;
Ok(())
}
}
22 changes: 10 additions & 12 deletions engineio/src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,17 @@ impl ClientBuilder {
/// Build socket with only a websocket transport
pub fn build_websocket(mut self) -> Result<Client> {
// SAFETY: Already a Url
let url = websocket::client::Url::parse(&self.url.to_string())?;
let url = url::Url::parse(&self.url.to_string())?;

let headers: Option<http::HeaderMap> = if let Some(map) = self.headers.clone() {
Some(map.try_into()?)
} else {
None
};

match url.scheme() {
"http" | "ws" => {
let transport = WebsocketTransport::new(
url,
self.headers
.clone()
.map(|headers| headers.try_into().unwrap()),
)?;
let transport = WebsocketTransport::new(url, headers)?;
if self.handshake.is_some() {
transport.upgrade()?;
} else {
Expand All @@ -228,11 +229,8 @@ impl ClientBuilder {
})
}
"https" | "wss" => {
let transport = WebsocketSecureTransport::new(
url,
self.tls_config.clone(),
self.headers.clone().map(|v| v.try_into().unwrap()),
)?;
let transport =
WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?;
if self.handshake.is_some() {
transport.upgrade()?;
} else {
Expand Down
10 changes: 5 additions & 5 deletions engineio/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use serde_json::Error as JsonError;
use std::io::Error as IoError;
use std::str::Utf8Error;
use thiserror::Error;
use tungstenite::Error as TungsteniteError;
use url::ParseError as UrlParseError;
use websocket::{client::ParseError, WebSocketError};

/// Enumeration of all possible errors in the `socket.io` context.
#[derive(Error, Debug)]
Expand All @@ -30,20 +30,20 @@ pub enum Error {
InvalidUrlScheme(String),
#[error("Error during connection via http: {0}")]
IncompleteResponseFromReqwest(#[from] ReqwestError),
#[error("Error with websocket connection: {0}")]
WebsocketError(#[from] TungsteniteError),
#[error("Network request returned with status code: {0}")]
IncompleteHttp(u16),
#[error("Got illegal handshake response: {0}")]
InvalidHandshake(String),
#[error("Called an action before the connection was established")]
IllegalActionBeforeOpen(),
#[error("Error setting up the http request: {0}")]
InvalidHttpConfiguration(#[from] http::Error),
#[error("string is not json serializable: {0}")]
InvalidJson(#[from] JsonError),
#[error("A lock was poisoned")]
InvalidPoisonedLock(),
#[error("Got a websocket error: {0}")]
IncompleteResponseFromWebsocket(#[from] WebSocketError),
#[error("Error while parsing the url for the websocket connection: {0}")]
InvalidWebsocketURL(#[from] ParseError),
#[error("Got an IO-Error: {0}")]
IncompleteIo(#[from] IoError),
#[error("Server did not allow upgrading to websockets")]
Expand Down
Loading