Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add affected, inserted, updated, deleted row to DatabricksAdapterResponse #883

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from databricks.sql.client import Connection as DatabricksSQLConnection
from databricks.sql.client import Cursor as DatabricksSQLCursor
from databricks.sql.exc import Error
from databricks.sql.types import Row
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.adapters.contracts.connection import (
DEFAULT_QUERY_COMMENT,
Expand Down Expand Up @@ -179,12 +180,21 @@ def dbr_version(self) -> tuple[int, int]:
return self._dbr_version


@dataclass
class DatabricksQueryImpact:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by the other commenter, 'statistics' are used elsewhere for a similar concept, so could be used here.

num_affected_rows: Optional[int] = None
num_updated_rows: Optional[int] = None
num_deleted_rows: Optional[int] = None
num_inserted_rows: Optional[int] = None


class DatabricksSQLCursorWrapper:
"""Wrap a Databricks SQL cursor in a way that no-ops transactions"""

_cursor: DatabricksSQLCursor
_user_agent: str
_creds: DatabricksCredentials
_cache_fetchone: Optional[Row] = None

def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str):
self._cursor = cursor
Expand All @@ -207,17 +217,52 @@ def close(self) -> None:
except Error as exc:
logger.warning(CursorCloseError(self._cursor, exc))

def fetchall(self) -> Sequence[tuple]:
def fetchall(self) -> Sequence[Row]:
return self._cursor.fetchall()

def fetchone(self) -> Optional[tuple]:
def query_impact(self) -> DatabricksQueryImpact:
Copy link

@nicor88 nicor88 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

little nit on the name. Maybe you can call it get_query_statistics - I'm not a native English speaker, but in other query engines the term statistics is used for this type of information.
If you like this name more, you should consider to replace impact with statistics everywhere for consistency.

"""Get the number of rows affected by the last query.

Delta returns for merge, update and insert commands a single row containing:
- num_affected_rows: the number of rows affected by the query
- num_updated_rows: the number of rows updated by the query
- num_deleted_rows: the number of rows deleted by the query
- num_inserted_rows: the number of rows inserted by the query

This method attempts to retrieve it from the last query, while caching the result to make
sure it does not interfere with the fetchone method.
"""
if not self._cache_fetchone:
try:
# Cache the result to be able to return it if fetchone is called later
self._cache_fetchone = self._cursor.fetchone()
except Error:
return DatabricksQueryImpact()

if not self._cache_fetchone:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition reachable? When you fetchone(), is it ever empty?

return DatabricksQueryImpact()

# Cast the result to check that is indeed query metadata
try:
return DatabricksQueryImpact(**self._cache_fetchone.asDict())
except TypeError:
return DatabricksQueryImpact()

def fetchone(self) -> Optional[Row]:
if self._cache_fetchone:
# If `fetchone` result was cached by `query_metadata`, return it and invalidate it
row = self._cache_fetchone
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we invalidate the cache, but in the callsite reassign the cache to the same value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like you would just want:

def fetchone(self) -> Optional[Row]:
    if not self._cache_fetchone:
        self._cache_fetchone = self._cursor.fetchone()
    return self._cache_fetchone

self._cache_fetchone = None
return row
return self._cursor.fetchone()

def fetchmany(self, size: int) -> Sequence[tuple]:
def fetchmany(self, size: int) -> Sequence[Row]:
return self._cursor.fetchmany(size)

def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
# print(f"execute: {sql}")
# Invalidate fetchone cache
self._cache_fetchone = None
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
if bindings is not None:
Expand Down Expand Up @@ -300,6 +345,9 @@ def _get_comment_macro(self) -> Optional[str]:
@dataclass
class DatabricksAdapterResponse(AdapterResponse):
query_id: str = ""
rows_updated: Optional[int] = None
rows_deleted: Optional[int] = None
rows_inserted: Optional[int] = None


@dataclass(init=False)
Expand Down Expand Up @@ -531,7 +579,7 @@ def execute(
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
try:
response = self.get_response(cursor)
response = self.get_response(cursor, include_impact=(not fetch))
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
Expand Down Expand Up @@ -693,15 +741,32 @@ def exponential_backoff(attempt: int) -> int:
)

@classmethod
def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse:
def get_response(
cls, cursor: DatabricksSQLCursorWrapper, include_impact: bool = False
) -> DatabricksAdapterResponse:
_query_id = getattr(cursor, "hex_query_id", None)
if cursor is None:
logger.debug("No cursor was provided. Query ID not available.")
query_id = "N/A"
else:
query_id = _query_id
message = "OK"
return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore

response = DatabricksAdapterResponse(
_message=message,
query_id=query_id, # type: ignore
)

# If some query metadata are available, add them to the adapter response
if include_impact:
query_impact = cursor.query_impact()
logger.debug(query_impact)
response.rows_affected = query_impact.num_affected_rows
response.rows_inserted = query_impact.num_inserted_rows
response.rows_updated = query_impact.num_updated_rows
response.rows_deleted = query_impact.num_deleted_rows

return response


class ExtendedSessionConnectionManager(DatabricksConnectionManager):
Expand Down