From 45a48f411c34032065fb49e29b04dee185203dcb Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 2 Jan 2025 15:20:34 -0800 Subject: [PATCH] Add TokenCounter to Agent struct to only instantiate once --- crates/goose/src/agent.rs | 42 ++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs index 5c02bf3be..986793c76 100644 --- a/crates/goose/src/agent.rs +++ b/crates/goose/src/agent.rs @@ -55,6 +55,7 @@ pub struct Agent { systems: Vec>, provider: Box, provider_usage: Mutex>, + token_counter: TokenCounter, } #[allow(dead_code)] @@ -65,6 +66,7 @@ impl Agent { systems: Vec::new(), provider, provider_usage: Mutex::new(Vec::new()), + token_counter: TokenCounter::new(), } } @@ -170,6 +172,7 @@ impl Agent { messages: &[Message], pending: &Vec, target_limit: usize, + token_counter: &TokenCounter, ) -> AgentResult> { // Prepares the inference by managing context window and token budget. // This function: @@ -191,7 +194,6 @@ impl Agent { // Returns: // * `AgentResult>` - Updated message history with status appended - let token_counter = TokenCounter::new(); let resource_content = self.get_systems_resources().await?; // Flatten all resource content into a vector of strings @@ -209,13 +211,12 @@ impl Agent { &resources, Some(&self.provider.get_model_config().model_name), ); - let mut status_content: Vec = Vec::new(); if approx_count > target_limit { println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); - // Get token counts for each resourcee + // Get token counts for each resource let mut system_token_counts = HashMap::new(); // Iterate through each system and its resources @@ -340,6 +341,7 @@ impl Agent { &messages, &Vec::new(), estimated_limit, + &self.token_counter, ) .await?; @@ -399,7 +401,7 @@ impl Agent { messages.pop(); let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &self.token_counter).await?; } })) } @@ -687,13 +689,20 @@ mod tests { let messages = vec![Message::user().with_text("Hi there")]; let tools = vec![]; let pending = vec![]; - + let token_counter = TokenCounter::new(); // Approx count is 40, so target limit of 35 will force trimming let target_limit = 35; // Call prepare_inference let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .prepare_inference( + system_prompt, + &tools, + &messages, + &pending, + target_limit, + &token_counter, + ) .await?; // Get the last message which should be the tool response containing status @@ -710,10 +719,18 @@ mod tests { // Now test with a target limit that allows both resources (no trimming) let target_limit = 100; + let token_counter = TokenCounter::new(); // Call prepare_inference let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .prepare_inference( + system_prompt, + &tools, + &messages, + &pending, + target_limit, + &token_counter, + ) .await?; // Get the last message which should be the tool response containing status @@ -755,14 +772,21 @@ mod tests { let messages = vec![Message::user().with_text("Hi there")]; let tools = vec![]; let pending = vec![]; - + let token_counter = TokenCounter::new(); // Use the context limit from the model config let target_limit = agent.get_context_limit(); assert_eq!(target_limit, 20, "Context limit should be 20"); // Call prepare_inference let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .prepare_inference( + system_prompt, + &tools, + &messages, + &pending, + target_limit, + &token_counter, + ) .await?; // Get the last message which should be the tool response containing status