Skip to content

Commit

Permalink
Updated heartbeat handler to correctly format packets
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmcquilkin committed Mar 15, 2024
1 parent 56cdb2e commit c2b4622
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 35 deletions.
50 changes: 26 additions & 24 deletions src/connections/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::sync::Arc;

use crate::errors_internal::{Error, InternalChannelError, InternalStreamError};
use crate::protobufs;
use crate::types::EncodedToRadioPacketWithHeader;
use crate::utils::format_data_packet;
use log::{debug, error, trace};
use prost::Message;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::spawn;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;

Expand Down Expand Up @@ -90,7 +88,7 @@ where

pub fn spawn_write_handler<W>(
cancellation_token: CancellationToken,
write_stream: Arc<Mutex<W>>,
write_stream: W,
write_input_rx: tokio::sync::mpsc::UnboundedReceiver<EncodedToRadioPacketWithHeader>,
) -> JoinHandle<Result<(), Error>>
where
Expand All @@ -116,7 +114,7 @@ where

async fn start_write_handler<W>(
_cancellation_token: CancellationToken,
write_stream: Arc<Mutex<W>>,
mut write_stream: W,
mut write_input_rx: tokio::sync::mpsc::UnboundedReceiver<EncodedToRadioPacketWithHeader>,
) -> Result<(), Error>
where
Expand All @@ -127,8 +125,6 @@ where
while let Some(message) = write_input_rx.recv().await {
trace!("Writing packet data: {:?}", message);

let mut write_stream = write_stream.lock().await;

if let Err(e) = write_stream.write(message.data()).await {
error!("Error writing to stream: {:?}", e);
return Err(Error::InternalStreamError(
Expand Down Expand Up @@ -181,14 +177,11 @@ async fn start_processing_handler(
trace!("Processing read_output_rx channel closed");
}

pub fn spawn_heartbeat_handler<W>(
pub fn spawn_heartbeat_handler(
cancellation_token: CancellationToken,
write_stream: Arc<Mutex<W>>,
) -> JoinHandle<Result<(), Error>>
where
W: AsyncWriteExt + Send + Unpin + 'static,
{
let handle = start_heartbeat_handler(cancellation_token.clone(), write_stream);
write_input_tx: UnboundedSender<EncodedToRadioPacketWithHeader>,
) -> JoinHandle<Result<(), Error>> {
let handle = start_heartbeat_handler(cancellation_token.clone(), write_input_tx);

spawn(async move {
tokio::select! {
Expand All @@ -206,21 +199,20 @@ where
})
}

async fn start_heartbeat_handler<W>(
async fn start_heartbeat_handler(
_cancellation_token: CancellationToken,
write_stream: Arc<Mutex<W>>,
) -> Result<(), Error>
where
W: AsyncWriteExt + Send + Unpin + 'static,
{
write_input_tx: UnboundedSender<EncodedToRadioPacketWithHeader>,
) -> Result<(), Error> {
debug!("Started heartbeat handler");

loop {
tokio::time::sleep(std::time::Duration::from_secs(CLIENT_HEARTBEAT_INTERVAL)).await;

let mut write_stream = write_stream.lock().await;

let heartbeat_packet = protobufs::Heartbeat::default();
let heartbeat_packet = protobufs::ToRadio {
payload_variant: Some(protobufs::to_radio::PayloadVariant::Heartbeat(
protobufs::Heartbeat::default(),
)),
};

let mut buffer = Vec::new();
match heartbeat_packet.encode(&mut buffer) {
Expand All @@ -231,7 +223,17 @@ where
}
};

if let Err(e) = write_stream.write(&buffer).await {
let packet_with_header = match format_data_packet(buffer.into()) {
Ok(p) => p,
Err(e) => {
error!("Error formatting heartbeat packet: {:?}", e);
continue;
}
};

trace!("Sending heartbeat packet");

if let Err(e) = write_input_tx.send(packet_with_header) {
error!("Error writing heartbeat packet to stream: {:?}", e);
return Err(Error::InternalStreamError(
InternalStreamError::StreamWriteError {
Expand Down
15 changes: 5 additions & 10 deletions src/connections/stream_api.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use futures_util::future::join3;
use log::trace;
use prost::Message;
use std::{fmt::Display, marker::PhantomData, sync::Arc};
use std::{fmt::Display, marker::PhantomData};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{mpsc::UnboundedSender, Mutex},
sync::mpsc::UnboundedSender,
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -438,16 +438,11 @@ impl StreamApi {
let (read_stream, write_stream) = tokio::io::split(stream_handle.stream);
let cancellation_token = CancellationToken::new();

let write_stream_mutex = Arc::new(Mutex::new(write_stream));

let read_handle =
handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx);

let write_handle = handlers::spawn_write_handler(
cancellation_token.clone(),
write_stream_mutex.clone(),
write_input_rx,
);
let write_handle =
handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx);

let processing_handle = handlers::spawn_processing_handler(
cancellation_token.clone(),
Expand All @@ -456,7 +451,7 @@ impl StreamApi {
);

let heartbeat_handle =
handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_stream_mutex);
handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_input_tx.clone());

// Persist channels and kill switch to struct

Expand Down
2 changes: 1 addition & 1 deletion src/protobufs

0 comments on commit c2b4622

Please sign in to comment.