Skip to content

Commit

Permalink
feat: union_all unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
cenzwong committed Oct 4, 2024
1 parent b3ff377 commit 28a512f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
16 changes: 10 additions & 6 deletions pysparky/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,41 @@ def create_map_from_dict(dict_: dict[str, int]) -> Column:
return F.create_map(list(map(F.lit, itertools.chain(*dict_.items()))))


def join_dataframes_on_column(
column_name: str, dataframes: list[DataFrame]
) -> DataFrame:
def join_dataframes_on_column(column_name: str, *dataframes: DataFrame) -> DataFrame:
"""
Joins a list of DataFrames on a specified column using an outer join.
Args:
column_name (str): The column name to join on.
dataframes (list): A list of DataFrames to join.
*dataframes (DataFrame): A list of DataFrames to join.
Returns:
DataFrame: The resulting DataFrame after performing the outer joins.
"""
if not dataframes:
raise ValueError("At least one DataFrame must be provided")

joined_df = dataframes[0].select(F.col(column_name))
for sdf in dataframes:
joined_df = joined_df.join(sdf, column_name, "outer").fillna(0)
return joined_df


def union_dataframes(dataframes: list[DataFrame]) -> DataFrame:
def union_dataframes(*dataframes: DataFrame) -> DataFrame:
"""
Unions a list of DataFrames.
Args:
dataframes (list): A list of DataFrames to union.
*dataframes (DataFrame): A list of DataFrames to union.
Returns:
DataFrame: The resulting DataFrame after performing the unions.
"""
# TODO: Check on the schema, if not align, raise error

if not dataframes:
raise ValueError("At least one DataFrame must be provided")

output_df = dataframes[0]
for sdf in dataframes[1:]:
output_df = output_df.union(sdf)
Expand Down
14 changes: 12 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_join_dataframes_on_column(spark):

dataframes = [df1, df2, df3]

result_df = utils.join_dataframes_on_column("id", dataframes)
result_df = utils.join_dataframes_on_column("id", *dataframes)
result_data = result_df.collect()

expected_data = [
Expand All @@ -58,6 +58,11 @@ def test_join_dataframes_on_column(spark):
assert result_data == expected_result


def test_join_dataframes_on_column_no_input():
with pytest.raises(ValueError, match="At least one DataFrame must be provided"):
utils.join_dataframes_on_column("col")


def test_union_dataframes(spark):
data1 = {"id": [1, 2, 3], "value": [10, 20, 30]}
data2 = {"id": [4, 5, 6], "value": [40, 50, 60]}
Expand All @@ -69,7 +74,7 @@ def test_union_dataframes(spark):

dataframes = [df1, df2, df3]

result_df = utils.union_dataframes(dataframes)
result_df = utils.union_dataframes(*dataframes)
result_data = result_df.collect()

expected_data = [
Expand All @@ -90,5 +95,10 @@ def test_union_dataframes(spark):
assert result_data == expected_result


def test_union_dataframes_no_input():
with pytest.raises(ValueError, match="At least one DataFrame must be provided"):
utils.union_dataframes()


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 28a512f

Please sign in to comment.