Skip to content

Commit

Permalink
add check for columns in test
Browse files Browse the repository at this point in the history
  • Loading branch information
alanpo1 committed Jul 23, 2024
1 parent 85e235d commit 66841df
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def read_mltable_in_spark(mltable_path: str):
raise se


def save_spark_df_as_mltable(metrics_df, folder_path: str):
def save_spark_df_as_mltable(metrics_df: DataFrame, folder_path: str):
"""Save spark dataframe as mltable."""
base_path = folder_path.rstrip('/')
output_path_pattern = base_path + "/data/*.parquet"
Expand Down
10 changes: 9 additions & 1 deletion assets/model_monitoring/components/tests/unit/test_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.shared_utilities.io_utils import (
init_spark,
save_spark_df_as_mltable,
try_read_mltable_in_spark_with_error,
)


Expand All @@ -19,5 +20,12 @@ def test_save_spark_df_as_mltable_duplicate_case_sensitive_columns(self):
"""Test the save dataframe as mltable functionality."""
spark = init_spark()
production_df = spark.createDataFrame([(1, "c", "bob"), (2, "d", "BOB")], ["id", "age", "Id"])
save_spark_df_as_mltable(production_df, "localData")
save_spark_df_as_mltable(production_df, "./localData")
assert os.path.exists("localData/") is True
assert os.path.exists("localData/data/") is True

saved_df = try_read_mltable_in_spark_with_error("./localData", "test-data")
print("Debug logs for saved dataframe:")
saved_df.show()
assert "id" in saved_df.columns
assert "Id" in saved_df.columns

0 comments on commit 66841df

Please sign in to comment.