-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support Bucket and Truncate transforms on write #1345
Changes from 7 commits
dd888ec
bd80f39
27ade9a
fcd654c
a0a9c58
a4137e0
05c440f
7079265
77246d5
c1ece75
3d0f03b
1163c2a
0e72d90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -719,50 +719,105 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non | |||
@pytest.mark.parametrize( | ||||
"spec", | ||||
[ | ||||
# mixed with non-identity is not supported | ||||
( | ||||
PartitionSpec( | ||||
PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), | ||||
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), | ||||
) | ||||
), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this case supported now? |
||||
# none of non-identity is supported | ||||
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), | ||||
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), | ||||
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), | ||||
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), | ||||
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good question. Truncating binary isn't supported with iceberg-rust so I've excluded this test case for now: https://github.com/apache/iceberg-rust/blob/main/crates/iceberg/src/transform/truncate.rs#L132-L164 |
||||
], | ||||
) | ||||
def test_unsupported_transform( | ||||
spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table | ||||
@pytest.mark.parametrize("format_version", [1, 2]) | ||||
def test_truncate_transform( | ||||
spec: PartitionSpec, | ||||
spark: SparkSession, | ||||
session_catalog: Catalog, | ||||
arrow_table_with_null: pa.Table, | ||||
format_version: int, | ||||
) -> None: | ||||
identifier = "default.unsupported_transform" | ||||
identifier = "default.truncate_transform" | ||||
|
||||
try: | ||||
session_catalog.drop_table(identifier=identifier) | ||||
except NoSuchTableError: | ||||
pass | ||||
|
||||
tbl = session_catalog.create_table( | ||||
tbl = _create_table( | ||||
session_catalog=session_catalog, | ||||
identifier=identifier, | ||||
schema=TABLE_SCHEMA, | ||||
properties={"format-version": str(format_version)}, | ||||
data=[arrow_table_with_null], | ||||
partition_spec=spec, | ||||
properties={"format-version": "1"}, | ||||
) | ||||
|
||||
with pytest.raises( | ||||
ValueError, | ||||
match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", | ||||
): | ||||
tbl.append(arrow_table_with_null) | ||||
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" | ||||
df = spark.table(identifier) | ||||
assert df.count() == 3, f"Expected 3 total rows for {identifier}" | ||||
for col in arrow_table_with_null.column_names: | ||||
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" | ||||
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" | ||||
|
||||
assert tbl.inspect.partitions().num_rows == 3 | ||||
files_df = spark.sql( | ||||
f""" | ||||
SELECT * | ||||
FROM {identifier}.files | ||||
""" | ||||
) | ||||
assert files_df.count() == 3 | ||||
|
||||
|
||||
@pytest.mark.integration | ||||
@pytest.mark.parametrize( | ||||
"spec, expected_rows", | ||||
[ | ||||
# none of non-identity is supported | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3), | ||||
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2), | ||||
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2), | ||||
], | ||||
) | ||||
@pytest.mark.parametrize("format_version", [1, 2]) | ||||
def test_bucket_transform( | ||||
spark: SparkSession, | ||||
session_catalog: Catalog, | ||||
arrow_table_with_null: pa.Table, | ||||
spec: PartitionSpec, | ||||
expected_rows: int, | ||||
format_version: int, | ||||
) -> None: | ||||
identifier = "default.bucket_transform" | ||||
|
||||
try: | ||||
session_catalog.drop_table(identifier=identifier) | ||||
except NoSuchTableError: | ||||
pass | ||||
|
||||
tbl = _create_table( | ||||
session_catalog=session_catalog, | ||||
identifier=identifier, | ||||
properties={"format-version": str(format_version)}, | ||||
data=[arrow_table_with_null], | ||||
partition_spec=spec, | ||||
) | ||||
|
||||
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" | ||||
df = spark.table(identifier) | ||||
assert df.count() == 3, f"Expected 3 total rows for {identifier}" | ||||
for col in arrow_table_with_null.column_names: | ||||
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" | ||||
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" | ||||
|
||||
assert tbl.inspect.partitions().num_rows == expected_rows | ||||
files_df = spark.sql( | ||||
f""" | ||||
SELECT * | ||||
FROM {identifier}.files | ||||
""" | ||||
) | ||||
assert files_df.count() == expected_rows | ||||
|
||||
|
||||
@pytest.mark.integration | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,10 +18,11 @@ | |
# pylint: disable=eval-used,protected-access,redefined-outer-name | ||
from datetime import date | ||
from decimal import Decimal | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
from typing import Any, Callable, Optional, Union | ||
from uuid import UUID | ||
|
||
import mmh3 as mmh3 | ||
import pyarrow as pa | ||
import pytest | ||
from pydantic import ( | ||
BeforeValidator, | ||
|
@@ -116,9 +117,6 @@ | |
timestamptz_to_micros, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import pyarrow as pa | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_input,test_type,expected", | ||
|
@@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms( | |
else: | ||
with pytest.raises(ValueError): | ||
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"source_type, input_arr, expected, num_buckets", | ||
[ | ||
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), | ||
( | ||
IntegerType(), | ||
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), | ||
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), | ||
10, | ||
), | ||
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), | ||
], | ||
) | ||
def test_bucket_pyarrow_transforms( | ||
source_type: PrimitiveType, | ||
input_arr: Union[pa.Array, pa.ChunkedArray], | ||
expected: Union[pa.Array, pa.ChunkedArray], | ||
num_buckets: int, | ||
Comment on lines
+1580
to
+1583
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: wydt of reordering these for readability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I think I feel indifferent here - there’s something nice about having the input and expected arrays side by side |
||
) -> None: | ||
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) | ||
assert expected == transform.pyarrow_transform(source_type)(input_arr) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"source_type, input_arr, expected, width", | ||
[ | ||
(StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3), | ||
(IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10), | ||
], | ||
) | ||
def test_truncate_pyarrow_transforms( | ||
source_type: PrimitiveType, | ||
input_arr: Union[pa.Array, pa.ChunkedArray], | ||
expected: Union[pa.Array, pa.ChunkedArray], | ||
width: int, | ||
) -> None: | ||
transform: Transform[Any, Any] = TruncateTransform(width=width) | ||
assert expected == transform.pyarrow_transform(source_type)(input_arr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.