Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latency issue fix with prepare_inference #535

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions crates/goose/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub struct Agent {
systems: Vec<Box<dyn System>>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
token_counter: TokenCounter,
}

#[allow(dead_code)]
Expand All @@ -65,6 +66,7 @@ impl Agent {
systems: Vec::new(),
provider,
provider_usage: Mutex::new(Vec::new()),
token_counter: TokenCounter::new(),
}
}

Expand Down Expand Up @@ -170,6 +172,7 @@ impl Agent {
messages: &[Message],
pending: &Vec<Message>,
target_limit: usize,
token_counter: &TokenCounter,
) -> AgentResult<Vec<Message>> {
// Prepares the inference by managing context window and token budget.
// This function:
Expand All @@ -191,7 +194,6 @@ impl Agent {
// Returns:
// * `AgentResult<Vec<Message>>` - Updated message history with status appended

let token_counter = TokenCounter::new();
let resource_content = self.get_systems_resources().await?;

// Flatten all resource content into a vector of strings
Expand All @@ -209,13 +211,12 @@ impl Agent {
&resources,
Some(&self.provider.get_model_config().model_name),
);

let mut status_content: Vec<String> = Vec::new();

if approx_count > target_limit {
println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit);

// Get token counts for each resourcee
// Get token counts for each resource
let mut system_token_counts = HashMap::new();

// Iterate through each system and its resources
Expand Down Expand Up @@ -340,6 +341,7 @@ impl Agent {
&messages,
&Vec::new(),
estimated_limit,
&self.token_counter,
)
.await?;

Expand Down Expand Up @@ -399,7 +401,7 @@ impl Agent {
messages.pop();

let pending = vec![response, message_tool_response];
messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?;
messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &self.token_counter).await?;
}
}))
}
Expand Down Expand Up @@ -687,13 +689,20 @@ mod tests {
let messages = vec![Message::user().with_text("Hi there")];
let tools = vec![];
let pending = vec![];

let token_counter = TokenCounter::new();
// Approx count is 40, so target limit of 35 will force trimming
let target_limit = 35;

// Call prepare_inference
let result = agent
.prepare_inference(system_prompt, &tools, &messages, &pending, target_limit)
.prepare_inference(
system_prompt,
&tools,
&messages,
&pending,
target_limit,
&token_counter,
)
.await?;

// Get the last message which should be the tool response containing status
Expand All @@ -710,10 +719,18 @@ mod tests {

// Now test with a target limit that allows both resources (no trimming)
let target_limit = 100;
let token_counter = TokenCounter::new();

// Call prepare_inference
let result = agent
.prepare_inference(system_prompt, &tools, &messages, &pending, target_limit)
.prepare_inference(
system_prompt,
&tools,
&messages,
&pending,
target_limit,
&token_counter,
)
.await?;

// Get the last message which should be the tool response containing status
Expand Down Expand Up @@ -755,14 +772,21 @@ mod tests {
let messages = vec![Message::user().with_text("Hi there")];
let tools = vec![];
let pending = vec![];

let token_counter = TokenCounter::new();
// Use the context limit from the model config
let target_limit = agent.get_context_limit();
assert_eq!(target_limit, 20, "Context limit should be 20");

// Call prepare_inference
let result = agent
.prepare_inference(system_prompt, &tools, &messages, &pending, target_limit)
.prepare_inference(
system_prompt,
&tools,
&messages,
&pending,
target_limit,
&token_counter,
)
.await?;

// Get the last message which should be the tool response containing status
Expand Down
Loading