-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1e41250
commit 3c05de4
Showing
8 changed files
with
342 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import List | ||
from langchain.schema import BaseMessage | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.output_parsers import PydanticOutputParser | ||
from pydantic import BaseModel, Field | ||
|
||
from app.modules.conversations.message.message_schema import MessageResponse | ||
from app.modules.intelligence.prompts.classification_prompts import ClassificationResult | ||
|
||
class AgentClassification(BaseModel): | ||
agent_id: str = Field(..., description="ID of the agent that should handle the query") | ||
confidence: float = Field(..., description="Confidence score between 0 and 1") | ||
reasoning: str = Field(..., description="Reasoning behind the agent selection") | ||
|
||
class AgentClassifier: | ||
def __init__(self, llm, available_agents): | ||
self.llm = llm | ||
self.available_agents = available_agents | ||
self.parser = PydanticOutputParser(pydantic_object=AgentClassification) | ||
|
||
def create_prompt(self) -> str: | ||
|
||
return """You are an expert agent router. | ||
User's query: {query} | ||
Conversation history: {history} | ||
Based on the user's query and conversation history, | ||
select the most appropriate agent from the following options: | ||
Available Agents: | ||
{agents_desc} | ||
Analyze the query and select the agent that best matches the user's needs. | ||
Consider: | ||
1. The specific task or question type | ||
2. Required expertise | ||
3. Context from conversation history | ||
4. Any explicit agent requests | ||
{format_instructions} | ||
""" | ||
|
||
async def classify(self, messages: List[MessageResponse]) -> AgentClassification: | ||
"""Classify the conversation and determine which agent should handle it""" | ||
|
||
if not messages: | ||
return AgentClassification( | ||
agent_id=self.available_agents[0].id, # Default to first agent | ||
confidence=0.0, | ||
reasoning="No messages to classify" | ||
) | ||
|
||
# Format agent descriptions | ||
agents_desc = "\n".join([ | ||
f"{i+1}. {agent.id}: {agent.description}" | ||
for i, agent in enumerate(self.available_agents) | ||
]) | ||
|
||
# Get the last message and up to 10 messages of history | ||
last_message = messages[-1].content if messages else "" | ||
history = [msg.content for msg in messages[-10:]] if len(messages) > 1 else [] | ||
|
||
inputs = { | ||
"query": last_message, | ||
"history": history, | ||
"agents_desc": agents_desc, | ||
"format_instructions": self.parser.get_format_instructions() | ||
} | ||
|
||
# Rest of the classification logic... | ||
|
||
prompt = ChatPromptTemplate.from_template(self.create_prompt()) | ||
|
||
chain = prompt | self.llm | self.parser | ||
|
||
result = await chain.ainvoke( | ||
inputs | ||
) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Dict, Any | ||
from sqlalchemy.orm import Session | ||
|
||
from app.modules.intelligence.provider.provider_service import AgentType, ProviderService | ||
from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import CodeChangesChatAgent | ||
from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent | ||
from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent | ||
from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent | ||
from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import IntegrationTestChatAgent | ||
from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent | ||
from app.modules.intelligence.agents.chat_agents.code_gen_chat_agent import CodeGenerationChatAgent | ||
from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent | ||
|
||
class AgentFactory: | ||
def __init__(self, db: Session, provider_service: ProviderService): | ||
self.db = db | ||
self.provider_service = provider_service | ||
self._agent_cache: Dict[str, Any] = {} | ||
|
||
def get_agent(self, agent_id: str, user_id: str) -> Any: | ||
"""Get or create an agent instance""" | ||
cache_key = f"{agent_id}_{user_id}" | ||
|
||
if cache_key in self._agent_cache: | ||
return self._agent_cache[cache_key] | ||
|
||
mini_llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN) | ||
reasoning_llm = self.provider_service.get_large_llm(agent_type=AgentType.LANGCHAIN) | ||
|
||
agent = self._create_agent(agent_id, mini_llm, reasoning_llm, user_id) | ||
self._agent_cache[cache_key] = agent | ||
return agent | ||
|
||
def _create_agent(self, agent_id: str, mini_llm, reasoning_llm, user_id: str) -> Any: | ||
"""Create a new agent instance""" | ||
agent_map = { | ||
"debugging_agent": lambda: DebuggingChatAgent(mini_llm, reasoning_llm, self.db), | ||
"codebase_qna_agent": lambda: QNAChatAgent(mini_llm, reasoning_llm, self.db), | ||
"unit_test_agent": lambda: UnitTestAgent(mini_llm, reasoning_llm, self.db), | ||
"integration_test_agent": lambda: IntegrationTestChatAgent(mini_llm, reasoning_llm, self.db), | ||
"code_changes_agent": lambda: CodeChangesChatAgent(mini_llm, reasoning_llm, self.db), | ||
"LLD_agent": lambda: LLDChatAgent(mini_llm, reasoning_llm, self.db), | ||
"code_generation_agent": lambda: CodeGenerationChatAgent(mini_llm, reasoning_llm, self.db), | ||
} | ||
|
||
if agent_id in agent_map: | ||
return agent_map[agent_id]() | ||
|
||
# If not a system agent, create custom agent | ||
return CustomAgent( | ||
llm=reasoning_llm, | ||
db=self.db, | ||
agent_id=agent_id, | ||
user_id=user_id | ||
) |
Oops, something went wrong.