Skip to content

Commit

Permalink
Factor tool definitions out of assistant (zed-industries#21189)
Browse files Browse the repository at this point in the history
This PR factors the tool definitions out of the `assistant` crate so
that they can be shared between `assistant` and `assistant2`.

`ToolWorkingSet` now lives in `assistant_tool`. The tool definitions
themselves live in `assistant_tools`, with the exception of the
`ContextServerTool`, which has been moved to the `context_server` crate.

As part of this refactoring I needed to extract the
`ContextServerSettings` to a separate `context_server_settings` crate so
that the `extension_host`—which is referenced by the `remote_server`—can
name the `ContextServerSettings` type without pulling in some undesired
dependencies.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Nov 25, 2024
1 parent 321fd19 commit 3901d46
Show file tree
Hide file tree
Showing 35 changed files with 219 additions and 113 deletions.
42 changes: 37 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"crates/assistant2",
"crates/assistant_slash_command",
"crates/assistant_tool",
"crates/assistant_tools",
"crates/audio",
"crates/auto_update",
"crates/auto_update_ui",
Expand All @@ -22,7 +23,8 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
"crates/context_servers",
"crates/context_server",
"crates/context_server_settings",
"crates/copilot",
"crates/db",
"crates/diagnostics",
Expand Down Expand Up @@ -191,6 +193,7 @@ assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_tool = { path = "crates/assistant_tool" }
assistant_tools = { path = "crates/assistant_tools" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_ui = { path = "crates/auto_update_ui" }
Expand All @@ -205,7 +208,8 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
context_servers = { path = "crates/context_servers" }
context_server = { path = "crates/context_server" }
context_server_settings = { path = "crates/context_server_settings" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
diagnostics = { path = "crates/diagnostics" }
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_servers.workspace = true
context_server.workspace = true
db.workspace = true
editor.workspace = true
feature_flags.workspace = true
Expand Down
12 changes: 1 addition & 11 deletions crates/assistant/src/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@ pub mod slash_command_settings;
mod slash_command_working_set;
mod streaming_diff;
mod terminal_inline_assistant;
mod tool_working_set;
mod tools;

use crate::slash_command::project_command::ProjectSlashCommandFeatureFlag;
pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet};
pub use crate::tool_working_set::{ToolId, ToolWorkingSet};
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use assistant_tool::ToolRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
Expand Down Expand Up @@ -246,7 +242,7 @@ pub fn init(
assistant_slash_command::init(cx);
assistant_tool::init(cx);
assistant_panel::init(cx);
context_servers::init(cx);
context_server::init(cx);

let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
fs: fs.clone(),
Expand All @@ -259,7 +255,6 @@ pub fn init(
.map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
register_slash_commands(Some(prompt_builder.clone()), cx);
register_tools(cx);
inline_assistant::init(
fs.clone(),
prompt_builder.clone(),
Expand Down Expand Up @@ -423,11 +418,6 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
}
}

fn register_tools(cx: &mut AppContext) {
let tool_registry = ToolRegistry::global(cx);
tool_registry.register_tool(tools::now_tool::NowTool);
}

pub fn humanize_token_count(count: usize) -> String {
match count {
0..=999 => count.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::slash_command::file_command::codeblock_fence_for_path;
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
humanize_token_count,
Expand All @@ -23,6 +22,7 @@ use crate::{
};
use anyhow::Result;
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use assistant_tool::ToolWorkingSet;
use client::{proto, zed_urls, Client, Status};
use collections::{hash_map, BTreeSet, HashMap, HashSet};
use editor::{
Expand Down Expand Up @@ -1316,7 +1316,7 @@ impl AssistantPanel {

fn restart_context_servers(
workspace: &mut Workspace,
_action: &context_servers::Restart,
_action: &context_server::Restart,
cx: &mut ViewContext<Workspace>,
) {
let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
mod context_tests;

use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
prompts::PromptBuilder,
slash_command::{file_command::FileCommandMetadata, SlashCommandLine},
Expand All @@ -12,6 +11,7 @@ use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandResult,
};
use assistant_tool::ToolWorkingSet;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant/src/context/context_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::{AssistantEdit, MessageCacheMetadata};
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId,
Expand All @@ -11,6 +10,7 @@ use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandContent, SlashCommandEvent, SlashCommandOutput,
SlashCommandOutputSection, SlashCommandRegistry, SlashCommandResult,
};
use assistant_tool::ToolWorkingSet;
use collections::{HashMap, HashSet};
use fs::FakeFs;
use futures::{
Expand Down
19 changes: 10 additions & 9 deletions crates/assistant/src/context_store.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::slash_command::context_server_command;
use crate::SlashCommandId;
use crate::{
prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
};
use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet};
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ToolId, ToolWorkingSet};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
use clock::ReplicaId;
use collections::HashMap;
use context_servers::manager::ContextServerManager;
use context_servers::ContextServerFactoryRegistry;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
Expand Down Expand Up @@ -808,13 +809,13 @@ impl ContextStore {
fn handle_context_server_event(
&mut self,
context_server_manager: Model<ContextServerManager>,
event: &context_servers::manager::Event,
event: &context_server::manager::Event,
cx: &mut ModelContext<Self>,
) {
let slash_command_working_set = self.slash_commands.clone();
let tool_working_set = self.tools.clone();
match event {
context_servers::manager::Event::ServerStarted { server_id } => {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
Expand All @@ -825,7 +826,7 @@ impl ContextStore {
return;
};

if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
Expand Down Expand Up @@ -853,12 +854,12 @@ impl ContextStore {
}
}

if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools.tools.into_iter().map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(
Arc::new(tools::context_server_tool::ContextServerTool::new(
Arc::new(ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
Expand All @@ -880,7 +881,7 @@ impl ContextStore {
.detach();
}
}
context_servers::manager::Event::ServerStopped { server_id } => {
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
Expand Down
12 changes: 6 additions & 6 deletions crates/assistant/src/slash_command/context_server_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use assistant_slash_command::{
SlashCommandOutputSection, SlashCommandResult,
};
use collections::HashMap;
use context_servers::{
use context_server::{
manager::{ContextServer, ContextServerManager},
types::Prompt,
};
Expand Down Expand Up @@ -95,9 +95,9 @@ impl SlashCommand for ContextServerSlashCommand {

let completion_result = protocol
.completion(
context_servers::types::CompletionReference::Prompt(
context_servers::types::PromptReference {
r#type: context_servers::types::PromptReferenceType::Prompt,
context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference {
r#type: context_server::types::PromptReferenceType::Prompt,
name: prompt_name,
},
),
Expand Down Expand Up @@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand {
if result
.messages
.iter()
.any(|msg| !matches!(msg.role, context_servers::types::Role::User))
.any(|msg| !matches!(msg.role, context_server::types::Role::User))
{
return Err(anyhow!(
"Prompt contains non-user roles, which is not supported"
Expand All @@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand {
.messages
.into_iter()
.filter_map(|msg| match msg.content {
context_servers::types::MessageContent::Text { text } => Some(text),
context_server::types::MessageContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
Expand Down
2 changes: 0 additions & 2 deletions crates/assistant/src/tools.rs

This file was deleted.

4 changes: 3 additions & 1 deletion crates/assistant_tool/src/assistant_tool.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
mod tool_registry;
mod tool_working_set;

use std::sync::Arc;

use anyhow::Result;
use gpui::{AppContext, Task, WeakView, WindowContext};
use workspace::Workspace;

pub use tool_registry::*;
pub use crate::tool_registry::*;
pub use crate::tool_working_set::*;

pub fn init(cx: &mut AppContext) {
ToolRegistry::default_global(cx);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use assistant_tool::{Tool, ToolRegistry};
use std::sync::Arc;

use collections::HashMap;
use gpui::AppContext;
use parking_lot::Mutex;
use std::sync::Arc;

use crate::{Tool, ToolRegistry};

#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
Expand Down
Loading

0 comments on commit 3901d46

Please sign in to comment.