Skip to content

Commit a69b73a

Browse files
committed
more specific type hints for batched(..., strict=True)
1 parent 82c9b97 commit a69b73a

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
5+
if sys.version_info >= (3, 13):
6+
from itertools import batched
7+
8+
from typing_extensions import assert_type
9+
10+
def check_batched_strict_literal() -> None:
11+
assert_type(batched([0], 1, strict=True), batched[tuple[int]])
12+
assert_type(batched([0, 0], 2, strict=True), batched[tuple[int, int]])
13+
assert_type(batched([0, 0, 0], 3, strict=True), batched[tuple[int, int, int]])
14+
assert_type(batched([0, 0, 0, 0], 4, strict=True), batched[tuple[int, int, int, int]])
15+
assert_type(batched([0, 0, 0, 0, 0], 5, strict=True), batched[tuple[int, int, int, int, int]])
16+
17+
def check_batched_non_strict() -> None:
18+
assert_type(batched([0], 2), batched[tuple[int, ...]])
19+
assert_type(batched([0], 2, strict=False), batched[tuple[int, ...]])
20+
21+
def check_batched_strict_non_literal() -> None:
22+
assert_type(batched([0, 0, 0], (lambda: 3)(), strict=True), batched[tuple[int, ...]])

stdlib/itertools.pyi

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,26 @@ if sys.version_info >= (3, 12):
343343
@disjoint_base
344344
class batched(Generic[_T_co]):
345345
if sys.version_info >= (3, 13):
346-
def __new__(cls, iterable: Iterable[_T_co], n: int, *, strict: bool = False) -> Self: ...
346+
@overload
347+
def __new__(cls, iterable: Iterable[_T], n: Literal[1], *, strict: Literal[True]) -> batched[tuple[_T]]: ...
348+
@overload
349+
def __new__(cls, iterable: Iterable[_T], n: Literal[2], *, strict: Literal[True]) -> batched[tuple[_T, _T]]: ...
350+
@overload
351+
def __new__(
352+
cls, iterable: Iterable[_T], n: Literal[3], *, strict: Literal[True]
353+
) -> batched[tuple[_T, _T, _T]]: ...
354+
@overload
355+
def __new__(
356+
cls, iterable: Iterable[_T], n: Literal[4], *, strict: Literal[True]
357+
) -> batched[tuple[_T, _T, _T, _T]]: ...
358+
@overload
359+
def __new__(
360+
cls, iterable: Iterable[_T], n: Literal[5], *, strict: Literal[True]
361+
) -> batched[tuple[_T, _T, _T, _T, _T]]: ...
362+
@overload
363+
def __new__(cls, iterable: Iterable[_T], n: int, *, strict: bool = False) -> batched[tuple[_T, ...]]: ...
347364
else:
348-
def __new__(cls, iterable: Iterable[_T_co], n: int) -> Self: ...
365+
def __new__(cls, iterable: Iterable[_T], n: int) -> batched[tuple[_T, ...]]: ...
349366

350367
def __iter__(self) -> Self: ...
351-
def __next__(self) -> tuple[_T_co, ...]: ...
368+
def __next__(self) -> _T_co: ...

0 commit comments

Comments
 (0)