Skip to content

Commit

Permalink
feat: changing columns name to pure function
Browse files Browse the repository at this point in the history
  • Loading branch information
cenzwong committed Oct 17, 2024
1 parent e899d7b commit bece33a
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 179 deletions.
Binary file modified dist/pysparky-0.1.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/pysparky-0.1.0.tar.gz
Binary file not shown.
86 changes: 46 additions & 40 deletions example/dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -23,6 +23,19 @@
"text": [
"3.5.2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/10/17 09:54:54 WARN Utils: Your hostname, codespaces-0aafae resolves to a loopback address: 127.0.0.1; using 10.0.10.147 instead (on interface eth0)\n",
"24/10/17 09:54:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"24/10/17 09:54:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"24/10/17 09:54:56 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
"24/10/17 09:54:56 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n"
]
}
],
"source": [
Expand All @@ -43,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -52,49 +65,42 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Column<'trim(regexp_replace(hi, \\s+, , 1))'>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F_.single_space_and_trim(\"hi\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Column<'CASE WHEN (hi = 1) THEN Ture END'>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"+---+\n",
"|ID1|\n",
"+---+\n",
"| 1|\n",
"| 2|\n",
"| 3|\n",
"| 4|\n",
"+---+\n",
"\n",
"+---+\n",
"|ID1|\n",
"+---+\n",
"| 1|\n",
"| 2|\n",
"| 3|\n",
"| 4|\n",
"+---+\n",
"\n"
]
}
],
"source": [
"F_.when_mapping(\"hi\", {1: \"Ture\"})"
"list_ = [1, 2, 3, 4]\n",
"column_names = [\"ID1\"]\n",
"df = convert_1d_list_to_dataframe(spark, list_, column_names, axis=\"column\")\n",
"expected_data = [(1,), (2,), (3,), (4,)]\n",
"expected_df = spark.createDataFrame(expected_data, schema=column_names)\n",
"df.show()\n",
"expected_df.show()"
]
},
{
Expand Down
96 changes: 0 additions & 96 deletions pysparky/decorator.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,5 @@
import functools

from pyspark.sql import functions as F


def pyspark_column_or_name_enabler(*param_names):
"""
A decorator to enable PySpark functions to accept either column names (as strings) or Column objects.
Parameters:
param_names (str): Names of the parameters that should be converted from strings to Column objects.
Returns:
function: The decorated function with specified parameters converted to Column objects if they are strings.
Example
@pyspark_column_or_name_enabler("column_or_name")
def your_function(column_or_name):
return column_or_name.startswith(bins)
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Convert args to a list to modify them
# args: This is the list of argument of the function.
# Get the parameter indices from the function signature
# func.__code__.co_varnames : Return the function parameter names as a tuple.
# param_names : the list of parameter from the decorator

# Merging the args into Kwargs
args_name_used = func.__code__.co_varnames[: len(args)]
kw_from_args = dict(zip(args_name_used, args))
kwargs = kw_from_args | kwargs

# print(kwargs)
# transform all the input param
for param_name in param_names:
# if it is string, wrap it as string, else do nth
kwargs[param_name] = (
F.col(kwargs[param_name])
if isinstance(kwargs[param_name], str)
else kwargs[param_name]
)

return func(**kwargs)

return wrapper

return decorator


def column_name_or_column_names_enabler(*param_names):
"""
A decorator to enable PySpark functions to accept either column names (as strings) or Column objects.
Parameters:
param_names (str): Names of the parameters that should be converted from strings to Column objects.
Returns:
function: The decorated function with specified parameters converted to Column objects if they are strings.
Example
@pyspark_column_or_name_enabler("column_or_name")
def your_function(column_or_name):
return column_or_name.startswith(bins)
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Convert args to a list to modify them
# args: This is the list of argument of the function.
# Get the parameter indices from the function signature
# func.__code__.co_varnames : Return the function parameter names as a tuple.
# param_names : the list of parameter from the decorator

# Merging the args into Kwargs
args_name_used = func.__code__.co_varnames[: len(args)]
kw_from_args = dict(zip(args_name_used, args))
kwargs = kw_from_args | kwargs

# print(kwargs)
# transform all the input param
for param_name in param_names:
# if it is string, wrap it as string, else do nth
kwargs[param_name] = (
[kwargs[param_name]]
if isinstance(kwargs[param_name], str)
else kwargs[param_name]
)

return func(**kwargs)

return wrapper

return decorator


def extension_enabler(cls):
"""
Expand Down
24 changes: 24 additions & 0 deletions pysparky/enabler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,27 @@ def column_or_name_enabler(*columns: ColumnOrName) -> tuple[Column, ...]:
lambda column: F.col(column) if isinstance(column, str) else column, columns
)
)


def column_name_or_column_names_enabler(
column_names: str | list[str],
) -> list[str]:
"""
Ensures that the input is always returned as a list of column names.
Parameters:
column_names (str | list[str]): A single column name (as a string) or a list of column names.
Returns:
list[str]: A list containing the column names.
Example:
>>> column_name_or_column_names_enabler("col1")
['col1']
>>> column_name_or_column_names_enabler(["col1", "col2"])
['col1', 'col2']
"""

column_names = [column_names] if isinstance(column_names, str) else column_names

return column_names
4 changes: 2 additions & 2 deletions pysparky/functions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def chain(self, func, *args, **kwargs) -> Column:


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def startswiths(
column_or_name: ColumnOrName, list_of_strings: list[str]
) -> pyspark.sql.Column:
Expand All @@ -81,10 +80,11 @@ def startswiths(
Returns:
Column: A PySpark Column expression that evaluates to True if the column starts with any string in the list, otherwise False.
"""
(column,) = column_or_name_enabler(column_or_name)

return functools.reduce(
operator.or_,
map(column_or_name.startswith, list_of_strings),
map(column.startswith, list_of_strings),
F.lit(False),
).alias(f"startswiths_len{len(list_of_strings)}")

Expand Down
1 change: 0 additions & 1 deletion pysparky/functions/math_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("lat1", "long1", "lat2", "long2")
def haversine_distance(
lat1: ColumnOrName,
long1: ColumnOrName,
Expand Down
35 changes: 24 additions & 11 deletions pysparky/spark_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from pyspark.sql import Column, DataFrame, SparkSession
from pyspark.sql import functions as F

from pysparky import decorator
from pysparky import decorator, enabler


@decorator.extension_enabler(SparkSession)
def column_function(spark, column_obj: Column) -> DataFrame:
def column_function(spark: SparkSession, column_obj: Column) -> DataFrame:
"""
Evaluates a Column expression in the context of a single-row DataFrame.
Expand Down Expand Up @@ -71,7 +71,10 @@ def column_function(spark, column_obj: Column) -> DataFrame:

@decorator.extension_enabler(SparkSession)
def convert_dict_to_dataframe(
spark, dict_: dict[str, Any], column_names: list[str], explode: bool = False
spark: SparkSession,
dict_: dict[str, Any],
column_names: list[str],
explode: bool = False,
) -> DataFrame:
"""
Transforms a dictionary with list values into a Spark DataFrame.
Expand Down Expand Up @@ -111,8 +114,13 @@ def convert_dict_to_dataframe(


@decorator.extension_enabler(SparkSession)
@decorator.column_name_or_column_names_enabler("column_names")
def convert_1d_list_to_dataframe(spark, list_, column_names, axis="column"):
# @decorator.column_name_or_column_names_enabler("column_names")
def convert_1d_list_to_dataframe(
spark: SparkSession,
list_: list[Any],
column_names: str | list[str],
axis: str = "column",
) -> DataFrame:
"""
Converts a 1-dimensional list into a PySpark DataFrame.
Expand Down Expand Up @@ -156,20 +164,25 @@ def convert_1d_list_to_dataframe(spark, list_, column_names, axis="column"):
| 1| 2| 3| 4|
+---+---+---+---+
"""
column_names = enabler.column_name_or_column_names_enabler(column_names)

if axis not in ["column", "row"]:
raise AttributeError
raise AttributeError(
f"Invalid axis value: {axis}. Acceptable values are 'column' or 'row'."
)

if axis == "column":
tuple_list = ((x,) for x in list_)
output_sdf = spark.createDataFrame(tuple_list, schema=column_names)
tuple_list = ((x,) for x in list_) # type: ignore
elif axis == "row":
tuple_list = (tuple(list_),)
output_sdf = spark.createDataFrame(tuple_list, schema=column_names)
tuple_list = (tuple(list_),) # type: ignore

output_sdf = spark.createDataFrame(tuple_list, schema=column_names)

return output_sdf


@decorator.extension_enabler(SparkSession)
def createDataFrame_from_dict(spark, dict_: dict) -> DataFrame:
def createDataFrame_from_dict(spark: SparkSession, dict_: dict) -> DataFrame:
"""
Creates a Spark DataFrame from a dictionary in a pandas-like style.
Expand Down
29 changes: 1 addition & 28 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,7 @@
from pyspark.sql import functions as F

# Now import the decorators
from pysparky.decorator import (extension_enabler,
pyspark_column_or_name_enabler)


def test_pyspark_column_or_name_enabler(spark):
# It require spark session in the decorator
@pyspark_column_or_name_enabler("col1", "col2")
def test_function(col1, col2, col3):
return col1, col2, col3

# Test with string input
result = test_function("name1", "name2", "name3")
assert isinstance(result[0], Column)
assert isinstance(result[1], Column)
assert isinstance(result[2], str)

# Test with Column input
col_input = F.col("col_name")
result = test_function(col_input, "name2", col_input)
assert result[0] is col_input
assert isinstance(result[1], Column)
assert result[2] is col_input

# Test with keyword arguments
result = test_function(col1="name1", col2=Column("col2"), col3="name3")
assert isinstance(result[0], Column)
assert isinstance(result[1], Column)
assert isinstance(result[2], str)
from pysparky.decorator import extension_enabler


def test_extension_enabler():
Expand Down
Loading

0 comments on commit bece33a

Please sign in to comment.