Skip to content

Commit

Permalink
🔥 disallow untyped defs
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Apr 18, 2024
1 parent c9f0b5a commit 1453e8d
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ warn_redundant_casts = true
warn_unused_ignores = true
check_untyped_defs = true
no_implicit_reexport = true
disallow_untyped_defs = false # TODO: enable
disallow_untyped_defs = true
# disallow_any_generics = false
ignore_missing_imports = true
# allow_redefinition = true
Expand Down
22 changes: 12 additions & 10 deletions tsflex/chunking/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def _chunk_time_data(
min_chunk_dur: Optional[Union[str, pd.Timedelta]] = None,
max_chunk_dur: Optional[Union[str, pd.Timedelta]] = None,
sub_chunk_overlap: Union[str, pd.Timedelta] = "0s",
copy=True,
verbose=False,
):
copy: bool = True,
verbose: bool = False,
) -> List[List[pd.Series]]:
if min_chunk_dur is not None:
min_chunk_dur = parse_time_arg(min_chunk_dur)
if max_chunk_dur is not None:
Expand Down Expand Up @@ -61,7 +61,9 @@ def _chunk_time_data(
# Each list item can be seen as (t_start_chunk, t_end_chunk, chunk_list)
same_range_chunks: List[Tuple[pd.Timestamp, pd.Timestamp, List[pd.Series]]] = []

def print_verbose_time(sig, t_begin, t_end, msg=""):
def print_verbose_time(
sig: pd.Series, t_begin: pd.Timestamp, t_end: pd.Timestamp, msg: str = ""
) -> None:
fmt = "%Y-%m-%d %H:%M"
if not verbose:
return
Expand All @@ -81,7 +83,7 @@ def slice_time(
else:
return sig[t_begin:t_end]

def insert_chunk(chunk: pd.Series):
def insert_chunk(chunk: pd.Series) -> None:
"""Insert the chunk into `same_range_chunks`."""
t_chunk_start, t_chunk_end = chunk.index[[0, -1]]

Expand Down Expand Up @@ -194,9 +196,9 @@ def _chunk_sequence_data(
min_chunk_dur: Optional[float] = None,
max_chunk_dur: Optional[float] = None,
sub_chunk_overlap: float = 0,
copy=True,
verbose=False,
):
copy: bool = True,
verbose: bool = False,
) -> List[List[pd.Series]]:
raise NotImplementedError("Not implemented yet")


Expand All @@ -218,8 +220,8 @@ def chunk_data(
min_chunk_dur: Optional[Union[float, str, pd.Timedelta]] = None,
max_chunk_dur: Optional[Union[float, str, pd.Timedelta]] = None,
sub_chunk_overlap: Union[float, str, pd.Timedelta] = "0s", # TODO: make optional
copy=True,
verbose=False,
copy: bool = True,
verbose: bool = False,
) -> List[List[pd.Series]]:
"""Divide the time-series `data` in same time/sequence-range chunks.
Expand Down
2 changes: 1 addition & 1 deletion tsflex/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
):
# Cast functions to FuncWrapper, this avoids creating multiple
# FuncWrapper objects for the same function in the FeatureDescriptor
def to_func_wrapper(f: Callable):
def to_func_wrapper(f: Callable) -> FuncWrapper:
return f if isinstance(f, FuncWrapper) else FuncWrapper(f)

functions = [to_func_wrapper(f) for f in to_list(functions)]
Expand Down
22 changes: 11 additions & 11 deletions tsflex/features/feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _check_feature_descriptors(
self,
skip_none: bool,
calc_stride: Optional[Union[float, pd.Timedelta, None]] = None,
):
) -> None:
"""Verify whether all added FeatureDescriptors imply the same-input data type.
If this condition is not met, a warning will be raised.
Expand Down Expand Up @@ -195,7 +195,7 @@ def _check_feature_descriptors(
category=RuntimeWarning,
)

def _add_feature(self, feature: FeatureDescriptor):
def _add_feature(self, feature: FeatureDescriptor) -> None:
"""Add a `FeatureDescriptor` instance to the collection.
Parameters
Expand Down Expand Up @@ -238,7 +238,7 @@ def add(
Union[FeatureDescriptor, MultipleFeatureDescriptors, FeatureCollection]
],
],
):
) -> None:
"""Add feature(s) to the FeatureCollection.
Parameters
Expand Down Expand Up @@ -324,13 +324,13 @@ def _executor_grouped(idx: int) -> pd.DataFrame:
f = function
if function.input_type is np.array:

def f(x: pd.DataFrame):
def f(x: pd.DataFrame) -> Any:
# pass the inputs as positional arguments of numpy array type
return function(*[x[c].values for c in cols_tuple])

else: # function.input_type is pd.Series

def f(x: pd.DataFrame):
def f(x: pd.DataFrame) -> Any:
# pass the inputs as positional arguments of pd.Series type
return function(*[x[c] for c in cols_tuple])

Expand Down Expand Up @@ -373,7 +373,7 @@ def _stroll_feat_generator(
[len(self._feature_desc_dict[k]) for k in keys_wins_strides]
)

def get_stroll_function(idx) -> Tuple[StridedRolling, FuncWrapper]:
def get_stroll_function(idx: int) -> Tuple[StridedRolling, FuncWrapper]:
key_idx = np.searchsorted(lengths, idx, "right") # right bc idx starts at 0
key, win = keys_wins_strides[key_idx]

Expand Down Expand Up @@ -416,7 +416,7 @@ def _group_feat_generator(
lengths = np.cumsum([len(self._feature_desc_dict[k]) for k in keys_wins])

def get_group_function(
idx,
idx: int,
) -> Tuple[pd.api.typing.DataFrameGroupBy, FuncWrapper,]:
key_idx = np.searchsorted(lengths, idx, "right") # right bc idx starts at 0
key, win = keys_wins[key_idx]
Expand All @@ -429,7 +429,7 @@ def get_group_function(

return get_group_function

def _check_no_multiple_windows(self, error_case: str):
def _check_no_multiple_windows(self, error_case: str) -> None:
"""Check whether there are no multiple windows in the feature collection.
Parameters
Expand Down Expand Up @@ -633,11 +633,11 @@ def _group_by_consecutive(

return df_grouped

def _calculate_group_by_consecutive(
def _calculate_group_by_consecutive( # type: ignore[no-untyped-def]
self,
data: Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]],
group_by: str,
return_df: Optional[bool] = False,
return_df: bool = False,
**calculate_kwargs,
) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""Calculate features on each consecutive group of the data.
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def calculate(
f_handler,
)

def serialize(self, file_path: Union[str, Path]):
def serialize(self, file_path: Union[str, Path]) -> None:
"""Serialize this FeatureCollection instance.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion tsflex/features/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FuncWrapper(FrozenClass):
"""

def __init__(
def __init__( # type: ignore[no-untyped-def]
self,
func: Callable,
output_names: Optional[Union[List[str], str]] = None,
Expand Down
10 changes: 5 additions & 5 deletions tsflex/features/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jeroen Van Der Donckt, Jonas Van Der Donckt"

import importlib
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -35,7 +35,7 @@ def seglearn_wrapper(func: Callable, func_name: Optional[str] = None) -> FuncWra
"""

def wrap_func(x: np.ndarray):
def wrap_func(x: np.ndarray) -> np.ndarray:
out = func(x.reshape(1, len(x)))
return out.flatten()

Expand Down Expand Up @@ -144,7 +144,7 @@ def tsfel_feature_dict_wrapper(features_dict: Dict) -> List[FuncWrapper]:
"""

def get_output_names(config: dict):
def get_output_names(config: dict) -> Union[str, List[str]]:
"""Create the output_names based on the configuration."""
nb_outputs = config["n_features"]
func_name = config["function"].split(".")[-1]
Expand Down Expand Up @@ -203,7 +203,7 @@ def tsfresh_combiner_wrapper(func: Callable, param: List[Dict]) -> FuncWrapper:
"""

def wrap_func(x: Union[np.ndarray, pd.Series]):
def wrap_func(x: Union[np.ndarray, pd.Series]) -> Tuple[Any, ...]:
out = func(x, param)
return tuple(t[1] for t in out)

Expand Down Expand Up @@ -330,7 +330,7 @@ def catch22_wrapper(catch22_all: Callable) -> FuncWrapper:
"""
catch22_names = catch22_all([0])["names"]

def wrap_catch22_all(x):
def wrap_catch22_all(x: np.ndarray) -> List[float]:
return catch22_all(x)["values"]

wrap_catch22_all.__name__ = "[wrapped]__" + _get_name(catch22_all)
Expand Down
2 changes: 1 addition & 1 deletion tsflex/features/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_function_stats(logging_file_path: str) -> pd.DataFrame:
.index.to_list()
)

def key_func(idx_level):
def key_func(idx_level): # type: ignore[no-untyped-def]
if all(idx in sorted_funcs for idx in idx_level):
return [sorted_funcs.index(idx) for idx in idx_level]
return idx_level
Expand Down
33 changes: 22 additions & 11 deletions tsflex/features/segmenter/strided_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(
RuntimeWarning,
)

def _calc_nb_segments_for_stride(self, stride) -> int:
def _calc_nb_segments_for_stride(self, stride: T) -> int:
"""Calculate the number of output items (segments) for a given single stride."""
assert self.start is not None and self.end is not None # for mypy
nb_feats = max((self.end - self.start - self.window) // stride + 1, 0)
Expand Down Expand Up @@ -293,7 +293,10 @@ def _get_output_index(
)

def _construct_series_containers(
self, series_list, np_start_times, np_end_times
self,
series_list: List[pd.Series],
np_start_times: np.ndarray,
np_end_times: np.ndarray,
) -> List[StridedRolling._NumpySeriesContainer]:
series_containers: List[StridedRolling._NumpySeriesContainer] = []
for series in series_list:
Expand Down Expand Up @@ -487,7 +490,7 @@ def apply_func(self, func: FuncWrapper) -> pd.DataFrame:

# --------------------------------- STATIC METHODS ---------------------------------
@staticmethod
def _get_np_value(val):
def _get_np_value(val: Union[np.number, pd.Timestamp, pd.Timedelta]) -> np.number:
# Convert everything to int64
if isinstance(val, pd.Timestamp):
return val.to_datetime64()
Expand All @@ -505,7 +508,9 @@ def construct_output_index(

# ----------------------------- OVERRIDE THESE METHODS -----------------------------
@abstractmethod
def _update_start_end_indices_to_stroll_type(self, series_list: List[pd.Series]):
def _update_start_end_indices_to_stroll_type(
self, series_list: List[pd.Series]
) -> None:
# NOTE: This method will only be implemented (with code != pass) in the
# TimeIndexSampleStridedRolling
raise NotImplementedError
Expand All @@ -522,7 +527,7 @@ def _create_feat_col_name(self, feat_name: str) -> str:


class SequenceStridedRolling(StridedRolling):
def __init__(
def __init__( # type: ignore[no-untyped-def]
self,
data: Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]],
window: Union[int, float],
Expand All @@ -535,7 +540,9 @@ def __init__(
super().__init__(data, window, strides, *args, **kwargs)

# ------------------------------- Overridden methods -------------------------------
def _update_start_end_indices_to_stroll_type(self, series_list: List[pd.Series]):
def _update_start_end_indices_to_stroll_type(
self, series_list: List[pd.Series]
) -> None:
pass

def _parse_segment_idxs(self, segment_idxs: np.ndarray) -> np.ndarray:
Expand All @@ -554,7 +561,7 @@ def _create_feat_col_name(self, feat_name: str) -> str:


class TimeStridedRolling(StridedRolling):
def __init__(
def __init__( # type: ignore[no-untyped-def]
self,
data: Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]],
window: pd.Timedelta,
Expand Down Expand Up @@ -589,7 +596,9 @@ def _get_output_index(
return super()._get_output_index(start_idxs, end_idxs, name)

# ------------------------------- Overridden methods -------------------------------
def _update_start_end_indices_to_stroll_type(self, series_list: List[pd.Series]):
def _update_start_end_indices_to_stroll_type(
self, series_list: List[pd.Series]
) -> None:
pass

def _parse_segment_idxs(self, segment_idxs: np.ndarray) -> np.ndarray:
Expand All @@ -616,7 +625,7 @@ def _create_feat_col_name(self, feat_name: str) -> str:


class TimeIndexSampleStridedRolling(SequenceStridedRolling):
def __init__(
def __init__( # type: ignore[no-untyped-def]
self,
# TODO -> update arguments
data: Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]],
Expand Down Expand Up @@ -678,7 +687,9 @@ def apply_func(self, func: FuncWrapper) -> pd.DataFrame:
return df

# ---------------------------- Overridden methods ------------------------------
def _update_start_end_indices_to_stroll_type(self, series_list: List[pd.Series]):
def _update_start_end_indices_to_stroll_type(
self, series_list: List[pd.Series]
) -> None:
# update the start and end times to the sequence datatype
self.start, self.end = np.searchsorted(
series_list[0].index.values,
Expand All @@ -689,7 +700,7 @@ def _update_start_end_indices_to_stroll_type(self, series_list: List[pd.Series])

def _sliding_strided_window_1d(
data: np.ndarray, window: int, step: int, nb_segments: int
):
) -> np.ndarray:
"""View based sliding strided-window for 1-dimensional data.
Parameters
Expand Down
20 changes: 17 additions & 3 deletions tsflex/features/segmenter/strided_rolling_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

__author__ = "Jonas Van Der Donckt"

from typing import List, Optional, Union

import pandas as pd

from ...utils.attribute_parsing import AttributeParser, DataType
from .strided_rolling import (
SequenceStridedRolling,
Expand All @@ -26,7 +30,12 @@ class StridedRollingFactory:
}

@staticmethod
def get_segmenter(data, window, strides, **kwargs) -> StridedRolling:
def get_segmenter( # type: ignore[no-untyped-def]
data: Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]],
window: Union[int, float, pd.TimeDelta],
strides: Optional[List[Union[int, float, pd.TimeDelta]]],
**kwargs,
) -> StridedRolling:
"""Get the appropriate StridedRolling instance for the passed data.
The returned instance will be determined by the data its index type
Expand All @@ -35,9 +44,9 @@ def get_segmenter(data, window, strides, **kwargs) -> StridedRolling:
----------
data : Union[pd.Series, pd.DataFrame, List[Union[pd.Series, pd.DataFrame]]]
The data to segment.
window : Union[float, pd.TimeDelta]
window : Union[int, float, pd.TimeDelta]
The window size to use for the segmentation.
strides : Union[List[Union[float, pd.TimeDelta]], None]
strides : Union[List[Union[int, float, pd.TimeDelta]], None]
The stride(s) to use for the segmentation.
**kwargs : dict, optional
Additional keyword arguments, see the `StridedRolling` its documentation
Expand Down Expand Up @@ -74,6 +83,11 @@ def get_segmenter(data, window, strides, **kwargs) -> StridedRolling:
)
elif data_dtype == DataType.TIME and args_dtype == DataType.SEQUENCE:
# Note: this is very niche and thus requires advanced knowledge
assert isinstance(window, int)
if strides is not None:
assert isinstance(strides, list) and all(
isinstance(s, int) for s in strides
)
return TimeIndexSampleStridedRolling(data, window, strides, **kwargs)
elif data_dtype == DataType.SEQUENCE and args_dtype == DataType.TIME:
raise ValueError("Cannot segment a sequence-series with a time window")
Expand Down
Loading

0 comments on commit 1453e8d

Please sign in to comment.