Skip to content

Commit

Permalink
supervisour routing
Browse files Browse the repository at this point in the history
  • Loading branch information
dhirenmathur committed Dec 18, 2024
1 parent 1e41250 commit 3c05de4
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 43 deletions.
166 changes: 156 additions & 10 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import AsyncGenerator, List
from typing import AsyncGenerator, Dict, Any, List, Optional, TypedDict
from langgraph.types import StreamWriter

from fastapi import HTTPException
from langgraph.graph import END, StateGraph
from langgraph.types import Command
from langchain.prompts import ChatPromptTemplate
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from uuid6 import uuid7

from app.modules.code_provider.code_provider_service import CodeProviderService
Expand All @@ -30,7 +35,10 @@
MessageResponse,
NodeContext,
)

from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService
from app.modules.intelligence.agents.agents_service import AgentsService
from app.modules.intelligence.agents.agent_factory import AgentFactory
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
Expand All @@ -47,24 +55,161 @@


class ConversationServiceError(Exception):
"""Base exception class for ConversationService errors."""
pass


class ConversationNotFoundError(ConversationServiceError):
"""Raised when a conversation is not found."""
pass


class MessageNotFoundError(ConversationServiceError):
"""Raised when a message is not found."""
pass


class AccessTypeNotFoundError(ConversationServiceError):
"""Raised when an access type is not found."""
pass


class AccessTypeReadError(ConversationServiceError):
"""Raised when an access type is read-only."""
pass


from langgraph.graph import END, StateGraph
from langgraph.types import Command
from typing import AsyncGenerator, Dict, Any

class SimplifiedAgentSupervisor:
def __init__(self, db, provider_service):
self.db = db
self.provider_service = provider_service
self.agents = {}
self.classifier = None
self.agents_service = AgentsService(db)
self.agent_factory = AgentFactory(db, provider_service)

async def initialize(self, user_id: str):
# Get available agents using AgentsService
available_agents = await self.agents_service.list_available_agents(
current_user={"user_id": user_id},
list_system_agents=True
)

# Create agent instances dictionary
self.agents = {
agent.id: self.agent_factory.get_agent(agent.id, user_id)
for agent in available_agents
}

self.llm = self.provider_service.get_small_llm(user_id)

# Enhanced classifier prompt with agent descriptions
self.classifier_prompt = """
Given the user query, determine which agent should handle it based on their specialties:
Query: {query}
Available agents and their specialties:
{agent_descriptions}
Return ONLY the agent id and confidence score in format: agent_id|confidence
Example: debugging_agent|0.85
"""

# Format agent descriptions for the prompt
self.agent_descriptions = "\n".join([
f"- {agent.id}: {agent.description}"
for agent in available_agents
])
class State(TypedDict):
query: str
project_id: str
conversation_id: str
response: Optional[str]
agent_id: Optional[str]
user_id: str
node_ids: List[NodeContext]

async def classifier_node(self, state: State) -> Command:
"""Classifies the query and routes to appropriate agent"""
if not state.get("query"):
return Command(update={"response": "No query provided"}, goto=END)

# Classification using LLM with enhanced prompt
prompt = self.classifier_prompt.format(
query=state["query"],
agent_descriptions=self.agent_descriptions
)
response = await self.llm.ainvoke(prompt)

# Parse response
try:
agent_id, confidence = response.content.split("|")
confidence = float(confidence)
except (ValueError, TypeError):
return Command(
update={"response": "Error in classification format"},
goto=END
)

if confidence < 0.5 or agent_id not in self.agents:
return Command(
update={"agent_id":state["agent_id"]},
goto=state["agent_id"]
)

return Command(
update={"agent_id": agent_id},
goto=agent_id
)

async def agent_node(self, state: State, writer: StreamWriter):
"""Creates a node function for a specific agent"""
agent = self.agents[state["agent_id"]]
async for chunk in agent.run(
query=state["query"],
project_id=state["project_id"],
conversation_id=state["conversation_id"],
user_id=state["user_id"],
node_ids=state["node_ids"]
):
if isinstance(chunk, str):
writer(chunk)




def build_graph(self) -> StateGraph:
"""Builds the graph with classifier and agent nodes"""
builder = StateGraph(self.State)

# Add classifier as entry point
builder.add_node("classifier", self.classifier_node)
#builder.add_edge("classifier", END)

# # Add agent nodes
#node_func = await self.agent_node(self.State, StreamWriter)
for agent_id in self.agents:
builder.add_node(agent_id, self.agent_node)
builder.add_edge(agent_id, END)

builder.set_entry_point("classifier")
return builder.compile()

async def process_query(self, query: str, project_id: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], agent_id: str) -> AsyncGenerator[Dict[str, Any], None]:
"""Main method to process queries"""
state = {
"query": query,
"project_id": project_id,
"conversation_id": conversation_id,
"response": None,
"user_id": user_id,
"node_ids": node_ids,
"agent_id": agent_id
}

graph = self.build_graph()
async for chunk in graph.astream(state, stream_mode="custom"):
yield chunk

class ConversationService:
def __init__(
Expand Down Expand Up @@ -450,7 +595,8 @@ async def _generate_and_stream_ai_response(

agent_id = conversation.agent_ids[0]
project_id = conversation.project_ids[0] if conversation.project_ids else None

supervisor = SimplifiedAgentSupervisor(self.sql_db, self.provider_service)
await supervisor.initialize(user_id)
try:
agent = self.agent_injector_service.get_agent(agent_id)

Expand All @@ -466,8 +612,8 @@ async def _generate_and_stream_ai_response(
yield response
else:
# For other agents that support streaming
async for chunk in agent.run(
query, project_id, user_id, conversation.id, node_ids
async for chunk in supervisor.process_query(
query, project_id, conversation.id, user_id, node_ids, agent_id
):
yield chunk

Expand Down
80 changes: 80 additions & 0 deletions app/modules/intelligence/agents/agent_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List
from langchain.schema import BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field

from app.modules.conversations.message.message_schema import MessageResponse
from app.modules.intelligence.prompts.classification_prompts import ClassificationResult

class AgentClassification(BaseModel):
agent_id: str = Field(..., description="ID of the agent that should handle the query")
confidence: float = Field(..., description="Confidence score between 0 and 1")
reasoning: str = Field(..., description="Reasoning behind the agent selection")

class AgentClassifier:
def __init__(self, llm, available_agents):
self.llm = llm
self.available_agents = available_agents
self.parser = PydanticOutputParser(pydantic_object=AgentClassification)

def create_prompt(self) -> str:

return """You are an expert agent router.
User's query: {query}
Conversation history: {history}
Based on the user's query and conversation history,
select the most appropriate agent from the following options:
Available Agents:
{agents_desc}
Analyze the query and select the agent that best matches the user's needs.
Consider:
1. The specific task or question type
2. Required expertise
3. Context from conversation history
4. Any explicit agent requests
{format_instructions}
"""

async def classify(self, messages: List[MessageResponse]) -> AgentClassification:
"""Classify the conversation and determine which agent should handle it"""

if not messages:
return AgentClassification(
agent_id=self.available_agents[0].id, # Default to first agent
confidence=0.0,
reasoning="No messages to classify"
)

# Format agent descriptions
agents_desc = "\n".join([
f"{i+1}. {agent.id}: {agent.description}"
for i, agent in enumerate(self.available_agents)
])

# Get the last message and up to 10 messages of history
last_message = messages[-1].content if messages else ""
history = [msg.content for msg in messages[-10:]] if len(messages) > 1 else []

inputs = {
"query": last_message,
"history": history,
"agents_desc": agents_desc,
"format_instructions": self.parser.get_format_instructions()
}

# Rest of the classification logic...

prompt = ChatPromptTemplate.from_template(self.create_prompt())

chain = prompt | self.llm | self.parser

result = await chain.ainvoke(
inputs
)

return result
55 changes: 55 additions & 0 deletions app/modules/intelligence/agents/agent_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Any
from sqlalchemy.orm import Session

from app.modules.intelligence.provider.provider_service import AgentType, ProviderService
from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import CodeChangesChatAgent
from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent
from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent
from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent
from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import IntegrationTestChatAgent
from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent
from app.modules.intelligence.agents.chat_agents.code_gen_chat_agent import CodeGenerationChatAgent
from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent

class AgentFactory:
def __init__(self, db: Session, provider_service: ProviderService):
self.db = db
self.provider_service = provider_service
self._agent_cache: Dict[str, Any] = {}

def get_agent(self, agent_id: str, user_id: str) -> Any:
"""Get or create an agent instance"""
cache_key = f"{agent_id}_{user_id}"

if cache_key in self._agent_cache:
return self._agent_cache[cache_key]

mini_llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN)
reasoning_llm = self.provider_service.get_large_llm(agent_type=AgentType.LANGCHAIN)

agent = self._create_agent(agent_id, mini_llm, reasoning_llm, user_id)
self._agent_cache[cache_key] = agent
return agent

def _create_agent(self, agent_id: str, mini_llm, reasoning_llm, user_id: str) -> Any:
"""Create a new agent instance"""
agent_map = {
"debugging_agent": lambda: DebuggingChatAgent(mini_llm, reasoning_llm, self.db),
"codebase_qna_agent": lambda: QNAChatAgent(mini_llm, reasoning_llm, self.db),
"unit_test_agent": lambda: UnitTestAgent(mini_llm, reasoning_llm, self.db),
"integration_test_agent": lambda: IntegrationTestChatAgent(mini_llm, reasoning_llm, self.db),
"code_changes_agent": lambda: CodeChangesChatAgent(mini_llm, reasoning_llm, self.db),
"LLD_agent": lambda: LLDChatAgent(mini_llm, reasoning_llm, self.db),
"code_generation_agent": lambda: CodeGenerationChatAgent(mini_llm, reasoning_llm, self.db),
}

if agent_id in agent_map:
return agent_map[agent_id]()

# If not a system agent, create custom agent
return CustomAgent(
llm=reasoning_llm,
db=self.db,
agent_id=agent_id,
user_id=user_id
)
Loading

0 comments on commit 3c05de4

Please sign in to comment.