Skip to content

Commit

Permalink
🦘 test groupby logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Oct 23, 2023
1 parent f3a9496 commit 8dee606
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
94 changes: 93 additions & 1 deletion tests/test_features_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings

import numpy as np
import pytest

from tsflex.features import (
FeatureCollection,
Expand All @@ -17,7 +18,7 @@
)
from tsflex.utils.data import flatten

from .utils import dummy_data, logging_file_path
from .utils import dummy_data, dummy_group_data, logging_file_path

test_path = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -651,6 +652,97 @@ def test_simple_features_logging_segment_start_and_end_idxs_overrule_stride_and_
assert all(series_names_df["duration %"]["mean"] > 0)


@pytest.mark.parametrize("group_by", ["group_by_all", "group_by_consecutive"])
def test_simple_features_logging_groupby(dummy_group_data, logging_file_path, group_by):
# Add no stride
dummy_data = dummy_group_data.reset_index(drop=True)
fd = FeatureDescriptor(
function=np.sum,
series_name="number_sold",
window=50,
stride=120,
)
fc = FeatureCollection(feature_descriptors=fd)
fc.add(
MultipleFeatureDescriptors(
np.min, series_names=["product", "number_sold"], windows=50, strides=20
)
)
for fd in flatten(fc._feature_desc_dict.values()):
assert fd.stride is not None

assert set(fc.get_required_series()) == set(["number_sold", "product"])
assert len(fc.get_required_series()) == 2

assert not os.path.exists(logging_file_path)

# Sequential (n_jobs <= 1), otherwise file_path gets cleared
_ = fc.calculate(
dummy_data, logging_file_path=logging_file_path, n_jobs=1, **{group_by: "store"}
)

assert os.path.exists(logging_file_path)
logging_df = get_feature_logs(logging_file_path)

assert all(
logging_df.columns.values
== [
"log_time",
"function",
"series_names",
"window",
"stride",
"output_names",
"duration",
"duration %",
]
)

assert len(logging_df) == 3
assert logging_df.select_dtypes(include=[np.datetime64]).columns.values == [
"log_time"
]
assert logging_df.select_dtypes(include=[np.timedelta64]).columns.values == [
"duration"
]

assert np.isclose(logging_df["duration %"].sum(), 100, atol=0.5)

assert set(logging_df["function"].values) == set(["amin", "sum"])
assert set(logging_df["series_names"].values) == set(
["(number_sold,)", "(product,)"]
)
assert set(logging_df["output_names"].values) == set(
[
"number_sold__sum__w=manual",
"number_sold__amin__w=manual",
"product__amin__w=manual",
]
)
assert all(logging_df["window"] == "manual")
assert all(logging_df["stride"] == "manual")

function_stats_df = get_function_stats(logging_file_path)
assert len(function_stats_df) == 2
assert set(function_stats_df.index) == set(
[(s, "manual", "manual") for s in ["sum", "amin"]]
)
assert all(function_stats_df["duration"]["mean"] > 0)
assert function_stats_df["duration"]["count"].sum() == 3
assert np.isclose(function_stats_df["duration %"]["sum"].sum(), 100, atol=0.5)
assert all(function_stats_df["duration %"]["mean"] > 0)

series_names_df = get_series_names_stats(logging_file_path)
assert len(series_names_df) == 2
assert set(series_names_df.index) == set(
[(s, "manual", "manual") for s in ["(number_sold,)", "(product,)"]]
)
assert all(series_names_df["duration"]["mean"] > 0)
assert series_names_df["duration"]["count"].sum() == 3
assert np.isclose(series_names_df["duration %"]["sum"].sum(), 100, atol=0.5)
assert all(series_names_df["duration %"]["mean"] > 0)


def test_file_warning_features_logging(dummy_data, logging_file_path):
fd = FeatureDescriptor(
function=np.sum,
Expand Down
4 changes: 3 additions & 1 deletion tsflex/features/feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def f(x: pd.DataFrame):
out = np.array(list(map(f, [data.iloc[idx] for idx in group_indices.values()])))

# Aggregate function output in a dictionary
output_names = ["|".join(cols) + "__" + o for o in function.output_names]
output_names = [
"|".join(cols) + "__" + o + "__w=manual" for o in function.output_names
]
feat_out = _process_func_output(out, group_ids, output_names, str(function))
# Log the function execution time
_log_func_execution(
Expand Down

0 comments on commit 8dee606

Please sign in to comment.