Skip to content

Commit

Permalink
🧹 cleanup groupby nan behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Oct 14, 2023
1 parent 913ed6e commit 6c61ec9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
77 changes: 61 additions & 16 deletions tests/test_features_feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,44 @@ def assert_results(data, res_data):
assert_results(data_count_max, grouped_res_df_max)


@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_multiple_series_multiple_features_group_by(dummy_group_data, group_by):
def sum_2(x: np.ndarray, y: np.ndarray) -> float:
return np.sum(x) + np.sum(y)

fd1 = FeatureDescriptor(function=np.sum, series_name="number_sold")
fd2 = FeatureDescriptor(function=np.sum, series_name="product")
fd3 = FeatureDescriptor(function=sum_2, series_name=("number_sold", "product"))

fc = FeatureCollection(feature_descriptors=[fd1, fd2, fd3])

assert set(fc.get_required_series()) == set(["number_sold", "product"])
assert fc.get_nb_output_features() == 3

res_list = fc.calculate(
dummy_group_data, return_df=False, n_jobs=1, **{group_by: "store"}
)
res_df = fc.calculate(
dummy_group_data, return_df=True, n_jobs=1, **{group_by: "store"}
)

assert isinstance(res_list, list)
assert isinstance(res_df, pd.DataFrame)

concatted_df = pd.concat(res_list, axis=1)
assert len(concatted_df.columns) == len(res_df.columns)
concatted_df = concatted_df[res_df.columns] # assure column order is the same

assert_frame_equal(concatted_df, res_df)

postfix = "__w=manual" if "consecutive" in group_by else ""
assert all(
res_df["number_sold__sum" + postfix].values
+ res_df["product__sum" + postfix].values
== res_df["number_sold|product__sum_2" + postfix].values
)


@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_group_by_with_nan_values(dummy_group_data, group_by):
fd = FeatureDescriptor(
Expand Down Expand Up @@ -217,12 +255,10 @@ def test_group_by_with_nan_values(dummy_group_data, group_by):

@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_group_by_with_unequal_lengths(group_by):
fd = FeatureDescriptor(
function=np.sum,
series_name="count",
)
fd1 = FeatureDescriptor(function=np.sum, series_name="count")
fd2 = FeatureDescriptor(function=np.nansum, series_name="count")

fc = FeatureCollection(feature_descriptors=fd)
fc = FeatureCollection(feature_descriptors=[fd1, fd2])

# create the dummy data
s_group = pd.Series(
Expand Down Expand Up @@ -264,24 +300,31 @@ def test_group_by_with_unequal_lengths(group_by):
data=np.arange(30),
name="count",
)
res_list = fc.calculate([s_group, s_val], return_df=True, **{group_by: "user_id"})
res_list = fc.calculate(
[s_group, s_val], return_df=True, n_jobs=1, **{group_by: "user_id"}
)
res_list2 = fc.calculate(
[s_group2, s_val2], return_df=True, **{group_by: "user_id"}
[s_group2, s_val2], return_df=True, n_jobs=1, **{group_by: "user_id"}
)
col = "count__sum"
col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
res_list2[col] = res_list2[col].astype(res_list.dtypes[col])
correct_res_list = fc.calculate(
[s_group, s_val2], return_df=True, **{group_by: "user_id"}
[s_group, s_val2], return_df=True, n_jobs=1, **{group_by: "user_id"}
)

assert_frame_equal(res_list, res_list2)
for c in res_list.columns:
# compare (compare_col) only with nan-safe col in case of group_by_all
compare_col = c if "consecutive" in group_by else col
assert np.all(
res_list[c]
== res_list2.loc[res_list.index, compare_col].astype(res_list.dtypes[c])
)
assert_frame_equal(res_list, correct_res_list)


@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_group_by_non_aligned_indices(group_by):
fd = FeatureDescriptor(function=np.sum, series_name="count")
fd = FeatureDescriptor(function=np.nansum, series_name="count")
fc = FeatureCollection(feature_descriptors=fd)

# create the dummy data
Expand All @@ -302,19 +345,21 @@ def test_group_by_non_aligned_indices(group_by):
).reset_index()
grouped_non_nan_df_sums = non_nan_df.groupby("groups").sum()

col = "count__sum"
col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
new_res_list = pd.DataFrame(
{"groups": res_list["user_id"], "values": res_list[col]}
)
new_res_list = new_res_list.set_index("groups")

assert_frame_equal(new_res_list, grouped_non_nan_df_sums)
assert_frame_equal(
new_res_list.loc[grouped_non_nan_df_sums.index], grouped_non_nan_df_sums
)


@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_group_by_with_numeric_index(group_by):
fd = FeatureDescriptor(function=np.sum, series_name="count")
fd = FeatureDescriptor(function=np.nansum, series_name="count")
fc = FeatureCollection(feature_descriptors=fd)

s_group = pd.Series(
Expand Down Expand Up @@ -349,7 +394,7 @@ def test_group_by_with_numeric_index(group_by):
s_df = pd.DataFrame({"groups": s_group, "values": s_val})

data_counts = s_df.groupby("groups")["values"].sum()
col = "count__sum"
col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
result_data_counts = res_df.groupby("user_id")[col].sum()

Expand Down
8 changes: 3 additions & 5 deletions tsflex/features/feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,7 @@ def _group_by_all(
df = pd.DataFrame(series_dict)
assert col_name in df.columns

# Drop all rows with NaN values
df.dropna(inplace=True)

# GroupBy ignores all rows with NaN values for the column on which we group
return df.groupby(col_name)

def _calculate_group_by_all(
Expand Down Expand Up @@ -553,8 +551,8 @@ def _group_by_consecutive(

assert col_name in df.columns

# Drop all rows with NaN values
df.dropna(inplace=True)
# Drop all rows with NaN values for the column on which we group
df.dropna(subset=[col_name], inplace=True)

df_cum = (
(df[col_name] != df[col_name].shift(1))
Expand Down

0 comments on commit 6c61ec9

Please sign in to comment.