From 2f225556f6020088e0037b94ea5a730b1c1c073e Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Thu, 1 Jul 2021 15:55:52 +0200 Subject: [PATCH] Type annotations for pipe (#55) * Add type annotations for pipe() * add annotations in other places as well * fix formatting of import statement * update * format test Co-authored-by: Nanne Aben --- setup.py | 2 +- tests/snippets/test_frame.py | 18 ++++++++++++++++++ third_party/3/pandas/core/common.pyi | 9 +++++++-- third_party/3/pandas/core/generic.pyi | 9 +++++++-- third_party/3/pandas/core/groupby/groupby.pyi | 9 +++++++-- third_party/3/pandas/core/resample.pyi | 9 +++++++-- third_party/3/pandas/io/formats/style.pyi | 8 ++++++-- 7 files changed, 53 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 6ff8b8d..e032f1e 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def list_packages(source_path: str = src_path) -> None: setup( name="pandas-stubs", package_dir={"": src_path}, - version="1.1.0.11", + version="1.1.0.12", description="Type annotations for Pandas", long_description=(open("README.md").read() if os.path.exists("README.md") else ""), diff --git a/tests/snippets/test_frame.py b/tests/snippets/test_frame.py index 1f57ad1..655ca46 100644 --- a/tests/snippets/test_frame.py +++ b/tests/snippets/test_frame.py @@ -500,3 +500,21 @@ def test_types_from_dict() -> None: pd.DataFrame.from_dict({'a': {'row1': 2}, 'b': {'row2': 4, 'row1': 4}}) pd.DataFrame.from_dict({'a': (1, 2, 3), 'b': (2, 4, 5)}) pd.DataFrame.from_dict(data={'col_1': {'a': 1}, 'col_2': {'a': 1, 'b': 2}}, orient="columns") + + +def test_pipe() -> None: + def foo(df: pd.DataFrame) -> pd.DataFrame: + return df + + df1: pd.DataFrame = pd.DataFrame({'a': [1]}).pipe(foo) + + df2: pd.DataFrame = ( + pd.DataFrame({'price': [10, 11, 9, 13, 14, 18, 17, 19], 'volume': [50, 60, 40, 100, 50, 100, 40, 50]}) + .assign(week_starting=pd.date_range('01/01/2018', periods=8, freq='W')) + .resample('M', on='week_starting') + .pipe(foo) + ) + + df3: pd.DataFrame = pd.DataFrame({'a': [1], 'b': [1]}).groupby('a').pipe(foo) + + df4: pd.DataFrame = pd.DataFrame({'a': [1], 'b': [1]}).style.pipe(foo) diff --git a/third_party/3/pandas/core/common.pyi b/third_party/3/pandas/core/common.pyi index 2d685b6..c892a1f 100644 --- a/third_party/3/pandas/core/common.pyi +++ b/third_party/3/pandas/core/common.pyi @@ -1,6 +1,6 @@ from pandas._typing import T as T -from typing import Any, Collection, Iterable, Optional, Union +from typing import Any, Collection, Iterable, Optional, Union, Callable, Tuple, TypeVar, overload class SettingWithCopyError(ValueError): ... class SettingWithCopyWarning(Warning): ... @@ -33,5 +33,10 @@ def apply_if_callable(maybe_callable: Any, obj: Any, **kwargs: Any) -> Any: ... def dict_compat(d: Any) -> Any: ... def standardize_mapping(into: Any) -> Any: ... def random_state(state: Optional[Any] = ...) -> Any: ... -def pipe(obj: Any, func: Any, *args: Any, **kwargs: Any) -> Any: ... def get_rename_function(mapper: Any) -> Any: ... + +PipeReturn = TypeVar("PipeReturn") +@overload +def pipe(obj: Any, func: Union[Callable[..., PipeReturn], Tuple[Callable[..., PipeReturn], str]], *args: Any, **kwargs: Any) -> PipeReturn: ... +@overload +def pipe(obj: Any, func: PipeReturn, *args: Any, **kwargs: Any) -> PipeReturn: ... diff --git a/third_party/3/pandas/core/generic.pyi b/third_party/3/pandas/core/generic.pyi index b717c23..6a082e3 100644 --- a/third_party/3/pandas/core/generic.pyi +++ b/third_party/3/pandas/core/generic.pyi @@ -2,6 +2,7 @@ from __future__ import annotations import sys from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union, AnyStr, overload +from pandas.core.resample import Resampler if sys.version_info >= (3, 8): from typing import Literal @@ -24,6 +25,7 @@ from pandas.core.internals import BlockManager bool_t = bool Self = TypeVar('Self', bound=NDFrame) +PipeReturn = TypeVar('PipeReturn') class NDFrame(PandasObject, SelectionMixin, indexing.IndexingMixin): __array_priority__: int = ... @@ -110,7 +112,10 @@ class NDFrame(PandasObject, SelectionMixin, indexing.IndexingMixin): def head(self, n: int = ...) -> FrameOrSeries: ... def tail(self, n: int = ...) -> FrameOrSeries: ... def sample(self, n: int = ..., frac: float = ..., replace: bool_t = ..., weights: Union[str, ArrayLike] = ..., random_state: Union[int, np.random.RandomState] = ..., axis: Optional[AxisOption] = ...) -> FrameOrSeries: ... - def pipe(self, func: Any, *args: Any, **kwargs: Any) -> Any: ... + @overload + def pipe(self: Any, func: Union[Callable[..., PipeReturn], Tuple[Callable[..., PipeReturn], str]], *args: Any, **kwargs: Any) -> PipeReturn: ... + @overload + def pipe(self: Any, func: PipeReturn, *args: Any, **kwargs: Any) -> PipeReturn: ... def __finalize__(self, other: Any, method: Any = ..., **kwargs: Any) -> FrameOrSeries: ... @property def values(self) -> np.ndarray: ... @@ -140,7 +145,7 @@ class NDFrame(PandasObject, SelectionMixin, indexing.IndexingMixin): def asfreq(self, freq: Any, method: Any = ..., how: Optional[str]=..., normalize: bool_t=..., fill_value: Any = ...) -> FrameOrSeries: ... def at_time(self, time: Any, asof: bool_t=..., axis: Any = ...) -> FrameOrSeries: ... def between_time(self, start_time: Any, end_time: Any, include_start: bool_t=..., include_end: bool_t=..., axis: Any = ...) -> FrameOrSeries: ... - def resample(self, rule: Any, axis: Any = ..., closed: Optional[str]=..., label: Optional[str]=..., convention: str=..., kind: Optional[str]=..., loffset: Any = ..., base: int=..., on: Any = ..., level: Any = ..., origin: Union[Timestamp, str] = ..., offset: Union[Timedelta, str] = ...) -> Any: ... + def resample(self, rule: Any, axis: Any = ..., closed: Optional[str]=..., label: Optional[str]=..., convention: str=..., kind: Optional[str]=..., loffset: Any = ..., base: int=..., on: Any = ..., level: Any = ..., origin: Union[Timestamp, str] = ..., offset: Union[Timedelta, str] = ...) -> Resampler: ... def first(self, offset: Any) -> FrameOrSeries: ... def last(self, offset: Any) -> FrameOrSeries: ... def rank(self, axis: Any = ..., method: str=..., numeric_only: Optional[bool_t]=..., na_option: str=..., ascending: bool_t=..., pct: bool_t=...) -> FrameOrSeries: ... diff --git a/third_party/3/pandas/core/groupby/groupby.pyi b/third_party/3/pandas/core/groupby/groupby.pyi index 8afc826..c1bccf3 100644 --- a/third_party/3/pandas/core/groupby/groupby.pyi +++ b/third_party/3/pandas/core/groupby/groupby.pyi @@ -15,7 +15,9 @@ from pandas.core.series import Series as Series from pandas.core.sorting import get_group_index_sorter as get_group_index_sorter from pandas.errors import AbstractMethodError as AbstractMethodError from pandas.util._decorators import Appender as Appender, Substitution as Substitution, cache_readonly as cache_readonly -from typing import Any, List, Optional, Union, Hashable, Callable, Mapping +from typing import Any, List, Optional, Union, Hashable, Callable, Mapping, TypeVar, Tuple, overload + +PipeReturn = TypeVar('PipeReturn') class GroupByPlot(PandasObject): @@ -53,7 +55,10 @@ class _GroupBy(PandasObject, SelectionMixin): @property def indices(self) -> Any: ... def __getattr__(self, attr: str) -> Any: ... - def pipe(self, func: Any, *args: Any, **kwargs: Any) -> Any: ... + @overload + def pipe(self: Any, func: Union[Callable[..., PipeReturn], Tuple[Callable[..., PipeReturn], str]], *args: Any, **kwargs: Any) -> PipeReturn: ... + @overload + def pipe(self: Any, func: PipeReturn, *args: Any, **kwargs: Any) -> PipeReturn: ... plot: Any = ... def get_group(self, name: Any, obj: Optional[Any] = ...) -> Any: ... def __iter__(self) -> Any: ... diff --git a/third_party/3/pandas/core/resample.pyi b/third_party/3/pandas/core/resample.pyi index f3ead7e..c34cd2c 100644 --- a/third_party/3/pandas/core/resample.pyi +++ b/third_party/3/pandas/core/resample.pyi @@ -5,7 +5,9 @@ from pandas.core.groupby.grouper import Grouper as Grouper from pandas.core.indexes.datetimes import DatetimeIndex as DatetimeIndex from pandas.core.indexes.period import PeriodIndex as PeriodIndex from pandas.core.indexes.timedeltas import TimedeltaIndex as TimedeltaIndex -from typing import Any, Optional +from typing import Any, Optional, TypeVar, Union, Callable, Tuple, overload + +PipeReturn = TypeVar('PipeReturn') class Resampler(_GroupBy, ShallowMixin): groupby: Any = ... @@ -26,7 +28,10 @@ class Resampler(_GroupBy, ShallowMixin): def obj(self) -> Any: ... @property def ax(self) -> Any: ... - def pipe(self, func: Any, *args: Any, **kwargs: Any) -> Any: ... + @overload + def pipe(self: Any, func: Union[Callable[..., PipeReturn], Tuple[Callable[..., PipeReturn], str]], *args: Any, **kwargs: Any) -> PipeReturn: ... + @overload + def pipe(self: Any, func: PipeReturn, *args: Any, **kwargs: Any) -> PipeReturn: ... def aggregate(self, func: Any, *args: Any, **kwargs: Any) -> Any: ... agg: Any = ... apply: Any = ... diff --git a/third_party/3/pandas/io/formats/style.pyi b/third_party/3/pandas/io/formats/style.pyi index 50afd4a..b05700a 100644 --- a/third_party/3/pandas/io/formats/style.pyi +++ b/third_party/3/pandas/io/formats/style.pyi @@ -1,10 +1,11 @@ from pandas._config import get_option as get_option from pandas.util._decorators import Appender as Appender -from typing import Any, Optional +from typing import Any, Optional, Callable, TypeVar, Tuple, Union, overload jinja2: Any has_mpl: bool no_mpl_message: str +PipeReturn = TypeVar('PipeReturn') class Styler: loader: Any = ... @@ -51,4 +52,7 @@ class Styler: def highlight_min(self, subset: Optional[Any] = ..., color: str = ..., axis: int = ...) -> Any: ... @classmethod def from_custom_template(cls, searchpath: Any, name: Any) -> Any: ... - def pipe(self, func: Any, *args: Any, **kwargs: Any) -> Any: ... + @overload + def pipe(self: Any, func: Union[Callable[..., PipeReturn], Tuple[Callable[..., PipeReturn], str]], *args: Any, **kwargs: Any) -> PipeReturn: ... + @overload + def pipe(self: Any, func: PipeReturn, *args: Any, **kwargs: Any) -> PipeReturn: ...