Skip to content

Commit

Permalink
assistant2: Improve tracking of pending completions (zed-industries#2…
Browse files Browse the repository at this point in the history
…1186)

This PR improves the tracking of pending completions in `assistant2`
such that we actually remove ones that have been completed.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Nov 25, 2024
1 parent 2b92508 commit cc5daa2
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions crates/assistant2/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason,
};
use util::ResultExt as _;
use util::{post_inc, ResultExt as _};

#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
Expand All @@ -19,17 +19,24 @@ pub struct Message {
pub text: String,
}

struct PendingCompletion {
id: usize,
_task: Task<()>,
}

/// A thread of conversation with the LLM.
pub struct Thread {
messages: Vec<Message>,
pending_completion_tasks: Vec<Task<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
}

impl Thread {
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
Self {
messages: Vec::new(),
pending_completion_tasks: Vec::new(),
completion_count: 0,
pending_completions: Vec::new(),
}
}

Expand Down Expand Up @@ -79,7 +86,9 @@ impl Thread {
model: Arc<dyn LanguageModel>,
cx: &mut ModelContext<Self>,
) {
let task = cx.spawn(|this, mut cx| async move {
let pending_completion_id = post_inc(&mut self.completion_count);

let task = cx.spawn(|thread, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let stream_completion = async {
let mut events = stream.await?;
Expand All @@ -88,7 +97,7 @@ impl Thread {
while let Some(event) = events.next().await {
let event = event?;

this.update(&mut cx, |thread, cx| {
thread.update(&mut cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message {
Expand Down Expand Up @@ -116,14 +125,23 @@ impl Thread {
smol::future::yield_now().await;
}

thread.update(&mut cx, |thread, _cx| {
thread
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
})?;

anyhow::Ok(stop_reason)
};

let result = stream_completion.await;
let _ = result.log_err();
});

self.pending_completion_tasks.push(task);
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
_task: task,
});
}
}

Expand Down

0 comments on commit cc5daa2

Please sign in to comment.