From cc5daa22bdf8e549becd5d33e4eb29b72f149c04 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 25 Nov 2024 17:07:55 -0500 Subject: [PATCH] assistant2: Improve tracking of pending completions (#21186) This PR improves the tracking of pending completions in `assistant2` such that we actually remove ones that have been completed. Release Notes: - N/A --- crates/assistant2/src/thread.rs | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index a433c10267ffe..c1df6c76d35d5 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -6,7 +6,7 @@ use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, StopReason, }; -use util::ResultExt as _; +use util::{post_inc, ResultExt as _}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -19,17 +19,24 @@ pub struct Message { pub text: String, } +struct PendingCompletion { + id: usize, + _task: Task<()>, +} + /// A thread of conversation with the LLM. pub struct Thread { messages: Vec, - pending_completion_tasks: Vec>, + completion_count: usize, + pending_completions: Vec, } impl Thread { pub fn new(_cx: &mut ModelContext) -> Self { Self { messages: Vec::new(), - pending_completion_tasks: Vec::new(), + completion_count: 0, + pending_completions: Vec::new(), } } @@ -79,7 +86,9 @@ impl Thread { model: Arc, cx: &mut ModelContext, ) { - let task = cx.spawn(|this, mut cx| async move { + let pending_completion_id = post_inc(&mut self.completion_count); + + let task = cx.spawn(|thread, mut cx| async move { let stream = model.stream_completion(request, &cx); let stream_completion = async { let mut events = stream.await?; @@ -88,7 +97,7 @@ impl Thread { while let Some(event) = events.next().await { let event = event?; - this.update(&mut cx, |thread, cx| { + thread.update(&mut cx, |thread, cx| { match event { LanguageModelCompletionEvent::StartMessage { .. } => { thread.messages.push(Message { @@ -116,6 +125,12 @@ impl Thread { smol::future::yield_now().await; } + thread.update(&mut cx, |thread, _cx| { + thread + .pending_completions + .retain(|completion| completion.id != pending_completion_id); + })?; + anyhow::Ok(stop_reason) }; @@ -123,7 +138,10 @@ impl Thread { let _ = result.log_err(); }); - self.pending_completion_tasks.push(task); + self.pending_completions.push(PendingCompletion { + id: pending_completion_id, + _task: task, + }); } }