-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Get sync chat function working Refactor SearchAgent Fix up mocking and SearchAgent test Get chat and sync chat tests passing Add additional tests Add agent handler unit tests Refactor chat package/module layout Get all tests passing Add tests for the S3 Checkpointer Full test coverage for metrics callback Extract SearchWorkflow class for readability and testing Full test coverage for EventConfig Add unit tests and fixture for the complex real-world keyword fields Fix up ruff check Full test coverage for tools Full test coverage for setup.py with a slight refactor Full test coverage for setup.py and OpenSearchNeuralSearch Full test coverage for WebSocket class Don't automatically load secrets on module import Switch test runner to pytest Leave individual test files as unittest Add tests for core.secrets Add test to make sure the chat handler writes metrics Full test coverage for s3_checkpointer.py (marked as a slow test) and more coverage for OpenSearchNeuralSearch
- Loading branch information
1 parent
da4505b
commit 7b9daf3
Showing
60 changed files
with
2,757 additions
and
1,337 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[pytest] | ||
addopts = -m "not slow" | ||
markers = | ||
slow: marks tests as slow (deselect with '-m "not slow"') |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,85 @@ | ||
import os | ||
|
||
from typing import Literal, List | ||
|
||
from agent.s3_saver import S3Saver, delete_checkpoints | ||
from agent.tools import aggregate, discover_fields, search | ||
from langchain_aws import ChatBedrock | ||
from langchain_core.messages import HumanMessage | ||
from langchain_core.messages.base import BaseMessage | ||
from langchain_core.language_models.chat_models import BaseModel | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages.system import SystemMessage | ||
from langgraph.graph import END, START, StateGraph, MessagesState | ||
from langgraph.prebuilt import ToolNode | ||
from core.setup import checkpoint_saver | ||
|
||
DEFAULT_SYSTEM_MESSAGE = """ | ||
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that | ||
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown | ||
links using the document's canonical_link field. Do not include intermediate messages explaining your process. | ||
""" | ||
|
||
class SearchWorkflow: | ||
def __init__(self, model: BaseModel, system_message: str): | ||
self.model = model | ||
self.system_message = system_message | ||
|
||
def should_continue(self, state: MessagesState) -> Literal["tools", END]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If the LLM makes a tool call, then we route to the "tools" node | ||
if last_message.tool_calls: | ||
return "tools" | ||
# Otherwise, we stop (reply to the user) | ||
return END | ||
|
||
def call_model(self, state: MessagesState): | ||
messages = [SystemMessage(content=self.system_message)] + state["messages"] | ||
response: BaseMessage = self.model.invoke(messages) | ||
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
|
||
class SearchAgent: | ||
def __init__( | ||
self, | ||
model: BaseModel, | ||
*, | ||
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"), | ||
system_message: str = DEFAULT_SYSTEM_MESSAGE, | ||
**kwargs): | ||
|
||
self.checkpoint_bucket = checkpoint_bucket | ||
|
||
**kwargs | ||
): | ||
tools = [discover_fields, search, aggregate] | ||
tool_node = ToolNode(tools) | ||
model = ChatBedrock(**kwargs).bind_tools(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(state: MessagesState) -> Literal["tools", END]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If the LLM makes a tool call, then we route to the "tools" node | ||
if last_message.tool_calls: | ||
return "tools" | ||
# Otherwise, we stop (reply to the user) | ||
return END | ||
try: | ||
model = model.bind_tools(tools) | ||
except NotImplementedError: | ||
pass | ||
|
||
|
||
# Define the function that calls the model | ||
def call_model(state: MessagesState): | ||
messages = [SystemMessage(content=system_message)] + state["messages"] | ||
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
# We return a list, because this will get added to the existing list | ||
# if socket is not none and the response content is not an empty string | ||
return {"messages": [response]} | ||
self.workflow_logic = SearchWorkflow(model=model, system_message=system_message) | ||
|
||
# Define a new graph | ||
workflow = StateGraph(MessagesState) | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", call_model) | ||
workflow.add_node("agent", self.workflow_logic.call_model) | ||
workflow.add_node("tools", tool_node) | ||
|
||
# Set the entrypoint as `agent` | ||
workflow.add_edge(START, "agent") | ||
|
||
# Add a conditional edge | ||
workflow.add_conditional_edges("agent", should_continue) | ||
workflow.add_conditional_edges("agent", self.workflow_logic.should_continue) | ||
|
||
# Add a normal edge from `tools` to `agent` | ||
workflow.add_edge("tools", "agent") | ||
|
||
checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip") | ||
self.search_agent = workflow.compile(checkpointer=checkpointer) | ||
self.checkpointer = checkpoint_saver() | ||
self.search_agent = workflow.compile(checkpointer=self.checkpointer) | ||
|
||
def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs): | ||
if forget: | ||
delete_checkpoints(self.checkpoint_bucket, ref) | ||
|
||
self.checkpointer.delete_checkpoints(ref) | ||
return self.search_agent.invoke( | ||
{"messages": [HumanMessage(content=question)]}, | ||
config={"configurable": {"thread_id": ref}, "callbacks": callbacks}, | ||
**kwargs | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.