diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py new file mode 100644 index 0000000..0797495 --- /dev/null +++ b/src/databricks_ai_bridge/genie.py @@ -0,0 +1,116 @@ +import logging +import time +from datetime import datetime +from typing import Union + +import pandas as pd +from databricks.sdk import WorkspaceClient + + +def _parse_query_result(resp) -> Union[str, pd.DataFrame]: + columns = resp["manifest"]["schema"]["columns"] + header = [str(col["name"]) for col in columns] + rows = [] + output = resp["result"] + if not output: + return "EMPTY" + + for item in resp["result"]["data_typed_array"]: + row = [] + for column, value in zip(columns, item["values"]): + type_name = column["type_name"] + str_value = value.get("str", None) + if str_value is None: + row.append(None) + continue + + if type_name in ["INT", "LONG", "SHORT", "BYTE"]: + row.append(int(str_value)) + elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]: + row.append(float(str_value)) + elif type_name == "BOOLEAN": + row.append(str_value.lower() == "true") + elif type_name == "DATE": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + elif type_name == "TIMESTAMP": + row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) + elif type_name == "BINARY": + row.append(bytes(str_value, "utf-8")) + else: + row.append(str_value) + + rows.append(row) + + query_result = pd.DataFrame(rows, columns=header).to_string() + return query_result + + +class Genie: + def __init__(self, space_id): + self.space_id = space_id + workspace_client = WorkspaceClient() + self.genie = workspace_client.genie + self.headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + + def start_conversation(self, content): + resp = self.genie._api.do( + "POST", + f"/api/2.0/genie/spaces/{self.space_id}/start-conversation", + body={"content": content}, + headers=self.headers, + ) + return resp + + def create_message(self, conversation_id, content): + resp = self.genie._api.do( + "POST", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages", + body={"content": content}, + headers=self.headers, + ) + return resp + + def poll_for_result(self, conversation_id, message_id): + def poll_result(): + while True: + resp = self.genie._api.do( + "GET", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", + headers=self.headers, + ) + if resp["status"] == "EXECUTING_QUERY": + sql = next(r for r in resp["attachments"] if "query" in r)["query"]["query"] + logging.debug(f"SQL: {sql}") + return poll_query_results() + elif resp["status"] == "COMPLETED": + return next(r for r in resp["attachments"] if "text" in r)["text"]["content"] + else: + logging.debug(f"Waiting...: {resp['status']}") + time.sleep(5) + + def poll_query_results(): + while True: + resp = self.genie._api.do( + "GET", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", + headers=self.headers, + )["statement_response"] + state = resp["status"]["state"] + if state == "SUCCEEDED": + return _parse_query_result(resp) + elif state == "RUNNING" or state == "PENDING": + logging.debug("Waiting for query result...") + time.sleep(5) + else: + logging.debug(f"No query result: {resp['state']}") + return None + + return poll_result() + + def ask_question(self, question): + resp = self.start_conversation(question) + # TODO (prithvi): return the query and the result + return self.poll_for_result(resp["conversation_id"], resp["message_id"]) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py new file mode 100644 index 0000000..edd720a --- /dev/null +++ b/tests/databricks_ai_bridge/test_genie.py @@ -0,0 +1,141 @@ +from datetime import datetime +from unittest.mock import patch + +import pandas as pd +import pytest + +from databricks_ai_bridge.genie import Genie, _parse_query_result + + +@pytest.fixture +def mock_workspace_client(): + with patch("databricks_ai_bridge.genie.WorkspaceClient") as MockWorkspaceClient: + mock_client = MockWorkspaceClient.return_value + yield mock_client + + +@pytest.fixture +def genie(mock_workspace_client): + return Genie(space_id="test_space_id") + + +def test_start_conversation(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.return_value = {"conversation_id": "123"} + response = genie.start_conversation("Hello") + assert response == {"conversation_id": "123"} + mock_workspace_client.genie._api.do.assert_called_once_with( + "POST", + "/api/2.0/genie/spaces/test_space_id/start-conversation", + body={"content": "Hello"}, + headers=genie.headers, + ) + + +def test_create_message(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.return_value = {"message_id": "456"} + response = genie.create_message("123", "Hello again") + assert response == {"message_id": "456"} + mock_workspace_client.genie._api.do.assert_called_once_with( + "POST", + "/api/2.0/genie/spaces/test_space_id/conversations/123/messages", + body={"content": "Hello again"}, + headers=genie.headers, + ) + + +def test_poll_for_result_completed(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, + ] + result = genie.poll_for_result("123", "456") + assert result == "Result" + + +def test_poll_for_result_executing_query(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, + { + "statement_response": { + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": {"columns": []}}, + "result": { + "data_typed_array": [], + }, + } + }, + ] + result = genie.poll_for_result("123", "456") + assert result == pd.DataFrame().to_string() + + +def test_ask_question(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"conversation_id": "123", "message_id": "456"}, + {"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]}, + ] + result = genie.ask_question("What is the meaning of life?") + assert result == "Answer" + + +def test_parse_query_result_empty(): + resp = {"manifest": {"schema": {"columns": []}}, "result": None} + result = _parse_query_result(resp) + assert result == "EMPTY" + + +def test_parse_query_result_with_data(): + resp = { + "manifest": { + "schema": { + "columns": [ + {"name": "id", "type_name": "INT"}, + {"name": "name", "type_name": "STRING"}, + {"name": "created_at", "type_name": "TIMESTAMP"}, + ] + } + }, + "result": { + "data_typed_array": [ + {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, + {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, + ] + }, + } + result = _parse_query_result(resp) + expected_df = pd.DataFrame( + { + "id": [1, 2], + "name": ["Alice", "Bob"], + "created_at": [datetime(2023, 10, 1).date(), datetime(2023, 10, 2).date()], + } + ) + assert result == expected_df.to_string() + + +def test_parse_query_result_with_null_values(): + resp = { + "manifest": { + "schema": { + "columns": [ + {"name": "id", "type_name": "INT"}, + {"name": "name", "type_name": "STRING"}, + {"name": "created_at", "type_name": "TIMESTAMP"}, + ] + } + }, + "result": { + "data_typed_array": [ + {"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]}, + {"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]}, + ] + }, + } + result = _parse_query_result(resp) + expected_df = pd.DataFrame( + { + "id": [1, 2], + "name": [None, "Bob"], + "created_at": [datetime(2023, 10, 1).date(), None], + } + ) + assert result == expected_df.to_string()