From 28a512f6d342b975bb96b6df69839d2f33f95986 Mon Sep 17 00:00:00 2001 From: Cenz Wong <44856918+cenzwong@users.noreply.github.com> Date: Fri, 4 Oct 2024 16:08:27 +0000 Subject: [PATCH] feat: union_all unpack --- pysparky/utils.py | 16 ++++++++++------ tests/test_utils.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/pysparky/utils.py b/pysparky/utils.py index 6a7f2ab..248e956 100644 --- a/pysparky/utils.py +++ b/pysparky/utils.py @@ -24,37 +24,41 @@ def create_map_from_dict(dict_: dict[str, int]) -> Column: return F.create_map(list(map(F.lit, itertools.chain(*dict_.items())))) -def join_dataframes_on_column( - column_name: str, dataframes: list[DataFrame] -) -> DataFrame: +def join_dataframes_on_column(column_name: str, *dataframes: DataFrame) -> DataFrame: """ Joins a list of DataFrames on a specified column using an outer join. Args: column_name (str): The column name to join on. - dataframes (list): A list of DataFrames to join. + *dataframes (DataFrame): A list of DataFrames to join. Returns: DataFrame: The resulting DataFrame after performing the outer joins. """ + if not dataframes: + raise ValueError("At least one DataFrame must be provided") + joined_df = dataframes[0].select(F.col(column_name)) for sdf in dataframes: joined_df = joined_df.join(sdf, column_name, "outer").fillna(0) return joined_df -def union_dataframes(dataframes: list[DataFrame]) -> DataFrame: +def union_dataframes(*dataframes: DataFrame) -> DataFrame: """ Unions a list of DataFrames. Args: - dataframes (list): A list of DataFrames to union. + *dataframes (DataFrame): A list of DataFrames to union. Returns: DataFrame: The resulting DataFrame after performing the unions. """ # TODO: Check on the schema, if not align, raise error + if not dataframes: + raise ValueError("At least one DataFrame must be provided") + output_df = dataframes[0] for sdf in dataframes[1:]: output_df = output_df.union(sdf) diff --git a/tests/test_utils.py b/tests/test_utils.py index cbe60b9..c430b32 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -39,7 +39,7 @@ def test_join_dataframes_on_column(spark): dataframes = [df1, df2, df3] - result_df = utils.join_dataframes_on_column("id", dataframes) + result_df = utils.join_dataframes_on_column("id", *dataframes) result_data = result_df.collect() expected_data = [ @@ -58,6 +58,11 @@ def test_join_dataframes_on_column(spark): assert result_data == expected_result +def test_join_dataframes_on_column_no_input(): + with pytest.raises(ValueError, match="At least one DataFrame must be provided"): + utils.join_dataframes_on_column("col") + + def test_union_dataframes(spark): data1 = {"id": [1, 2, 3], "value": [10, 20, 30]} data2 = {"id": [4, 5, 6], "value": [40, 50, 60]} @@ -69,7 +74,7 @@ def test_union_dataframes(spark): dataframes = [df1, df2, df3] - result_df = utils.union_dataframes(dataframes) + result_df = utils.union_dataframes(*dataframes) result_data = result_df.collect() expected_data = [ @@ -90,5 +95,10 @@ def test_union_dataframes(spark): assert result_data == expected_result +def test_union_dataframes_no_input(): + with pytest.raises(ValueError, match="At least one DataFrame must be provided"): + utils.union_dataframes() + + if __name__ == "__main__": pytest.main([__file__])