diff --git a/dist/pysparky-0.1.0-py3-none-any.whl b/dist/pysparky-0.1.0-py3-none-any.whl index ca7c4e6..704d214 100644 Binary files a/dist/pysparky-0.1.0-py3-none-any.whl and b/dist/pysparky-0.1.0-py3-none-any.whl differ diff --git a/dist/pysparky-0.1.0.tar.gz b/dist/pysparky-0.1.0.tar.gz index d13b1ae..259b46a 100644 Binary files a/dist/pysparky-0.1.0.tar.gz and b/dist/pysparky-0.1.0.tar.gz differ diff --git a/example/dev.ipynb b/example/dev.ipynb index 95b4b41..392bb46 100644 --- a/example/dev.ipynb +++ b/example/dev.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 13, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -23,6 +23,19 @@ "text": [ "3.5.2\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24/10/17 09:54:54 WARN Utils: Your hostname, codespaces-0aafae resolves to a loopback address: 127.0.0.1; using 10.0.10.147 instead (on interface eth0)\n", + "24/10/17 09:54:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n", + "Setting default log level to \"WARN\".\n", + "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", + "24/10/17 09:54:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", + "24/10/17 09:54:56 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n", + "24/10/17 09:54:56 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n" + ] } ], "source": [ @@ -43,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -52,49 +65,42 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Column<'trim(regexp_replace(hi, \\s+, , 1))'>" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "F_.single_space_and_trim(\"hi\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "Column<'CASE WHEN (hi = 1) THEN Ture END'>" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "+---+\n", + "|ID1|\n", + "+---+\n", + "| 1|\n", + "| 2|\n", + "| 3|\n", + "| 4|\n", + "+---+\n", + "\n", + "+---+\n", + "|ID1|\n", + "+---+\n", + "| 1|\n", + "| 2|\n", + "| 3|\n", + "| 4|\n", + "+---+\n", + "\n" + ] } ], "source": [ - "F_.when_mapping(\"hi\", {1: \"Ture\"})" + "list_ = [1, 2, 3, 4]\n", + "column_names = [\"ID1\"]\n", + "df = convert_1d_list_to_dataframe(spark, list_, column_names, axis=\"column\")\n", + "expected_data = [(1,), (2,), (3,), (4,)]\n", + "expected_df = spark.createDataFrame(expected_data, schema=column_names)\n", + "df.show()\n", + "expected_df.show()" ] }, { diff --git a/pysparky/decorator.py b/pysparky/decorator.py index 94d60d0..e7f81ec 100644 --- a/pysparky/decorator.py +++ b/pysparky/decorator.py @@ -1,101 +1,5 @@ import functools -from pyspark.sql import functions as F - - -def pyspark_column_or_name_enabler(*param_names): - """ - A decorator to enable PySpark functions to accept either column names (as strings) or Column objects. - - Parameters: - param_names (str): Names of the parameters that should be converted from strings to Column objects. - - Returns: - function: The decorated function with specified parameters converted to Column objects if they are strings. - - Example - @pyspark_column_or_name_enabler("column_or_name") - def your_function(column_or_name): - return column_or_name.startswith(bins) - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - # Convert args to a list to modify them - # args: This is the list of argument of the function. - # Get the parameter indices from the function signature - # func.__code__.co_varnames : Return the function parameter names as a tuple. - # param_names : the list of parameter from the decorator - - # Merging the args into Kwargs - args_name_used = func.__code__.co_varnames[: len(args)] - kw_from_args = dict(zip(args_name_used, args)) - kwargs = kw_from_args | kwargs - - # print(kwargs) - # transform all the input param - for param_name in param_names: - # if it is string, wrap it as string, else do nth - kwargs[param_name] = ( - F.col(kwargs[param_name]) - if isinstance(kwargs[param_name], str) - else kwargs[param_name] - ) - - return func(**kwargs) - - return wrapper - - return decorator - - -def column_name_or_column_names_enabler(*param_names): - """ - A decorator to enable PySpark functions to accept either column names (as strings) or Column objects. - - Parameters: - param_names (str): Names of the parameters that should be converted from strings to Column objects. - - Returns: - function: The decorated function with specified parameters converted to Column objects if they are strings. - - Example - @pyspark_column_or_name_enabler("column_or_name") - def your_function(column_or_name): - return column_or_name.startswith(bins) - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - # Convert args to a list to modify them - # args: This is the list of argument of the function. - # Get the parameter indices from the function signature - # func.__code__.co_varnames : Return the function parameter names as a tuple. - # param_names : the list of parameter from the decorator - - # Merging the args into Kwargs - args_name_used = func.__code__.co_varnames[: len(args)] - kw_from_args = dict(zip(args_name_used, args)) - kwargs = kw_from_args | kwargs - - # print(kwargs) - # transform all the input param - for param_name in param_names: - # if it is string, wrap it as string, else do nth - kwargs[param_name] = ( - [kwargs[param_name]] - if isinstance(kwargs[param_name], str) - else kwargs[param_name] - ) - - return func(**kwargs) - - return wrapper - - return decorator - def extension_enabler(cls): """ diff --git a/pysparky/enabler.py b/pysparky/enabler.py index bfcd85a..f0ef984 100644 --- a/pysparky/enabler.py +++ b/pysparky/enabler.py @@ -23,3 +23,27 @@ def column_or_name_enabler(*columns: ColumnOrName) -> tuple[Column, ...]: lambda column: F.col(column) if isinstance(column, str) else column, columns ) ) + + +def column_name_or_column_names_enabler( + column_names: str | list[str], +) -> list[str]: + """ + Ensures that the input is always returned as a list of column names. + + Parameters: + column_names (str | list[str]): A single column name (as a string) or a list of column names. + + Returns: + list[str]: A list containing the column names. + + Example: + >>> column_name_or_column_names_enabler("col1") + ['col1'] + >>> column_name_or_column_names_enabler(["col1", "col2"]) + ['col1', 'col2'] + """ + + column_names = [column_names] if isinstance(column_names, str) else column_names + + return column_names diff --git a/pysparky/functions/general.py b/pysparky/functions/general.py index 03d31b0..1cbc0a5 100644 --- a/pysparky/functions/general.py +++ b/pysparky/functions/general.py @@ -67,7 +67,6 @@ def chain(self, func, *args, **kwargs) -> Column: @decorator.extension_enabler(Column) -@decorator.pyspark_column_or_name_enabler("column_or_name") def startswiths( column_or_name: ColumnOrName, list_of_strings: list[str] ) -> pyspark.sql.Column: @@ -81,10 +80,11 @@ def startswiths( Returns: Column: A PySpark Column expression that evaluates to True if the column starts with any string in the list, otherwise False. """ + (column,) = column_or_name_enabler(column_or_name) return functools.reduce( operator.or_, - map(column_or_name.startswith, list_of_strings), + map(column.startswith, list_of_strings), F.lit(False), ).alias(f"startswiths_len{len(list_of_strings)}") diff --git a/pysparky/functions/math_.py b/pysparky/functions/math_.py index 58c8ba1..9192a31 100644 --- a/pysparky/functions/math_.py +++ b/pysparky/functions/math_.py @@ -7,7 +7,6 @@ @decorator.extension_enabler(Column) -@decorator.pyspark_column_or_name_enabler("lat1", "long1", "lat2", "long2") def haversine_distance( lat1: ColumnOrName, long1: ColumnOrName, diff --git a/pysparky/spark_ext.py b/pysparky/spark_ext.py index a6d6b1c..0356ef9 100644 --- a/pysparky/spark_ext.py +++ b/pysparky/spark_ext.py @@ -3,11 +3,11 @@ from pyspark.sql import Column, DataFrame, SparkSession from pyspark.sql import functions as F -from pysparky import decorator +from pysparky import decorator, enabler @decorator.extension_enabler(SparkSession) -def column_function(spark, column_obj: Column) -> DataFrame: +def column_function(spark: SparkSession, column_obj: Column) -> DataFrame: """ Evaluates a Column expression in the context of a single-row DataFrame. @@ -71,7 +71,10 @@ def column_function(spark, column_obj: Column) -> DataFrame: @decorator.extension_enabler(SparkSession) def convert_dict_to_dataframe( - spark, dict_: dict[str, Any], column_names: list[str], explode: bool = False + spark: SparkSession, + dict_: dict[str, Any], + column_names: list[str], + explode: bool = False, ) -> DataFrame: """ Transforms a dictionary with list values into a Spark DataFrame. @@ -111,8 +114,13 @@ def convert_dict_to_dataframe( @decorator.extension_enabler(SparkSession) -@decorator.column_name_or_column_names_enabler("column_names") -def convert_1d_list_to_dataframe(spark, list_, column_names, axis="column"): +# @decorator.column_name_or_column_names_enabler("column_names") +def convert_1d_list_to_dataframe( + spark: SparkSession, + list_: list[Any], + column_names: str | list[str], + axis: str = "column", +) -> DataFrame: """ Converts a 1-dimensional list into a PySpark DataFrame. @@ -156,20 +164,25 @@ def convert_1d_list_to_dataframe(spark, list_, column_names, axis="column"): | 1| 2| 3| 4| +---+---+---+---+ """ + column_names = enabler.column_name_or_column_names_enabler(column_names) + if axis not in ["column", "row"]: - raise AttributeError + raise AttributeError( + f"Invalid axis value: {axis}. Acceptable values are 'column' or 'row'." + ) if axis == "column": - tuple_list = ((x,) for x in list_) - output_sdf = spark.createDataFrame(tuple_list, schema=column_names) + tuple_list = ((x,) for x in list_) # type: ignore elif axis == "row": - tuple_list = (tuple(list_),) - output_sdf = spark.createDataFrame(tuple_list, schema=column_names) + tuple_list = (tuple(list_),) # type: ignore + + output_sdf = spark.createDataFrame(tuple_list, schema=column_names) + return output_sdf @decorator.extension_enabler(SparkSession) -def createDataFrame_from_dict(spark, dict_: dict) -> DataFrame: +def createDataFrame_from_dict(spark: SparkSession, dict_: dict) -> DataFrame: """ Creates a Spark DataFrame from a dictionary in a pandas-like style. diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2dcef4c..1668a62 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -3,34 +3,7 @@ from pyspark.sql import functions as F # Now import the decorators -from pysparky.decorator import (extension_enabler, - pyspark_column_or_name_enabler) - - -def test_pyspark_column_or_name_enabler(spark): - # It require spark session in the decorator - @pyspark_column_or_name_enabler("col1", "col2") - def test_function(col1, col2, col3): - return col1, col2, col3 - - # Test with string input - result = test_function("name1", "name2", "name3") - assert isinstance(result[0], Column) - assert isinstance(result[1], Column) - assert isinstance(result[2], str) - - # Test with Column input - col_input = F.col("col_name") - result = test_function(col_input, "name2", col_input) - assert result[0] is col_input - assert isinstance(result[1], Column) - assert result[2] is col_input - - # Test with keyword arguments - result = test_function(col1="name1", col2=Column("col2"), col3="name3") - assert isinstance(result[0], Column) - assert isinstance(result[1], Column) - assert isinstance(result[2], str) +from pysparky.decorator import extension_enabler def test_extension_enabler(): diff --git a/tests/test_enabler.py b/tests/test_enabler.py index 459ac49..64da21b 100644 --- a/tests/test_enabler.py +++ b/tests/test_enabler.py @@ -23,5 +23,23 @@ def test_column_or_name_enabler_with_1(spark): assert isinstance(col1, Column) +def test_column_name_or_column_names_enabler_with_string(): + result = enabler.column_name_or_column_names_enabler("col1") + expected = ["col1"] + assert result == expected + + +def test_column_name_or_column_names_enabler_with_list(): + result = enabler.column_name_or_column_names_enabler(["col1", "col2"]) + expected = ["col1", "col2"] + assert result == expected + + +def test_column_name_or_column_names_enabler_with_empty_list(): + result = enabler.column_name_or_column_names_enabler([]) + expected = [] + assert result == expected + + if __name__ == "__main__": pytest.main() diff --git a/tests/test_spark_ext.py b/tests/test_spark_ext.py index 38ab37c..57956e1 100644 --- a/tests/test_spark_ext.py +++ b/tests/test_spark_ext.py @@ -1,6 +1,7 @@ import pytest -from pysparky.spark_ext import createDataFrame_from_dict +from pysparky.spark_ext import (convert_1d_list_to_dataframe, + createDataFrame_from_dict) def test_createDataFrame_from_dict(spark): @@ -17,5 +18,30 @@ def test_createDataFrame_from_dict(spark): assert result_df.columns == expected_columns +def test_convert_1d_list_to_dataframe_column(spark): + list_ = [1, 2, 3, 4] + column_names = "ID1" + df = convert_1d_list_to_dataframe(spark, list_, column_names, axis="column") + expected_data = [(1,), (2,), (3,), (4,)] + expected_df = spark.createDataFrame(expected_data, schema=[column_names]) + assert df.collect() == expected_df.collect() + + +def test_convert_1d_list_to_dataframe_row(spark): + list_ = [1, 2, 3, 4] + column_names = ["ID1", "ID2", "ID3", "ID4"] + df = convert_1d_list_to_dataframe(spark, list_, column_names, axis="row") + expected_data = [(1, 2, 3, 4)] + expected_df = spark.createDataFrame(expected_data, schema=column_names) + assert df.collect() == expected_df.collect() + + +def test_convert_1d_list_to_dataframe_invalid_axis(spark): + list_ = [1, 2, 3, 4] + column_names = ["numbers"] + with pytest.raises(AttributeError): + convert_1d_list_to_dataframe(spark, list_, column_names, axis="invalid") + + if __name__ == "__main__": pytest.main([__file__])