Skip to content

Commit

Permalink
feat: changing column enabler from decorator to function
Browse files Browse the repository at this point in the history
  • Loading branch information
cenzwong committed Oct 17, 2024
1 parent ea99b1e commit e899d7b
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 72 deletions.
Binary file modified dist/pysparky-0.1.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/pysparky-0.1.0.tar.gz
Binary file not shown.
105 changes: 79 additions & 26 deletions example/dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand All @@ -23,25 +23,6 @@
"text": [
"3.5.2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/10/03 15:42:41 WARN Utils: Your hostname, codespaces-0aafae resolves to a loopback address: 127.0.0.1; using 10.0.1.110 instead (on interface eth0)\n",
"24/10/03 15:42:41 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"24/10/03 15:42:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"24/10/03 15:42:44 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"24/10/03 15:42:56 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors\n"
]
}
],
"source": [
Expand All @@ -60,13 +41,62 @@
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from pysparky import functions as F_"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Column<'trim(regexp_replace(hi, \\s+, , 1))'>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F_.single_space_and_trim(\"hi\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Column<'CASE WHEN (hi = 1) THEN Ture END'>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F_.when_mapping(\"hi\", {1: \"Ture\"})"
]
},
{
"cell_type": "code",
"execution_count": 3,
Expand Down Expand Up @@ -128,11 +158,7 @@
"se.convert_1d_list_to_dataframe(\n",
" spark, my_list, [\"ID1\", \"ID2\", \"ID3\", \"ID4\"], axis=\"row\"\n",
").show()\n",
"se.convert_1d_list_to_dataframe(spark, my_list, \"ID\", axis=\"column\").show()\n",
"spark.convert_1d_list_to_dataframe(\n",
" my_list, [\"ID1\", \"ID2\", \"ID3\", \"ID4\"], axis=\"row\"\n",
").show()\n",
"spark.convert_1d_list_to_dataframe(my_list, \"ID\", axis=\"column\").show()"
"se.convert_1d_list_to_dataframe(spark, my_list, \"ID\", axis=\"column\").show()"
]
},
{
Expand Down Expand Up @@ -201,6 +227,33 @@
"result == [(3, 4)]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Column<'hello'>,)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_columns"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
1 change: 0 additions & 1 deletion pysparky.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ python -m build

# TODO
- Change pytest test case
- Build mkdocs -> to make it standard to ingest to MkDocs
- Build wheels for PyPi

# Reference:
Expand Down
3 changes: 3 additions & 0 deletions pysparky.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ pyproject.toml
pysparky/__init__.py
pysparky/debug.py
pysparky/decorator.py
pysparky/enabler.py
pysparky/quality.py
pysparky/reader_options.py
pysparky/schema_ext.py
pysparky/spark_ext.py
pysparky/transformation_ext.py
pysparky/typing.py
pysparky/utils.py
pysparky.egg-info/PKG-INFO
pysparky.egg-info/SOURCES.txt
Expand All @@ -21,6 +23,7 @@ pysparky/functions/general.py
pysparky/functions/math_.py
tests/test_debug.py
tests/test_decorator.py
tests/test_enabler.py
tests/test_quality.py
tests/test_schema_ext.py
tests/test_spark_ext.py
Expand Down
25 changes: 25 additions & 0 deletions pysparky/enabler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pyspark.sql import Column
from pyspark.sql import functions as F

from pysparky.typing import ColumnOrName


def column_or_name_enabler(*columns: ColumnOrName) -> tuple[Column, ...]:
"""
Enables PySpark functions to accept either column names (as strings) or Column objects.
Parameters:
columns (ColumnOrName): Column names (as strings) or Column objects to be converted.
Returns:
tuple[Column]: A tuple of Column objects.
Example:
>>> column_or_name_enabler("col1", "col2", F.col("col3"))
(Column<b'col1'>, Column<b'col2'>, Column<b'col3'>)
"""
return tuple(
map(
lambda column: F.col(column) if isinstance(column, str) else column, columns
)
)
11 changes: 6 additions & 5 deletions pysparky/functions/conditions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from functools import reduce
from operator import and_, or_
from typing import Union

from pyspark.sql import Column
from pyspark.sql import functions as F

from pysparky.typing import ColumnOrName

def condition_and(*conditions: Union[Column, str]) -> Column:

def condition_and(*conditions: ColumnOrName) -> Column:
"""
Combines multiple conditions using logical AND.
Args:
*conditions (Union[Column, str]): Multiple PySpark Column objects or SQL expression strings representing conditions.
*conditions (ColumnOrName): Multiple PySpark Column objects or SQL expression strings representing conditions.
Returns:
Column: A single PySpark Column object representing the combined condition.
Expand All @@ -29,12 +30,12 @@ def condition_and(*conditions: Union[Column, str]) -> Column:
return reduce(and_, parsed_conditions, F.lit(True))


def condition_or(*conditions: Union[Column, str]) -> Column:
def condition_or(*conditions: ColumnOrName) -> Column:
"""
Combines multiple conditions using logical OR.
Args:
*conditions (Union[Column, str]): Multiple PySpark Column objects or SQL expression strings representing conditions.
*conditions (ColumnOrName): Multiple PySpark Column objects or SQL expression strings representing conditions.
Returns:
Column: A single PySpark Column object representing the combined condition.
Expand Down
36 changes: 20 additions & 16 deletions pysparky/functions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pyspark.sql import functions as F

from pysparky import decorator, utils
from pysparky.enabler import column_or_name_enabler
from pysparky.typing import ColumnOrName


@decorator.extension_enabler(Column)
Expand Down Expand Up @@ -67,7 +69,7 @@ def chain(self, func, *args, **kwargs) -> Column:
@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def startswiths(
column_or_name: str | Column, list_of_strings: list[str]
column_or_name: ColumnOrName, list_of_strings: list[str]
) -> pyspark.sql.Column:
"""
Creates a PySpark Column expression to check if the given column starts with any string in the list.
Expand All @@ -88,9 +90,8 @@ def startswiths(


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def replace_strings_to_none(
column_or_name: str | Column,
column_or_name: ColumnOrName,
list_of_null_string: list[str],
customize_output: Any = None,
) -> pyspark.sql.Column:
Expand All @@ -104,14 +105,13 @@ def replace_strings_to_none(
Column: A Spark DataFrame column with the values replaced.
"""

return F.when(column_or_name.isin(list_of_null_string), customize_output).otherwise(
column_or_name
)
(column,) = column_or_name_enabler(column_or_name)

return F.when(column.isin(list_of_null_string), customize_output).otherwise(column)


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def single_space_and_trim(column_or_name: str | Column) -> Column:
def single_space_and_trim(column_or_name: ColumnOrName) -> Column:
"""
Replaces multiple white spaces with a single space and trims the column.
Expand All @@ -126,8 +126,7 @@ def single_space_and_trim(column_or_name: str | Column) -> Column:


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def get_value_from_map(column_or_name: str | Column, dict_: dict) -> Column:
def get_value_from_map(column_or_name: ColumnOrName, dict_: dict) -> Column:
"""
Retrieves a value from a map (dictionary) using a key derived from a specified column in a DataFrame.
Expand All @@ -153,12 +152,13 @@ def get_value_from_map(column_or_name: str | Column, dict_: dict) -> Column:
| 2| b|
+----------+-----+
"""
return utils.create_map_from_dict(dict_)[column_or_name]
(column,) = column_or_name_enabler(column_or_name)

return utils.create_map_from_dict(dict_)[column]


@decorator.extension_enabler(Column)
@decorator.pyspark_column_or_name_enabler("column_or_name")
def when_mapping(column_or_name: Column, dict_: dict) -> Column:
def when_mapping(column_or_name: ColumnOrName, dict_: dict) -> Column:
"""
Applies a series of conditional mappings to a PySpark Column based on a dictionary of conditions and values.
Expand All @@ -169,7 +169,11 @@ def when_mapping(column_or_name: Column, dict_: dict) -> Column:
Returns:
Column: A new PySpark Column with the conditional mappings applied.
"""
result_column = F # initiate as an functions
for condition, value in dict_.items():
result_column = result_column.when(column_or_name == condition, value)
(column,) = column_or_name_enabler(column_or_name)

def reducer(result_column: Column, condition_value: tuple[Any, Any]) -> Column:
condition, value = condition_value
return result_column.when(column == condition, value)

result_column: Column = functools.reduce(reducer, dict_.items(), F) # type: ignore
return result_column
Loading

0 comments on commit e899d7b

Please sign in to comment.