Skip to content

Commit

Permalink
Merge pull request #254 from resoluteCoder/add-cancel-to-streaming-ge…
Browse files Browse the repository at this point in the history
…neration

added ability to cancel streaming generation when client disconnects
  • Loading branch information
gkumbhat authored Nov 20, 2024
2 parents 774cdaa + d6d36ff commit a4964b1
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions src/orchestrator/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use axum::http::HeaderMap;
use futures::{future::try_join_all, Stream, StreamExt, TryStreamExt};
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tracing::{debug, error, info, instrument};
use tracing::{debug, error, info, instrument, warn};

use super::{get_chunker_ids, Context, Error, Orchestrator, StreamingClassificationWithGenTask};
use crate::{
Expand Down Expand Up @@ -192,7 +192,12 @@ impl Orchestrator {
match result {
Some(result) => {
debug!(%trace_id, ?result, "sending result to client");
let _ = response_tx.send(result).await;
if (response_tx.send(result).await).is_err() {
warn!(%trace_id, "response channel closed (client disconnected), terminating task");
// Broadcast cancellation signal to tasks
let _ = error_tx.send(Error::Cancelled);
return;
}
},
None => {
info!(%trace_id, "task completed: stream closed");
Expand All @@ -209,7 +214,10 @@ impl Orchestrator {
tokio::spawn(async move {
while let Some(result) = generation_stream.next().await {
debug!(%trace_id, ?result, "sending result to client");
let _ = response_tx.send(result).await;
if (response_tx.send(result).await).is_err() {
warn!(%trace_id, "response channel closed (client disconnected), terminating task");
return;
}
}
debug!(%trace_id, "task completed: stream closed");
});
Expand Down Expand Up @@ -337,7 +345,10 @@ async fn generation_broadcast_task(
let mut error_rx = error_tx.subscribe();
loop {
tokio::select! {
_ = error_rx.recv() => { break },
_ = error_rx.recv() => {
warn!("cancellation signal received, terminating task");
break
},
result = generation_stream.next() => {
match result {
Some(Ok(generation)) => {
Expand Down Expand Up @@ -380,7 +391,10 @@ async fn detection_task(

loop {
tokio::select! {
_ = error_rx.recv() => { break },
_ = error_rx.recv() => {
warn!("cancellation signal received, terminating task");
break
},
result = chunk_rx.recv() => {
match result {
Ok(chunk) => {
Expand Down Expand Up @@ -508,7 +522,10 @@ async fn chunk_broadcast_task(
async move {
loop {
tokio::select! {
_ = error_rx.recv() => { break },
_ = error_rx.recv() => {
warn!("cancellation signal received, terminating task");
break
},
result = output_stream.next() => {
match result {
Some(Ok(chunk)) => {
Expand Down

0 comments on commit a4964b1

Please sign in to comment.