Skip to content

Commit

Permalink
feat(python): support use of KùzuDB via pl.read_database (#14822)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Mar 3, 2024
1 parent a61d5d6 commit baacf3d
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 20 deletions.
29 changes: 22 additions & 7 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@


class _ArrowDriverProperties_(TypedDict):
fetch_all: str # name of the method that fetches all arrow data
fetch_batches: str | None # name of the method that fetches arrow data in batches
exact_batch_size: bool | None # whether indicated batch size is respected exactly
repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator)
# name of the method that fetches all arrow data; tuple form
# calls the fetch_all method with the given chunk size (int)
fetch_all: str | tuple[str, int]
# name of the method that fetches arrow data in batches
fetch_batches: str | None
# indicate whether the given batch size is respected exactly
exact_batch_size: bool | None
# repeat batch calls (if False, the batch call is a generator)
repeat_batch_calls: bool


_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = {
Expand Down Expand Up @@ -64,6 +69,13 @@ class _ArrowDriverProperties_(TypedDict):
"exact_batch_size": True,
"repeat_batch_calls": False,
},
"kuzu": {
# 'get_as_arrow' currently takes a mandatory chunk size
"fetch_all": ("get_as_arrow", 10_000),
"fetch_batches": None,
"exact_batch_size": None,
"repeat_batch_calls": False,
},
"snowflake": {
"fetch_all": "fetch_arrow_all",
"fetch_batches": "fetch_arrow_batches",
Expand Down Expand Up @@ -153,7 +165,7 @@ def __exit__(
) -> None:
# if we created it and are finished with it, we can
# close the cursor (but NOT the connection)
if self.can_close_cursor:
if self.can_close_cursor and hasattr(self.cursor, "close"):
self.cursor.close()

def __repr__(self) -> str:
Expand All @@ -169,8 +181,11 @@ def _arrow_batches(
"""Yield Arrow data in batches, or as a single 'fetchall' batch."""
fetch_batches = driver_properties["fetch_batches"]
if not iter_batches or fetch_batches is None:
fetch_method = driver_properties["fetch_all"]
yield getattr(self.result, fetch_method)()
fetch_method, sz = driver_properties["fetch_all"], []
if isinstance(fetch_method, tuple):
fetch_method, chunk_size = fetch_method
sz = [chunk_size]
yield getattr(self.result, fetch_method)(*sz)
else:
size = batch_size if driver_properties["exact_batch_size"] else None
repeat_batch_calls = driver_properties["repeat_batch_calls"]
Expand Down
6 changes: 0 additions & 6 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,11 @@ class SeriesBuffers(TypedDict):
# minimal protocol definitions that can reasonably represent
# an executable connection, cursor, or equivalent object
class BasicConnection(Protocol): # noqa: D101
def close(self) -> None:
"""Close the connection."""

def cursor(self, *args: Any, **kwargs: Any) -> Any:
"""Return a cursor object."""


class BasicCursor(Protocol): # noqa: D101
def close(self) -> None:
"""Close the cursor."""

def execute(self, *args: Any, **kwargs: Any) -> Any:
"""Execute a query."""

Expand Down
3 changes: 2 additions & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ module = [
"fsspec.*",
"gevent",
"hvplot.*",
"kuzu",
"matplotlib.*",
"moto.server",
"openpyxl",
Expand Down Expand Up @@ -179,7 +180,7 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"]
"tests/**/*.py" = ["D100", "D102", "D103", "B018", "FBT001"]

[tool.ruff.lint.pycodestyle]
max-doc-length = 88
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
# TODO: Remove version constraint for connectorx when Python 3.12 is supported:
# https://github.com/sfu-db/connector-x/issues/527
connectorx; python_version <= '3.11'
kuzu
# Cloud
cloudpickle
fsspec
Expand Down
4 changes: 4 additions & 0 deletions py-polars/tests/unit/io/files/graph-data/follows.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Adam,Karissa,2020
Adam,Zhang,2020
Karissa,Zhang,2021
Zhang,Noura,2022
4 changes: 4 additions & 0 deletions py-polars/tests/unit/io/files/graph-data/user.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Adam,30
Karissa,40
Zhang,50
Noura,25
62 changes: 57 additions & 5 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
test_data=test_data,
)

def close(self) -> None: # noqa: D102
def close(self) -> None:
pass

def cursor(self) -> Any: # noqa: D102
def cursor(self) -> Any:
return self._cursor


Expand All @@ -143,10 +143,10 @@ def __getattr__(self, item: str) -> Any:
return self.resultset
super().__getattr__(item) # type: ignore[misc]

def close(self) -> Any: # noqa: D102
def close(self) -> Any:
pass

def execute(self, query: str) -> Any: # noqa: D102
def execute(self, query: str) -> Any:
return self


Expand All @@ -161,7 +161,7 @@ def __init__(
self.batched = batched
self.n_calls = 1

def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.repeat_batched_calls:
res = self.test_data[: None if self.n_calls else 0]
self.n_calls -= 1
Expand Down Expand Up @@ -632,3 +632,55 @@ def test_read_database_cx_credentials(uri: str) -> None:
# can reasonably mitigate the issue.
with pytest.raises(BaseException, match=r"fakedb://\*\*\*:\*\*\*@\w+"):
pl.read_database_uri("SELECT * FROM data", uri=uri)


@pytest.mark.write_disk()
def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
# validate reading from a kuzu graph database
import kuzu

tmp_path.mkdir(exist_ok=True)
if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists():
kuzu_test_db.unlink()

test_db = str(kuzu_test_db).replace("\\", "/")

db = kuzu.Database(test_db)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")

users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/")
follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/")

conn.execute(f'COPY User FROM "{users}"')
conn.execute(f'COPY Follows FROM "{follows}"')

df1 = pl.read_database(
query="MATCH (u:User) RETURN u.name, u.age",
connection=conn,
)
assert_frame_equal(
df1,
pl.DataFrame(
{
"u.name": ["Adam", "Karissa", "Zhang", "Noura"],
"u.age": [30, 40, 50, 25],
}
),
)

df2 = pl.read_database(
query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name",
connection=conn,
)
assert_frame_equal(
df2,
pl.DataFrame(
{
"a.name": ["Adam", "Adam", "Karissa", "Zhang"],
"f.since": [2020, 2020, 2021, 2022],
"b.name": ["Karissa", "Zhang", "Zhang", "Noura"],
}
),
)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/utils/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def hello(oof: str, rab: str, ham: str) -> None: ...

class Foo: # noqa: D101
@deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2")
def bar( # noqa: D102
def bar(
self, baz: str, ham: str | None = None, foobar: str | None = None
) -> None: ...

Expand Down

0 comments on commit baacf3d

Please sign in to comment.