From 3c9a4834cfd0b518e721b758eb4bd256ffc050dc Mon Sep 17 00:00:00 2001 From: resoluteCoder Date: Thu, 14 Nov 2024 16:11:36 -0600 Subject: [PATCH 1/3] added ability to cancel streaming generation when client disconnects Signed-off-by: resoluteCoder --- src/orchestrator/streaming.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index 861ab354..197878d8 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -209,7 +209,15 @@ impl Orchestrator { tokio::spawn(async move { while let Some(result) = generation_stream.next().await { debug!(%trace_id, ?result, "sending result to client"); - let _ = response_tx.send(result).await; + let response = response_tx.send(result).await; + match response { + Err(e) => { + debug!(%trace_id, "could not send to client: {e}"); + let _ = response_tx.send(Err(Error::Cancelled)); + return; + } + _ => {} + } } debug!(%trace_id, "task completed: stream closed"); }); From 67852d8e2b43f004e0d2aab94abd526cfd817837 Mon Sep 17 00:00:00 2001 From: resoluteCoder Date: Mon, 18 Nov 2024 14:37:49 -0600 Subject: [PATCH 2/3] lint and used if let syntax Signed-off-by: resoluteCoder --- src/orchestrator/streaming.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index 197878d8..0e5b67be 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -210,13 +210,10 @@ impl Orchestrator { while let Some(result) = generation_stream.next().await { debug!(%trace_id, ?result, "sending result to client"); let response = response_tx.send(result).await; - match response { - Err(e) => { - debug!(%trace_id, "could not send to client: {e}"); - let _ = response_tx.send(Err(Error::Cancelled)); - return; - } - _ => {} + if let Err(e) = response { + debug!(%trace_id, "could not send to client: {e}"); + let _ = response_tx.send(Err(Error::Cancelled)).await; + return; } } debug!(%trace_id, "task completed: stream closed"); From d6d36ffd48cba4d1c94e3de62ba0ce958385859c Mon Sep 17 00:00:00 2001 From: resoluteCoder Date: Tue, 19 Nov 2024 13:38:55 -0600 Subject: [PATCH 3/3] added cancellation for output detections as well Signed-off-by: resoluteCoder --- src/orchestrator/streaming.rs | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs index 0e5b67be..3d5507d1 100644 --- a/src/orchestrator/streaming.rs +++ b/src/orchestrator/streaming.rs @@ -24,7 +24,7 @@ use axum::http::HeaderMap; use futures::{future::try_join_all, Stream, StreamExt, TryStreamExt}; use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; -use tracing::{debug, error, info, instrument}; +use tracing::{debug, error, info, instrument, warn}; use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask}; use crate::{ @@ -192,7 +192,12 @@ impl Orchestrator { match result { Some(result) => { debug!(%trace_id, ?result, "sending result to client"); - let _ = response_tx.send(result).await; + if (response_tx.send(result).await).is_err() { + warn!(%trace_id, "response channel closed (client disconnected), terminating task"); + // Broadcast cancellation signal to tasks + let _ = error_tx.send(Error::Cancelled); + return; + } }, None => { info!(%trace_id, "task completed: stream closed"); @@ -209,10 +214,8 @@ impl Orchestrator { tokio::spawn(async move { while let Some(result) = generation_stream.next().await { debug!(%trace_id, ?result, "sending result to client"); - let response = response_tx.send(result).await; - if let Err(e) = response { - debug!(%trace_id, "could not send to client: {e}"); - let _ = response_tx.send(Err(Error::Cancelled)).await; + if (response_tx.send(result).await).is_err() { + warn!(%trace_id, "response channel closed (client disconnected), terminating task"); return; } } @@ -342,7 +345,10 @@ async fn generation_broadcast_task( let mut error_rx = error_tx.subscribe(); loop { tokio::select! { - _ = error_rx.recv() => { break }, + _ = error_rx.recv() => { + warn!("cancellation signal received, terminating task"); + break + }, result = generation_stream.next() => { match result { Some(Ok(generation)) => { @@ -385,7 +391,10 @@ async fn detection_task( loop { tokio::select! { - _ = error_rx.recv() => { break }, + _ = error_rx.recv() => { + warn!("cancellation signal received, terminating task"); + break + }, result = chunk_rx.recv() => { match result { Ok(chunk) => { @@ -513,7 +522,10 @@ async fn chunk_broadcast_task( async move { loop { tokio::select! { - _ = error_rx.recv() => { break }, + _ = error_rx.recv() => { + warn!("cancellation signal received, terminating task"); + break + }, result = output_stream.next() => { match result { Some(Ok(chunk)) => {