Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Genie API wrapper #2

Merged
merged 12 commits into from
Oct 24, 2024
116 changes: 116 additions & 0 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
import time
from datetime import datetime
from typing import Union

import pandas as pd
from databricks.sdk import WorkspaceClient


def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
rows = []
output = resp["result"]
if not output:
return "EMPTY"

for item in resp["result"]["data_typed_array"]:
row = []
for column, value in zip(columns, item["values"]):
type_name = column["type_name"]
str_value = value.get("str", None)
if str_value is None:
row.append(None)
continue

if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
row.append(int(str_value))
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
row.append(float(str_value))
elif type_name == "BOOLEAN":
row.append(str_value.lower() == "true")
elif type_name == "DATE":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "TIMESTAMP":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "BINARY":
row.append(bytes(str_value, "utf-8"))
else:
row.append(str_value)

rows.append(row)

query_result = pd.DataFrame(rows, columns=header).to_string()
return query_result


class Genie:
def __init__(self, space_id):
self.space_id = space_id
workspace_client = WorkspaceClient()
self.genie = workspace_client.genie
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}

def start_conversation(self, content):
resp = self.genie._api.do(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the polling mechanisms with the start_conversation API does not work. ill asked the genie team and they recommended to use the _api for now. we can revisit when fixed.

"POST",
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
body={"content": content},
headers=self.headers,
)
return resp

def create_message(self, conversation_id, content):
resp = self.genie._api.do(
"POST",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
body={"content": content},
headers=self.headers,
)
return resp

def poll_for_result(self, conversation_id, message_id):
def poll_result():
while True:
resp = self.genie._api.do(
prithvikannan marked this conversation as resolved.
Show resolved Hide resolved
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
headers=self.headers,
)
if resp["status"] == "EXECUTING_QUERY":
sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we getting the sql statements here only for the debug statement? Do end customers need this information?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually these should be part of the trace. currently adding extra traces with autologged traces are not supported, but this is coming soon. we'll use the SQL at that time.

logging.debug(f"SQL: {sql}")
return poll_query_results()
elif resp["status"] == "COMPLETED":
return next(r for r in resp["attachments"] if "text" in r)["text"]["content"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For genie, do we only care about the first response in the list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this time there's only one text attachment from genie so this is a safe assumption.

else:
logging.debug(f"Waiting...: {resp['status']}")
time.sleep(5)

def poll_query_results():
while True:
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
headers=self.headers,
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
return _parse_query_result(resp)
elif state == "RUNNING" or state == "PENDING":
logging.debug("Waiting for query result...")
time.sleep(5)
else:
logging.debug(f"No query result: {resp['state']}")
return None

return poll_result()

def ask_question(self, question):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of adding a timeout feature here, so users are not stuck waiting forever in the polling loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. i wonder if genie requests have some RPC level timeout on that side, but i think a client side timeout also makes sense. will update

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually genie has a CANCELED state that happens in the case of timeout, which will hit the else case and break the loop. i dont think we need the client side timeout.

resp = self.start_conversation(question)
# TODO (prithvi): return the query and the result
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
141 changes: 141 additions & 0 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from datetime import datetime
from unittest.mock import patch

import pandas as pd
import pytest

from databricks_ai_bridge.genie import Genie, _parse_query_result


@pytest.fixture
def mock_workspace_client():
with patch("databricks_ai_bridge.genie.WorkspaceClient") as MockWorkspaceClient:
mock_client = MockWorkspaceClient.return_value
yield mock_client


@pytest.fixture
def genie(mock_workspace_client):
return Genie(space_id="test_space_id")


def test_start_conversation(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"}
response = genie.start_conversation("Hello")
assert response == {"conversation_id": "123"}
mock_workspace_client.genie._api.do.assert_called_once_with(
"POST",
"/api/2.0/genie/spaces/test_space_id/start-conversation",
body={"content": "Hello"},
headers=genie.headers,
)


def test_create_message(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.return_value = {"message_id": "456"}
response = genie.create_message("123", "Hello again")
assert response == {"message_id": "456"}
mock_workspace_client.genie._api.do.assert_called_once_with(
"POST",
"/api/2.0/genie/spaces/test_space_id/conversations/123/messages",
body={"content": "Hello again"},
headers=genie.headers,
)


def test_poll_for_result_completed(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]},
]
result = genie.poll_for_result("123", "456")
assert result == "Result"


def test_poll_for_result_executing_query(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]},
{
"statement_response": {
"status": {"state": "SUCCEEDED"},
"manifest": {"schema": {"columns": []}},
"result": {
"data_typed_array": [],
},
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_string()


def test_ask_question(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "123", "message_id": "456"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]},
]
result = genie.ask_question("What is the meaning of life?")
assert result == "Answer"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The answer cant be that simple 🤣



def test_parse_query_result_empty():
resp = {"manifest": {"schema": {"columns": []}}, "result": None}
result = _parse_query_result(resp)
assert result == "EMPTY"


def test_parse_query_result_with_data():
resp = {
"manifest": {
"schema": {
"columns": [
{"name": "id", "type_name": "INT"},
{"name": "name", "type_name": "STRING"},
{"name": "created_at", "type_name": "TIMESTAMP"},
]
}
},
"result": {
"data_typed_array": [
{"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]},
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]},
]
},
}
result = _parse_query_result(resp)
expected_df = pd.DataFrame(
{
"id": [1, 2],
"name": ["Alice", "Bob"],
"created_at": [datetime(2023, 10, 1).date(), datetime(2023, 10, 2).date()],
}
)
assert result == expected_df.to_string()


def test_parse_query_result_with_null_values():
resp = {
"manifest": {
"schema": {
"columns": [
{"name": "id", "type_name": "INT"},
{"name": "name", "type_name": "STRING"},
{"name": "created_at", "type_name": "TIMESTAMP"},
]
}
},
"result": {
"data_typed_array": [
{"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]},
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]},
]
},
}
result = _parse_query_result(resp)
expected_df = pd.DataFrame(
{
"id": [1, 2],
"name": [None, "Bob"],
"created_at": [datetime(2023, 10, 1).date(), None],
}
)
assert result == expected_df.to_string()
Loading