Skip to content

Commit 67a1748

Browse files
authored
TYP: collections.abc.Callable (#1526)
* Callable * missing cases * one missing case
1 parent e2de228 commit 67a1748

File tree

14 files changed

+120
-67
lines changed

14 files changed

+120
-67
lines changed

pandas-stubs/_typing.pyi

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ from typing import (
2222
SupportsIndex,
2323
TypeAlias,
2424
TypedDict,
25-
Union,
2625
overload,
2726
)
2827

@@ -596,18 +595,25 @@ IndexKeyFunc: TypeAlias = Callable[[Index], Index | AnyArrayLike] | None
596595

597596
# types of `func` kwarg for DataFrame.aggregate and Series.aggregate
598597
# More specific than what is in pandas
599-
# following Union is here to make it ty compliant https://github.com/astral-sh/ty/issues/591
600-
AggFuncTypeBase: TypeAlias = Union[Callable, str, np.ufunc] # noqa: UP007
601-
AggFuncTypeDictSeries: TypeAlias = Mapping[HashableT, AggFuncTypeBase]
598+
AggFuncTypeBase: TypeAlias = Callable[P, Any] | str | np.ufunc
599+
AggFuncTypeDictSeries: TypeAlias = Mapping[HashableT, AggFuncTypeBase[P]]
602600
AggFuncTypeDictFrame: TypeAlias = Mapping[
603-
HashableT, AggFuncTypeBase | list[AggFuncTypeBase]
601+
HashableT, AggFuncTypeBase[P] | Sequence[AggFuncTypeBase[P]]
604602
]
605-
AggFuncTypeSeriesToFrame: TypeAlias = list[AggFuncTypeBase] | AggFuncTypeDictSeries
603+
AggFuncTypeSeriesToFrame: TypeAlias = (
604+
Sequence[AggFuncTypeBase[P]] | AggFuncTypeDictSeries[HashableT, P]
605+
)
606606
AggFuncTypeFrame: TypeAlias = (
607-
AggFuncTypeBase | list[AggFuncTypeBase] | AggFuncTypeDictFrame
607+
AggFuncTypeBase[P]
608+
| Sequence[AggFuncTypeBase[P]]
609+
| AggFuncTypeDictFrame[HashableT, P]
610+
)
611+
AggFuncTypeDict: TypeAlias = (
612+
AggFuncTypeDictSeries[HashableT, P] | AggFuncTypeDictFrame[HashableT, P]
613+
)
614+
AggFuncType: TypeAlias = (
615+
AggFuncTypeBase[P] | Sequence[AggFuncTypeBase[P]] | AggFuncTypeDict[HashableT, P]
608616
)
609-
AggFuncTypeDict: TypeAlias = AggFuncTypeDictSeries | AggFuncTypeDictFrame
610-
AggFuncType: TypeAlias = AggFuncTypeBase | list[AggFuncTypeBase] | AggFuncTypeDict
611617

612618
# Not used in stubs
613619
# AggObjType = Union[
@@ -694,7 +700,9 @@ CompressionOptions: TypeAlias = (
694700

695701
# types in DataFrameFormatter
696702
FormattersType: TypeAlias = (
697-
list[Callable] | tuple[Callable, ...] | Mapping[str | int, Callable]
703+
list[Callable[..., Any]]
704+
| tuple[Callable[..., Any], ...]
705+
| Mapping[str | int, Callable[..., Any]]
698706
)
699707
# ColspaceType = Mapping[Hashable, Union[str, int]] not used in stubs
700708
FloatFormatType: TypeAlias = str | Callable[[float], str] | EngFormatter

pandas-stubs/core/frame.pyi

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
273273
| list[HashableT]
274274
| slice
275275
| _IndexSliceTuple
276-
| Callable,
277-
MaskType | Iterable[HashableT] | IndexType | Callable,
276+
| Callable[..., Any],
277+
MaskType | Iterable[HashableT] | IndexType | Callable[..., Any],
278278
]
279279
),
280280
) -> _T: ...
@@ -1268,7 +1268,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
12681268
def combine(
12691269
self,
12701270
other: DataFrame,
1271-
func: Callable,
1271+
func: Callable[..., Any],
12721272
fill_value: Scalar | None = None,
12731273
overwrite: _bool = True,
12741274
) -> Self: ...
@@ -1278,7 +1278,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
12781278
other: DataFrame | Series,
12791279
join: UpdateJoin = "left",
12801280
overwrite: _bool = True,
1281-
filter_func: Callable | None = ...,
1281+
filter_func: Callable[..., Any] | None = ...,
12821282
errors: IgnoreRaise = "ignore",
12831283
) -> None: ...
12841284
@overload
@@ -1516,21 +1516,21 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
15161516
@overload
15171517
def aggregate( # pyright: ignore[reportOverlappingOverload]
15181518
self,
1519-
func: AggFuncTypeBase | AggFuncTypeDictSeries,
1519+
func: AggFuncTypeBase[...] | AggFuncTypeDictSeries[Any, ...],
15201520
axis: Axis = 0,
15211521
**kwargs: Any,
15221522
) -> Series: ...
15231523
@overload
15241524
def aggregate(
15251525
self,
1526-
func: list[AggFuncTypeBase] | AggFuncTypeDictFrame | None = ...,
1526+
func: list[AggFuncTypeBase[...]] | AggFuncTypeDictFrame[Any, ...] | None = ...,
15271527
axis: Axis = 0,
15281528
**kwargs: Any,
15291529
) -> Self: ...
15301530
agg = aggregate
15311531
def transform(
15321532
self,
1533-
func: AggFuncTypeFrame,
1533+
func: AggFuncTypeFrame[..., Any],
15341534
axis: Axis = 0,
15351535
*args: Any,
15361536
**kwargs: Any,
@@ -1684,7 +1684,10 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
16841684

16851685
# Add spacing between apply() overloads and remaining annotations
16861686
def map(
1687-
self, func: Callable, na_action: Literal["ignore"] | None = None, **kwargs: Any
1687+
self,
1688+
func: Callable[..., Any],
1689+
na_action: Literal["ignore"] | None = None,
1690+
**kwargs: Any,
16881691
) -> Self: ...
16891692
def join(
16901693
self,
@@ -2332,7 +2335,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
23322335
| Callable[[DataFrame], DataFrame]
23332336
| Callable[[Any], _bool]
23342337
),
2335-
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
2338+
other: Scalar | Series | DataFrame | Callable[..., Any] | NAType | None = ...,
23362339
*,
23372340
inplace: Literal[True],
23382341
axis: Axis | None = ...,
@@ -2349,7 +2352,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
23492352
| Callable[[DataFrame], DataFrame]
23502353
| Callable[[Any], _bool]
23512354
),
2352-
other: Scalar | Series | DataFrame | Callable | NAType | None = ...,
2355+
other: Scalar | Series | DataFrame | Callable[..., Any] | NAType | None = ...,
23532356
*,
23542357
inplace: Literal[False] = False,
23552358
axis: Axis | None = ...,
@@ -2510,8 +2513,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
25102513
def rename_axis(
25112514
self,
25122515
*,
2513-
index: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
2514-
columns: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
2516+
index: (
2517+
_str | Sequence[_str] | dict[_str | int, _str] | Callable[..., Any] | None
2518+
) = ...,
2519+
columns: (
2520+
_str | Sequence[_str] | dict[_str | int, _str] | Callable[..., Any] | None
2521+
) = ...,
25152522
copy: _bool = ...,
25162523
inplace: Literal[True],
25172524
) -> None: ...
@@ -2520,8 +2527,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
25202527
def rename_axis(
25212528
self,
25222529
*,
2523-
index: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
2524-
columns: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
2530+
index: (
2531+
_str | Sequence[_str] | dict[_str | int, _str] | Callable[..., Any] | None
2532+
) = ...,
2533+
columns: (
2534+
_str | Sequence[_str] | dict[_str | int, _str] | Callable[..., Any] | None
2535+
) = ...,
25252536
copy: _bool = ...,
25262537
inplace: Literal[False] = False,
25272538
) -> Self: ...

pandas-stubs/core/groupby/generic.pyi

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
7878
@overload
7979
def aggregate(
8080
self,
81-
func: list[AggFuncTypeBase],
81+
func: list[AggFuncTypeBase[...]],
8282
/,
8383
*args: Any,
8484
engine: WindowingEngine = ...,
@@ -88,7 +88,7 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
8888
@overload
8989
def aggregate(
9090
self,
91-
func: AggFuncTypeBase | None = ...,
91+
func: AggFuncTypeBase[...] | None = ...,
9292
/,
9393
*args: Any,
9494
engine: WindowingEngine = ...,
@@ -109,16 +109,20 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
109109
@overload
110110
def transform(
111111
self,
112-
func: Callable,
113-
*args: Any,
114-
**kwargs: Any,
112+
func: Callable[Concatenate[Series, P], Any],
113+
*args: P.args,
114+
**kwargs: P.kwargs,
115115
) -> Series: ...
116116
@overload
117117
def transform(
118118
self, func: TransformReductionListType, *args: Any, **kwargs: Any
119119
) -> Series: ...
120120
def filter(
121-
self, func: Callable | str, dropna: bool = ..., *args: Any, **kwargs: Any
121+
self,
122+
func: Callable[Concatenate[Series, P], Any] | str,
123+
dropna: bool = ...,
124+
*args: P.args,
125+
**kwargs: P.kwargs,
122126
) -> Series: ...
123127
def nunique(self, dropna: bool = ...) -> Series[int]: ...
124128
# describe delegates to super() method but here it has keyword-only parameters
@@ -257,7 +261,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
257261
@overload
258262
def aggregate(
259263
self,
260-
func: AggFuncTypeFrame | None = ...,
264+
func: AggFuncTypeFrame[..., Any] | None = ...,
261265
*args: Any,
262266
engine: WindowingEngine = ...,
263267
engine_kwargs: WindowingEngineKwargs = ...,
@@ -266,7 +270,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
266270
@overload
267271
def aggregate(
268272
self,
269-
func: AggFuncTypeFrame | None = None,
273+
func: AggFuncTypeFrame[..., Any] | None = None,
270274
/,
271275
**kwargs: Any,
272276
) -> DataFrame: ...
@@ -283,16 +287,20 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
283287
@overload
284288
def transform(
285289
self,
286-
func: Callable,
287-
*args: Any,
288-
**kwargs: Any,
290+
func: Callable[Concatenate[DataFrame, P], Any],
291+
*args: P.args,
292+
**kwargs: P.kwargs,
289293
) -> DataFrame: ...
290294
@overload
291295
def transform(
292296
self, func: TransformReductionListType, *args: Any, **kwargs: Any
293297
) -> DataFrame: ...
294298
def filter(
295-
self, func: Callable, dropna: bool = ..., *args: Any, **kwargs: Any
299+
self,
300+
func: Callable[Concatenate[DataFrame, P], Any],
301+
dropna: bool = ...,
302+
*args: P.args,
303+
**kwargs: P.kwargs,
296304
) -> DataFrame: ...
297305
@overload
298306
def __getitem__(self, key: Scalar) -> SeriesGroupBy[Any, ByT]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ _ResamplerGroupBy: TypeAlias = (
7878

7979
class GroupBy(BaseGroupBy[NDFrameT]):
8080
def __getattr__(self, attr: str) -> Any: ...
81-
def apply(self, func: Callable | str, *args: Any, **kwargs: Any) -> NDFrameT: ...
81+
def apply(
82+
self,
83+
func: Callable[Concatenate[NDFrameT, P], Any] | str,
84+
*args: P.args,
85+
**kwargs: P.kwargs,
86+
) -> NDFrameT: ...
8287
@final
8388
@overload
8489
def any(self: GroupBy[Series], skipna: bool = ...) -> Series[bool]: ...

pandas-stubs/core/indexes/base.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,10 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
408408
) -> Self: ...
409409
def copy(self, name: Hashable = ..., deep: bool = False) -> Self: ...
410410
def format(
411-
self, name: bool = ..., formatter: Callable | None = ..., na_rep: _str = ...
411+
self,
412+
name: bool = ...,
413+
formatter: Callable[..., Any] | None = ...,
414+
na_rep: _str = ...,
412415
) -> list[_str]: ...
413416
def to_series(
414417
self, index: Index | None = None, name: Hashable | None = None

pandas-stubs/core/indexes/multi.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class MultiIndex(Index):
9494
def format(
9595
self,
9696
name: bool | None = ...,
97-
formatter: Callable | None = ...,
97+
formatter: Callable[..., Any] | None = ...,
9898
na_rep: str | None = ...,
9999
names: bool = ...,
100100
space: int = ...,

0 commit comments

Comments
 (0)