Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement websocket interaction between server and client #8

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
701 changes: 700 additions & 1 deletion Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ rust-remote-shell = { path = "../rust-remote-shell" } # cli depends on rust-remo
clap = { version = "3.2.25", features = ["derive"] }
color-eyre = "0.6.2"
shellwords = "1.1.0"
tokio = "1.28.1"
url = "2.3.1"
tracing = "0.1.37"
tracing-subscriber = "0.3.17"
39 changes: 25 additions & 14 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::ops::Deref;
use std::net::SocketAddr;

use clap::{Parser, Subcommand};

use color_eyre::Result;
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

use rust_remote_shell::{device_server::DeviceServer, sender_client::SenderClient};

/// CLI for a rust remote shell
#[derive(Debug, Parser)]
Expand All @@ -14,25 +17,33 @@ struct Cli {
// these commands can be called from the CLI using lowercase Commands name
#[derive(Subcommand, Debug)]
enum Commands {
/// Execute a command
Command { cmd: String },
/// Make the device listen on a specific IP and port
Listener { addr: SocketAddr },
/// Create a client capable of sending command to a Listener
Sender { listener_addr: url::Url },
}

fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
color_eyre::install()?;

let cli = Cli::parse();
// define a subscriber for logging purposes
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.finish();

match &cli.command {
Commands::Command { cmd } => {
println!("Input command \"{}\"", cmd); // substitute with logging inside the function
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");

// parse the cmd into a slice
let cmd = shellwords::split(cmd.trim())
.map_err(|_| rust_remote_shell::ShellError::MalformedInput)?;
let cli = Cli::parse();

let cmd_out = rust_remote_shell::cmd_from_input(cmd.deref())?;
println!("Command output: {}", cmd_out);
match &cli.command {
Commands::Listener { addr } => {
let device_server = DeviceServer::new(*addr);
device_server.listen().await?
}
Commands::Sender { listener_addr } => {
let mut sender_client = SenderClient::new(listener_addr.clone());
sender_client.connect().await?;
}
}

Expand Down
6 changes: 6 additions & 0 deletions rust-remote-shell/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@ edition = "2021"
[lib]

[dependencies]
futures = "0.3.28"
shellwords = "1.1.0"
thiserror = "1.0.40"
tokio = { version = "1.28.1", features = ["full"] }
tokio-tungstenite = "0.18.0"
tracing = "0.1.37"
url = "2.3.1"
172 changes: 172 additions & 0 deletions rust-remote-shell/src/device_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use std::io::{self};
use std::net::SocketAddr;
use std::string::FromUtf8Error;
use std::sync::Arc;

use futures::{SinkExt, StreamExt, TryStreamExt};
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::error::ProtocolError;
use tokio_tungstenite::tungstenite::Error as TungsteniteError;
use tokio_tungstenite::{accept_async, tungstenite::Message};
use tracing::{error, info, instrument, warn};

use crate::shell::{CommandHandler, ShellError};

#[derive(Error, Debug)]
pub enum DeviceServerError {
#[error("Failed to bind")]
Bind(#[from] io::Error),
#[error("Connected streams should have a peer address")]
PeerAddr,
#[error("Error during the websocket handshake occurred")]
WebSocketHandshake,
#[error("Error while reading the shell command from websocket")]
ReadCommand,
#[error("Error marshaling to UTF8")]
Utf8Error(#[from] FromUtf8Error),
#[error("Trasport error from Tungstenite")]
Transport(#[from] tokio_tungstenite::tungstenite::Error),
#[error("Error while precessing the shell command")]
ShellError(#[from] ShellError),
#[error("Close websocket connection")]
CloseWebsocket,
}

type TxErrorType = tokio::sync::mpsc::Sender<DeviceServerError>;
const MAX_ERRORS_TO_HANDLE: usize = 10;

#[derive(Debug)]
pub struct DeviceServer {
addr: SocketAddr,
}

impl DeviceServer {
pub fn new(addr: SocketAddr) -> Self {
Self { addr }
}

#[instrument(skip(self))]
pub async fn listen(&self) -> Result<(), DeviceServerError> {
let socket = TcpListener::bind(self.addr)
.await
.map_err(DeviceServerError::Bind)?;

info!("Listening at {}", self.addr);

// channel tx/rx to handle error
let (tx_err, mut rx_err) =
tokio::sync::mpsc::channel::<DeviceServerError>(MAX_ERRORS_TO_HANDLE);

let handles = Arc::new(Mutex::new(Vec::new()));
let handles_clone = Arc::clone(&handles);

// accept a new connection
let handle_connections = tokio::spawn(async move {
while let Ok((stream, _)) = socket.accept().await {
let handle_single_connection =
tokio::spawn(Self::handle_connection(stream, tx_err.clone()));

handles_clone.lock().await.push(handle_single_connection);
}
});

// join connections and handle errors
if let Some(err) = rx_err.recv().await {
self.terminate(handle_connections, &handles).await?;
error!("Received error {:?}. Terminate all connections.", err);
return Err(err);
}

Ok(())
}

// terminate all connections
#[instrument(skip_all)]
async fn terminate(
&self,
handle_connections: JoinHandle<()>,
handles: &Mutex<Vec<JoinHandle<()>>>,
) -> Result<(), DeviceServerError> {
handle_connections.abort();

match handle_connections.await {
Err(err) if !err.is_cancelled() => error!("Join failed: {}", err),
_ => {}
}

for h in handles.lock().await.iter() {
h.abort();
}

Ok(())
}

#[instrument(skip_all)]
async fn handle_connection(stream: TcpStream, tx_err: TxErrorType) {
match Self::impl_handle_connection(stream).await {
Ok(_) => {}
Err(DeviceServerError::CloseWebsocket)
| Err(DeviceServerError::Transport(TungsteniteError::Protocol(
ProtocolError::ResetWithoutClosingHandshake,
))) => {
warn!("Websocket connection closed");
// TODO: check that the connection is effectively closed on the server-side (not only on the client-side)
}
Err(err) => {
error!("Fatal error occurred: {}", err);
tx_err.send(err).await.expect("Error handler failure");
}
}
}

#[instrument(skip_all)]
async fn impl_handle_connection(stream: TcpStream) -> Result<(), DeviceServerError> {
let addr = stream
.peer_addr()
.map_err(|_| DeviceServerError::PeerAddr)?;

// create a WebSocket connection
let web_socket_stream = accept_async(stream)
.await
.map_err(|_| DeviceServerError::WebSocketHandshake)?;

info!("New WebSocket connection created: {}", addr);

// separate ownership between receiving and writing part
let (write, read) = web_socket_stream.split();

// Read the received command
read.map_err(DeviceServerError::Transport)
.and_then(|msg| async move {
info!("Received command from the client");
match msg {
// convert the message from a Vec<u8> into a OsString
Message::Binary(v) => {
String::from_utf8(v).map_err(DeviceServerError::Utf8Error)
}
Message::Close(_) => Err(DeviceServerError::CloseWebsocket), // the client closed the connection
_ => Err(DeviceServerError::ReadCommand),
}
})
.and_then(|cmd| async move {
// define a command handler
let cmd_handler = CommandHandler::default();

// execute the command and eventually return the error
let cmd_out = cmd_handler.execute(cmd).await.unwrap_or_else(|err| {
warn!("Shell error: {}", err);
format!("Shell error: {}\n", err)
});

info!("Send command output to the client");
Ok(Message::Binary(cmd_out.as_bytes().to_vec()))
})
.forward(write.sink_map_err(DeviceServerError::Transport))
.await?;

Ok(())
}
}
147 changes: 147 additions & 0 deletions rust-remote-shell/src/io_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use futures::stream::SplitSink;
use futures::SinkExt;
use tokio::io::{AsyncBufReadExt, BufReader, Stdin, Stdout};
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::MutexGuard;
use tokio::{io::AsyncWriteExt, sync::Mutex};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tracing::{debug, info, instrument, warn};

use crate::sender_client::SenderClientError;

#[derive(Debug)]
pub struct IOHandler {
stdout: Stdout,
reader: BufReader<Stdin>,
write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
tx_err: Sender<Result<(), SenderClientError>>,
buf_cmd: String,
}

impl IOHandler {
pub fn new(
write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
tx_err: Sender<Result<(), SenderClientError>>,
) -> Self {
Self {
stdout: tokio::io::stdout(),
reader: BufReader::new(tokio::io::stdin()),
write,
tx_err,
buf_cmd: String::new(),
}
}

#[instrument(skip_all)]
pub async fn read_stdin(&mut self) -> Result<(), SenderClientError> {
self.buf_cmd.clear();

// read a shell command from the stdin and send it to the server
let byte_read = self
.reader
.read_line(&mut self.buf_cmd)
.await
.map_err(SenderClientError::IORead)?;

debug!(?byte_read);
if byte_read == 0 {
info!("EOF received");
self.exit().await
} else if self.check_exit() {
info!("exit received");
self.exit().await
} else {
Ok(())
}
}

#[instrument(skip_all)]
fn check_exit(&self) -> bool {
self.buf_cmd.starts_with("exit")
}

#[instrument(skip_all)]
async fn exit(&mut self) -> Result<(), SenderClientError> {
// check if the command is exit. Eventually, close the connection

self.write
.send(Message::Close(None))
.await
.expect("Error while closing websocket connection");
info!("Closed websocket on client side");

self.tx_err.send(Ok(())).await.expect("channel error");

Ok(()) // send Ok(()) to close the connection on client side
//break Ok(());
}

#[instrument(skip_all)]
pub async fn send_to_server(&mut self) -> Result<(), SenderClientError> {
info!("Send command to the server");
self.write
.send(Message::Binary(self.buf_cmd.as_bytes().to_vec()))
.await
.map_err(|err| SenderClientError::TungsteniteReadData { err })?;

info!("Command sent: {}", self.buf_cmd);

Ok(())
}

#[instrument(skip_all)]
pub async fn write_stdout(
&mut self,
rx: &Mutex<UnboundedReceiver<Message>>,
) -> Result<(), SenderClientError> {
// check if there are command outputs stored in the channel. Eventually, print them to the stdout
let mut channel = rx.lock().await;

// wait to receive the first command output
let msg = channel.recv().await.unwrap();

self.impl_write_stdout(msg).await?;

// if the channel still contains information, empty it before aborting the task
self.empty_buffer(channel).await?;

Ok(())
}

#[instrument(skip_all)]
async fn impl_write_stdout(&mut self, msg: Message) -> Result<(), SenderClientError> {
let data = msg.into_data();

self.stdout
.write(&data)
.await
.map_err(|err| SenderClientError::IOWrite { err })?;

self.stdout.flush().await.expect("writing stdout");

Ok(())
}

async fn empty_buffer(
harlem88 marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
mut channel: MutexGuard<'_, UnboundedReceiver<Message>>,
) -> Result<(), SenderClientError> {
loop {
match channel.try_recv() {
Ok(msg) => {
self.impl_write_stdout(msg).await?;
}
Err(TryRecvError::Empty) => {
// the channel is empty but the connection is still open
break Ok(()); // TODO: check that Ok(()) is a good return value
}
Err(TryRecvError::Disconnected) => {
unreachable!("the channel should not be dropped before the task is aborted")
}
}
}
}
}
Loading