Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
Get sync chat function working

Refactor SearchAgent

Fix up mocking and SearchAgent test

Get chat and sync chat tests passing

Add additional tests

Add agent handler unit tests

Refactor chat package/module layout
Get all tests passing
Add tests for the S3 Checkpointer

Full test coverage for metrics callback

Extract SearchWorkflow class for readability and testing

Full test coverage for EventConfig

Add unit tests and fixture for the complex real-world keyword fields

Fix up ruff check

Full test coverage for tools

Full test coverage for setup.py with a slight refactor

Full test coverage for setup.py and OpenSearchNeuralSearch

Full test coverage for WebSocket class

Don't automatically load secrets on module import

Switch test runner to pytest
Leave individual test files as unittest

Add tests for core.secrets

Add test to make sure the chat handler writes metrics

Full test coverage for s3_checkpointer.py (marked as a slow test) and more coverage for OpenSearchNeuralSearch
  • Loading branch information
charlesLoder authored and mbklein committed Dec 19, 2024
1 parent da4505b commit 7b9daf3
Show file tree
Hide file tree
Showing 60 changed files with 2,757 additions and 1,337 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ jobs:
run: ruff check .
- name: Run tests
run: |
coverage run --include='src/**/*' -m unittest
coverage run --include='src/**/*' -m pytest -m ""
coverage report
env:
AWS_REGION: us-east-1
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ test-node: deps-node
deps-python:
cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt
cover-python: deps-python
cd chat && coverage run --source=src -m unittest -v && coverage report --skip-empty
cd chat && coverage run --source=src -m pytest -v && coverage report --skip-empty
cover-html-python: deps-python
cd chat && coverage run --source=src -m unittest -v && coverage html --skip-empty
cd chat && coverage run --source=src -m pytest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
style-python-fix: deps-python
cd chat && ruff check --fix .
test-python: deps-python
cd chat && __SKIP_SECRETS__=true PYTHONPATH=src:test python -m unittest discover -v
cd chat && pytest
python-version:
cd chat && python --version
build: .aws-sam/build.toml
Expand Down
266 changes: 10 additions & 256 deletions chat-playground/playground.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions chat/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
addopts = -m "not slow"
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,33 @@
from langchain_core.messages.tool import ToolMessage
import json

class MetricsHandler(BaseCallbackHandler):
class MetricsCallbackHandler(BaseCallbackHandler):
def __init__(self, *args, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return

if not response.generations or not response.generations[0]:
return

for generation in response.generations[0]:
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
if k not in self.accumulator:
self.accumulator[k] = v
else:
self.accumulator[k] += v
if generation.text != "":
self.answers.append(generation.text)

if not hasattr(generation, 'message') or generation.message is None:
continue

metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional

from websocket import Websocket
from core.websocket import Websocket

from json.decoder import JSONDecodeError
from langchain_core.callbacks import BaseCallbackHandler
Expand All @@ -19,7 +19,7 @@ def deserialize_input(input_str):
except JSONDecodeError:
return input_str

class AgentHandler(BaseCallbackHandler):
class SocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict[str, Any]):
if socket is None:
raise ValueError("Socket not provided to agent callback handler")
Expand Down Expand Up @@ -56,12 +56,9 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
case "discover_fields":
pass
case "search":
try:
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
case _:
print(f"Unhandled tool_end message: {output}")

Expand Down
72 changes: 38 additions & 34 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,85 @@
import os

from typing import Literal, List

from agent.s3_saver import S3Saver, delete_checkpoints
from agent.tools import aggregate, discover_fields, search
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.language_models.chat_models import BaseModel
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from core.setup import checkpoint_saver

DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field. Do not include intermediate messages explaining your process.
"""

class SearchWorkflow:
def __init__(self, model: BaseModel, system_message: str):
self.model = model
self.system_message = system_message

def should_continue(self, state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return END

def call_model(self, state: MessagesState):
messages = [SystemMessage(content=self.system_message)] + state["messages"]
response: BaseMessage = self.model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}


class SearchAgent:
def __init__(
self,
model: BaseModel,
*,
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"),
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs):

self.checkpoint_bucket = checkpoint_bucket

**kwargs
):
tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)
model = ChatBedrock(**kwargs).bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return END
try:
model = model.bind_tools(tools)
except NotImplementedError:
pass


# Define the function that calls the model
def call_model(state: MessagesState):
messages = [SystemMessage(content=system_message)] + state["messages"]
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
return {"messages": [response]}
self.workflow_logic = SearchWorkflow(model=model, system_message=system_message)

# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("agent", self.workflow_logic.call_model)
workflow.add_node("tools", tool_node)

# Set the entrypoint as `agent`
workflow.add_edge(START, "agent")

# Add a conditional edge
workflow.add_conditional_edges("agent", should_continue)
workflow.add_conditional_edges("agent", self.workflow_logic.should_continue)

# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip")
self.search_agent = workflow.compile(checkpointer=checkpointer)
self.checkpointer = checkpoint_saver()
self.search_agent = workflow.compile(checkpointer=self.checkpointer)

def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
if forget:
delete_checkpoints(self.checkpoint_bucket, ref)

self.checkpointer.delete_checkpoints(ref)
return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={"configurable": {"thread_id": ref}, "callbacks": callbacks},
**kwargs
)
)
2 changes: 1 addition & 1 deletion chat/src/agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from langchain_core.tools import tool
from setup import opensearch_vector_store
from core.setup import opensearch_vector_store

def get_keyword_fields(properties, prefix=''):
"""
Expand Down
36 changes: 0 additions & 36 deletions chat/src/content_handler.py

This file was deleted.

File renamed without changes.
File renamed without changes.
28 changes: 3 additions & 25 deletions chat/src/event_config.py → chat/src/core/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from langchain_core.prompts import ChatPromptTemplate

from helpers.apitoken import ApiToken
from helpers.prompts import prompt_template
from websocket import Websocket
from core.apitoken import ApiToken
from core.prompts import prompt_template
from core.websocket import Websocket
from uuid import uuid4

CHAIN_TYPE = "stuff"
Expand Down Expand Up @@ -92,20 +92,6 @@ def _get_temperature(self):
def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def debug_message(self):
return {
"type": "debug",
"message": {
"k": self.k,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"size": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self, socket=None):
if socket is None:
connection_id = self.request_context.get("connectionId")
Expand All @@ -120,11 +106,3 @@ def setup_websocket(self, socket=None):
def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
File renamed without changes.
12 changes: 3 additions & 9 deletions chat/src/secrets.py → chat/src/core/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
import json
import os

def load_secrets():
SecretsPath = os.getenv('SECRETS_PATH')
def load_secrets(SecretsPath=os.getenv('SECRETS_PATH')):
EnvironmentMap = [
['API_TOKEN_SECRET', 'dcapi', 'api_token_secret'],
['OPENSEARCH_ENDPOINT', 'index', 'endpoint'],
['OPENSEARCH_MODEL_ID', 'index', 'embedding_model']
]

client = boto3.client("secretsmanager")
client = boto3.client("secretsmanager", region_name=os.getenv('AWS_REGION', 'us-east-1'))
response = client.batch_get_secret_value(SecretIdList=[
f'{SecretsPath}/config/dcapi',
f'{SecretsPath}/infrastructure/index',
f'{SecretsPath}/config/dcapi'
f'{SecretsPath}/infrastructure/index'
])

secrets = {
Expand All @@ -29,7 +27,3 @@ def load_secrets():
if var not in os.environ and value is not None:
os.environ[var] = value

os.environ['__SKIP_SECRETS__'] = 'true'

if not os.getenv('__SKIP_SECRETS__'):
load_secrets()
18 changes: 15 additions & 3 deletions chat/src/setup.py → chat/src/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from persistence.s3_checkpointer import S3Checkpointer
from search.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel
from langgraph.checkpoint.base import BaseCheckpointSaver
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from urllib.parse import urlparse
import os
import boto3

def chat_model(**kwargs) -> BaseModel:
return ChatBedrock(**kwargs)

def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs)

def prefix(value):
env_prefix = os.getenv("ENV_PREFIX")
Expand All @@ -21,7 +31,8 @@ def opensearch_endpoint():
return endpoint


def opensearch_client(region_name=os.getenv("AWS_REGION")):
def opensearch_client(region_name=None):
region_name = region_name or os.getenv("AWS_REGION") # Evaluate at runtime
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(
region=region_name,
Expand All @@ -38,7 +49,8 @@ def opensearch_client(region_name=os.getenv("AWS_REGION")):
)


def opensearch_vector_store(region_name=os.getenv("AWS_REGION")):
def opensearch_vector_store(region_name=None):
region_name = region_name or os.getenv("AWS_REGION") # Evaluate at runtime
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(
region=region_name,
Expand Down
Loading

0 comments on commit 7b9daf3

Please sign in to comment.