Skip to content

Commit

Permalink
feat: remove shcema check for union
Browse files Browse the repository at this point in the history
  • Loading branch information
cenzwong committed Oct 4, 2024
1 parent 4bbf790 commit 2c14afc
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 25 deletions.
Binary file modified dist/pysparky-0.1.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/pysparky-0.1.0.tar.gz
Binary file not shown.
6 changes: 1 addition & 5 deletions pysparky/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,4 @@ def union_dataframes(*dataframes: DataFrame | list[DataFrame]) -> DataFrame:
if isinstance(dataframes[0], list):
dataframes = dataframes[0]

# Check if all DataFrames have the same schema
if not all(sdf.schema == dataframes[0].schema for sdf in dataframes):
raise ValueError("All DataFrames must have the same schema")

return reduce(DataFrame.union, dataframes)
return reduce(lambda df1, df2: df1.union(df2), dataframes)
15 changes: 15 additions & 0 deletions run_pytests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import sys

import pytest

sys.path.append(".")

sys.dont_write_bytecode = True

args = ["--verbose", "-p", "no:cacheprovider"]
# args += ["-k", "test_dataframe_transform"] # uncomment for specific test

result = pytest.main(args)

print(f"{result=}")
assert result == pytest.ExitCode.OK, "Test run was not successful."
20 changes: 0 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,26 +119,6 @@ def test_union_dataframes(spark):
assert result_data == expected_result


def test_union_different_schema(spark):
schema1 = T.StructType(
[
T.StructField("name", T.StringType(), True),
T.StructField("age", T.IntegerType(), True),
]
)
schema2 = T.StructType(
[
T.StructField("name", T.StringType(), True),
T.StructField("salary", T.IntegerType(), True),
]
)
df1 = spark.createDataFrame([("Alice", 30)], schema1)
df2 = spark.createDataFrame([("Bob", 50000)], schema2)

with pytest.raises(ValueError, match="All DataFrames must have the same schema"):
utils.union_dataframes(df1, df2)


def test_union_list_dataframes(spark):
data1 = {"id": [1, 2, 3], "value": [10, 20, 30]}
data2 = {"id": [4, 5, 6], "value": [40, 50, 60]}
Expand Down

0 comments on commit 2c14afc

Please sign in to comment.