-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Genie API wrapper #2
Changes from all commits
66ce723
ac93eae
188ac83
ded9ec2
0fa7b70
f6d3126
8457766
96ab57a
1b17d84
4cadcbc
0e58f54
86a74fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import logging | ||
import time | ||
from datetime import datetime | ||
from typing import Union | ||
|
||
import pandas as pd | ||
from databricks.sdk import WorkspaceClient | ||
|
||
|
||
def _parse_query_result(resp) -> Union[str, pd.DataFrame]: | ||
columns = resp["manifest"]["schema"]["columns"] | ||
header = [str(col["name"]) for col in columns] | ||
rows = [] | ||
output = resp["result"] | ||
if not output: | ||
return "EMPTY" | ||
|
||
for item in resp["result"]["data_typed_array"]: | ||
row = [] | ||
for column, value in zip(columns, item["values"]): | ||
type_name = column["type_name"] | ||
str_value = value.get("str", None) | ||
if str_value is None: | ||
row.append(None) | ||
continue | ||
|
||
if type_name in ["INT", "LONG", "SHORT", "BYTE"]: | ||
row.append(int(str_value)) | ||
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]: | ||
row.append(float(str_value)) | ||
elif type_name == "BOOLEAN": | ||
row.append(str_value.lower() == "true") | ||
elif type_name == "DATE": | ||
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) | ||
elif type_name == "TIMESTAMP": | ||
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) | ||
elif type_name == "BINARY": | ||
row.append(bytes(str_value, "utf-8")) | ||
else: | ||
row.append(str_value) | ||
|
||
rows.append(row) | ||
|
||
query_result = pd.DataFrame(rows, columns=header).to_string() | ||
return query_result | ||
|
||
|
||
class Genie: | ||
def __init__(self, space_id): | ||
self.space_id = space_id | ||
workspace_client = WorkspaceClient() | ||
self.genie = workspace_client.genie | ||
self.headers = { | ||
"Accept": "application/json", | ||
"Content-Type": "application/json", | ||
} | ||
|
||
def start_conversation(self, content): | ||
resp = self.genie._api.do( | ||
"POST", | ||
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation", | ||
body={"content": content}, | ||
headers=self.headers, | ||
) | ||
return resp | ||
|
||
def create_message(self, conversation_id, content): | ||
resp = self.genie._api.do( | ||
"POST", | ||
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages", | ||
body={"content": content}, | ||
headers=self.headers, | ||
) | ||
return resp | ||
|
||
def poll_for_result(self, conversation_id, message_id): | ||
def poll_result(): | ||
while True: | ||
resp = self.genie._api.do( | ||
prithvikannan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"GET", | ||
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", | ||
headers=self.headers, | ||
) | ||
if resp["status"] == "EXECUTING_QUERY": | ||
sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we getting the sql statements here only for the debug statement? Do end customers need this information? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eventually these should be part of the trace. currently adding extra traces with autologged traces are not supported, but this is coming soon. we'll use the SQL at that time. |
||
logging.debug(f"SQL: {sql}") | ||
return poll_query_results() | ||
elif resp["status"] == "COMPLETED": | ||
return next(r for r in resp["attachments"] if "text" in r)["text"]["content"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For genie, do we only care about the first response in the list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this time there's only one text attachment from genie so this is a safe assumption. |
||
else: | ||
logging.debug(f"Waiting...: {resp['status']}") | ||
time.sleep(5) | ||
|
||
def poll_query_results(): | ||
while True: | ||
resp = self.genie._api.do( | ||
"GET", | ||
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", | ||
headers=self.headers, | ||
)["statement_response"] | ||
state = resp["status"]["state"] | ||
if state == "SUCCEEDED": | ||
return _parse_query_result(resp) | ||
elif state == "RUNNING" or state == "PENDING": | ||
logging.debug("Waiting for query result...") | ||
time.sleep(5) | ||
else: | ||
logging.debug(f"No query result: {resp['state']}") | ||
return None | ||
|
||
return poll_result() | ||
|
||
def ask_question(self, question): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of adding a timeout feature here, so users are not stuck waiting forever in the polling loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question. i wonder if genie requests have some RPC level timeout on that side, but i think a client side timeout also makes sense. will update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually genie has a |
||
resp = self.start_conversation(question) | ||
# TODO (prithvi): return the query and the result | ||
return self.poll_for_result(resp["conversation_id"], resp["message_id"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from datetime import datetime | ||
from unittest.mock import patch | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from databricks_ai_bridge.genie import Genie, _parse_query_result | ||
|
||
|
||
@pytest.fixture | ||
def mock_workspace_client(): | ||
with patch("databricks_ai_bridge.genie.WorkspaceClient") as MockWorkspaceClient: | ||
mock_client = MockWorkspaceClient.return_value | ||
yield mock_client | ||
|
||
|
||
@pytest.fixture | ||
def genie(mock_workspace_client): | ||
return Genie(space_id="test_space_id") | ||
|
||
|
||
def test_start_conversation(genie, mock_workspace_client): | ||
mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"} | ||
response = genie.start_conversation("Hello") | ||
assert response == {"conversation_id": "123"} | ||
mock_workspace_client.genie._api.do.assert_called_once_with( | ||
"POST", | ||
"/api/2.0/genie/spaces/test_space_id/start-conversation", | ||
body={"content": "Hello"}, | ||
headers=genie.headers, | ||
) | ||
|
||
|
||
def test_create_message(genie, mock_workspace_client): | ||
mock_workspace_client.genie._api.do.return_value = {"message_id": "456"} | ||
response = genie.create_message("123", "Hello again") | ||
assert response == {"message_id": "456"} | ||
mock_workspace_client.genie._api.do.assert_called_once_with( | ||
"POST", | ||
"/api/2.0/genie/spaces/test_space_id/conversations/123/messages", | ||
body={"content": "Hello again"}, | ||
headers=genie.headers, | ||
) | ||
|
||
|
||
def test_poll_for_result_completed(genie, mock_workspace_client): | ||
mock_workspace_client.genie._api.do.side_effect = [ | ||
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, | ||
] | ||
result = genie.poll_for_result("123", "456") | ||
assert result == "Result" | ||
|
||
|
||
def test_poll_for_result_executing_query(genie, mock_workspace_client): | ||
mock_workspace_client.genie._api.do.side_effect = [ | ||
{"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, | ||
{ | ||
"statement_response": { | ||
"status": {"state": "SUCCEEDED"}, | ||
"manifest": {"schema": {"columns": []}}, | ||
"result": { | ||
"data_typed_array": [], | ||
}, | ||
} | ||
}, | ||
] | ||
result = genie.poll_for_result("123", "456") | ||
assert result == pd.DataFrame().to_string() | ||
|
||
|
||
def test_ask_question(genie, mock_workspace_client): | ||
mock_workspace_client.genie._api.do.side_effect = [ | ||
{"conversation_id": "123", "message_id": "456"}, | ||
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]}, | ||
] | ||
result = genie.ask_question("What is the meaning of life?") | ||
assert result == "Answer" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The answer cant be that simple 🤣 |
||
|
||
|
||
def test_parse_query_result_empty(): | ||
resp = {"manifest": {"schema": {"columns": []}}, "result": None} | ||
result = _parse_query_result(resp) | ||
assert result == "EMPTY" | ||
|
||
|
||
def test_parse_query_result_with_data(): | ||
resp = { | ||
"manifest": { | ||
"schema": { | ||
"columns": [ | ||
{"name": "id", "type_name": "INT"}, | ||
{"name": "name", "type_name": "STRING"}, | ||
{"name": "created_at", "type_name": "TIMESTAMP"}, | ||
] | ||
} | ||
}, | ||
"result": { | ||
"data_typed_array": [ | ||
{"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, | ||
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, | ||
] | ||
}, | ||
} | ||
result = _parse_query_result(resp) | ||
expected_df = pd.DataFrame( | ||
{ | ||
"id": [1, 2], | ||
"name": ["Alice", "Bob"], | ||
"created_at": [datetime(2023, 10, 1).date(), datetime(2023, 10, 2).date()], | ||
} | ||
) | ||
assert result == expected_df.to_string() | ||
|
||
|
||
def test_parse_query_result_with_null_values(): | ||
resp = { | ||
"manifest": { | ||
"schema": { | ||
"columns": [ | ||
{"name": "id", "type_name": "INT"}, | ||
{"name": "name", "type_name": "STRING"}, | ||
{"name": "created_at", "type_name": "TIMESTAMP"}, | ||
] | ||
} | ||
}, | ||
"result": { | ||
"data_typed_array": [ | ||
{"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]}, | ||
{"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]}, | ||
] | ||
}, | ||
} | ||
result = _parse_query_result(resp) | ||
expected_df = pd.DataFrame( | ||
{ | ||
"id": [1, 2], | ||
"name": [None, "Bob"], | ||
"created_at": [datetime(2023, 10, 1).date(), None], | ||
} | ||
) | ||
assert result == expected_df.to_string() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question, why are we using the raw
_api
request here? Can we not use this method directly? https://databricks-sdk-py.readthedocs.io/en/stable/workspace/dashboards/genie.html#databricks.sdk.service.dashboards.GenieAPI.start_conversationThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the polling mechanisms with the
start_conversation
API does not work. ill asked the genie team and they recommended to use the_api
for now. we can revisit when fixed.