Skip to content

Commit

Permalink
🔍 review
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasvdd committed Jan 5, 2024
1 parent 9526a32 commit 7e2d5c6
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions tsflex/features/feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ def _executor_stroll(idx: int) -> pd.DataFrame:
"""Executor function for the StridedRolling.apply_func method.
Strided rolling feature calculation occurs when either;
- a `window` and `stride` argument is stored in the `FeatureDescriptor` object
(or the `stride` argument is passed to the `calculate` method)
- a `window` and `stride` argument are stored in the `FeatureDescriptor` object
- the `window` is stored in the `FeatureDescriptor` object and the `stride`
argument is passed to the `calculate` method, potentially overriding the
`stride`
- segment indices are passed to the `calculate` method
- a `group_by_consecutive` argument is passed to the `calculate` method (since
we calculate the segment indices for the consecutive groups)
Expand Down Expand Up @@ -339,9 +341,11 @@ def f(x: pd.DataFrame):

# Aggregate function output in a dictionary
output_names = [
"|".join(cols) + "__" + o + "__w=manual" for o in function.output_names
StridedRolling.construct_output_index(cols, feat_name, win_str="manual")
for feat_name in function.output_names
]
feat_out = _process_func_output(out, group_ids, output_names, str(function))

# Log the function execution time
_log_func_execution(
t_start, function, tuple(cols), "manual", "manual", output_names
Expand All @@ -353,18 +357,16 @@ def _group_feat_generator(
self,
grouped_df: pd.api.typing.DataFrameGroupBy,
) -> Callable[[int], Tuple[pd.api.typing.DataFrameGroupBy, FuncWrapper,],]:
keys_wins_strides = list(self._feature_desc_dict.keys())
lengths = np.cumsum(
[len(self._feature_desc_dict[k]) for k in keys_wins_strides]
)
keys_wins = list(self._feature_desc_dict.keys())
lengths = np.cumsum([len(self._feature_desc_dict[k]) for k in keys_wins])

def get_group_function(
idx,
) -> Tuple[pd.api.typing.DataFrameGroupBy, FuncWrapper,]:
key_idx = np.searchsorted(lengths, idx, "right") # right bc idx starts at 0
key, win = keys_wins_strides[key_idx]
key, win = keys_wins[key_idx]

feature = self._feature_desc_dict[keys_wins_strides[key_idx]][
feature = self._feature_desc_dict[keys_wins[key_idx]][
idx - lengths[key_idx]
]
function: FuncWrapper = feature.function
Expand Down

0 comments on commit 7e2d5c6

Please sign in to comment.