Skip to content

Commit

Permalink
assistant2: Add support for using tools (zed-industries#21190)
Browse files Browse the repository at this point in the history
This PR adds rudimentary support for using tools to `assistant2`. There
are currently no visual affordances for tool use.

This is gated behind the `assistant-tool-use` feature flag.

<img width="1079" alt="Screenshot 2024-11-25 at 7 21 31 PM"
src="https://github.com/user-attachments/assets/64d6ca29-c592-4474-8e9d-c344f855bc63">

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Nov 26, 2024
1 parent 3901d46 commit f059b6a
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 37 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 1 addition & 11 deletions crates/assistant/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use fs::{Fs, RemoveOptions};
use futures::{future::Shared, FutureExt, StreamExt};
use gpui::{
Expand Down Expand Up @@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
Error(String),
}

pub(crate) struct ToolUseFeatureFlag;

impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";

fn enabled_for_staff() -> bool {
false
}
}

#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: Arc<str>,
Expand Down
3 changes: 3 additions & 0 deletions crates/assistant2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ doctest = false

[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
feature_flags.workspace = true
Expand All @@ -23,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true
proto.workspace = true
settings.workspace = true
serde_json.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true
Expand Down
61 changes: 54 additions & 7 deletions crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::sync::Arc;

use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use gpui::{
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
Expand All @@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace;

use crate::message_editor::MessageEditor;
use crate::thread::Thread;
use crate::thread::{Thread, ThreadEvent};
use crate::{NewThread, ToggleFocus, ToggleModelSelector};

pub fn init(cx: &mut AppContext) {
Expand All @@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
}

pub struct AssistantPanel {
workspace: WeakView<Workspace>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
_subscriptions: Vec<Subscription>,
}

Expand All @@ -36,33 +41,75 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, cx))
cx.new_view(|cx| Self::new(workspace, tools, cx))
})
})
}

fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
let thread = cx.new_model(Thread::new);
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];

Self {
workspace: workspace.weak_handle(),
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools,
_subscriptions: subscriptions,
}
}

fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
let thread = cx.new_model(Thread::new);
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
let tools = self.thread.read(cx).tools().clone();
let thread = cx.new_model(|cx| Thread::new(tools, cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];

self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread;
self._subscriptions = subscriptions;

self.message_editor.focus_handle(cx).focus(cx);
}

fn handle_thread_event(
&mut self,
_: Model<Thread>,
event: &ThreadEvent,
cx: &mut ViewContext<Self>,
) {
match event {
ThreadEvent::StreamedCompletion => {}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();

for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);

self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
}
ThreadEvent::ToolFinished { .. } => {}
}
}
}

impl FocusableView for AssistantPanel {
Expand Down
19 changes: 17 additions & 2 deletions crates/assistant2/src/message_editor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use editor::{Editor, EditorElement, EditorStyle};
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use gpui::{AppContext, FocusableView, Model, TextStyle, View};
use language_model::LanguageModelRegistry;
use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
Expand Down Expand Up @@ -55,7 +56,21 @@ impl MessageEditor {

self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message);
let request = thread.to_completion_request(request_kind, cx);
let mut request = thread.to_completion_request(request_kind, cx);

if cx.has_flag::<ToolUseFeatureFlag>() {
request.tools = thread
.tools()
.tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
})
.collect();
}

thread.stream_completion(request, model, cx)
});

Expand Down
Loading

0 comments on commit f059b6a

Please sign in to comment.