Skip to content

Commit

Permalink
set results to "" empty string if the response from genie is None
Browse files Browse the repository at this point in the history
  • Loading branch information
Sri Tikkireddy committed Nov 13, 2024
1 parent 83522f5 commit c3dbe76
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
17 changes: 9 additions & 8 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,20 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
return {"messages": [AIMessage(content="")]}


class GenieToolInput(BaseModel):
question: str = Field(description="question to ask the agent")
summarized_chat_history: str = Field(
description="summarized chat history to provide the agent context of what may have been talked about. "
"Say 'No history' if there is no history to provide."
)


def GenieTool(genie_space_id: str, genie_agent_name: str, genie_space_description: str):
from langchain_core.tools import BaseTool
from langchain_core.callbacks.manager import CallbackManagerForToolRun

genie = Genie(genie_space_id)

class GenieToolInput(BaseModel):
question: str = Field(description="question to ask the agent")
summarized_chat_history: str = Field(
description="summarized chat history to provide the agent context of what may have been talked about. "
"Say 'No history' if there is no history to provide."
)

class GenieQuestionToolWithTrace(BaseTool):
name: str = f"{genie_agent_name}_details"
description: str = genie_space_description
Expand All @@ -70,7 +71,7 @@ def _run(
response = genie.ask_question_with_details(message)
if response:
return response.response, response
return "no results from room", None
return "", None

tool_with_details = GenieQuestionToolWithTrace()

Expand Down
43 changes: 43 additions & 0 deletions integrations/langchain/tests/test_genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from langchain_core.messages import AIMessage

from databricks_ai_bridge.genie import GenieResult
from databricks_langchain.genie import (
GenieAgent,
_concat_messages_array,
_query_genie_as_agent,
GenieTool,
GenieToolInput,
)


Expand Down Expand Up @@ -69,3 +72,43 @@ def test_create_genie_agent(MockRunnableLambda):

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()


@patch("databricks_langchain.genie.Genie")
def test_create_genie_tool(MockGenie):
mock_genie = MockGenie.return_value
mock_genie.ask_question_with_details.return_value = GenieResult(
description=None, sql_query=None, response="It is sunny."
)

agent = GenieTool("space-id", "Genie", "Description")

assert agent.name == "Genie"
assert agent.args_schema == GenieToolInput
assert agent.description == "Description"
assert (
agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"})
== "It is sunny."
)

assert mock_genie.ask_question_with_details.call_count == 1


@patch("databricks_langchain.genie.Genie")
def test_create_genie_tool_no_response(MockGenie):
mock_genie = MockGenie.return_value
mock_genie.ask_question_with_details.return_value = None

agent = GenieTool("space-id", "Genie", "Description")

assert agent.name == "Genie"
assert agent.args_schema == GenieToolInput
assert agent.description == "Description"
assert (
agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"})
== ""
)

assert mock_genie.ask_question_with_details.call_count == 1


0 comments on commit c3dbe76

Please sign in to comment.