Skip to content

Commit e8638cd

Browse files
[builtins] Canonicalized set/frozenset signatures (#15470)
1 parent ceaa3e1 commit e8638cd

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

stdlib/@tests/test_cases/builtins/check_set.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,46 @@ def test_set_difference(x: set[Literal["foo", "bar"]], y: set[str], z: set[int])
1212
assert_type(z - x, set[int])
1313
assert_type(y - z, set[str])
1414
assert_type(z - y, set[int])
15+
16+
17+
def test_set_interface_overlapping_type(s: set[Literal["foo", "bar"]], y: set[str], key: str) -> None:
18+
s.add(key) # type: ignore
19+
s.discard(key)
20+
s.remove(key) # type: ignore
21+
s.difference_update(y)
22+
s.intersection_update(y)
23+
s.symmetric_difference_update(y) # type: ignore
24+
s.update(y) # type: ignore
25+
26+
assert_type(s.difference(y), set[Literal["foo", "bar"]])
27+
assert_type(s.intersection(y), set[Literal["foo", "bar"]])
28+
assert_type(s.isdisjoint(y), bool)
29+
assert_type(s.issubset(y), bool)
30+
assert_type(s.issuperset(y), bool)
31+
assert_type(s.symmetric_difference(y), set[str])
32+
assert_type(s.union(y), set[str])
33+
34+
assert_type(s - y, set[Literal["foo", "bar"]])
35+
assert_type(s & y, set[Literal["foo", "bar"]])
36+
assert_type(s | y, set[str])
37+
assert_type(s ^ y, set[str])
38+
39+
s -= y
40+
s &= y
41+
s |= y # type: ignore
42+
s ^= y # type: ignore
43+
44+
45+
def test_frozenset_interface(s: frozenset[Literal["foo", "bar"]], y: frozenset[str]) -> None:
46+
assert_type(s.difference(y), frozenset[Literal["foo", "bar"]])
47+
assert_type(s.intersection(y), frozenset[Literal["foo", "bar"]])
48+
assert_type(s.isdisjoint(y), bool)
49+
assert_type(s.issubset(y), bool)
50+
assert_type(s.issuperset(y), bool)
51+
assert_type(s.symmetric_difference(y), frozenset[str])
52+
assert_type(s.union(y), frozenset[str])
53+
54+
assert_type(s - y, frozenset[Literal["foo", "bar"]])
55+
assert_type(s & y, frozenset[Literal["foo", "bar"]])
56+
assert_type(s | y, frozenset[str])
57+
assert_type(s ^ y, frozenset[str])

stdlib/builtins.pyi

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,16 +1262,16 @@ class set(MutableSet[_T]):
12621262
def __init__(self, iterable: Iterable[_T], /) -> None: ...
12631263
def add(self, element: _T, /) -> None: ...
12641264
def copy(self) -> set[_T]: ...
1265-
def difference(self, *s: Iterable[Any]) -> set[_T]: ...
1266-
def difference_update(self, *s: Iterable[Any]) -> None: ...
1267-
def discard(self, element: _T, /) -> None: ...
1268-
def intersection(self, *s: Iterable[Any]) -> set[_T]: ...
1269-
def intersection_update(self, *s: Iterable[Any]) -> None: ...
1270-
def isdisjoint(self, s: Iterable[Any], /) -> bool: ...
1271-
def issubset(self, s: Iterable[Any], /) -> bool: ...
1272-
def issuperset(self, s: Iterable[Any], /) -> bool: ...
1265+
def difference(self, *s: Iterable[object]) -> set[_T]: ...
1266+
def difference_update(self, *s: Iterable[object]) -> None: ...
1267+
def discard(self, element: object, /) -> None: ...
1268+
def intersection(self, *s: Iterable[object]) -> set[_T]: ...
1269+
def intersection_update(self, *s: Iterable[object]) -> None: ...
1270+
def isdisjoint(self, s: Iterable[object], /) -> bool: ...
1271+
def issubset(self, s: Iterable[object], /) -> bool: ...
1272+
def issuperset(self, s: Iterable[object], /) -> bool: ...
12731273
def remove(self, element: _T, /) -> None: ...
1274-
def symmetric_difference(self, s: Iterable[_T], /) -> set[_T]: ...
1274+
def symmetric_difference(self, s: Iterable[_S], /) -> set[_T | _S]: ...
12751275
def symmetric_difference_update(self, s: Iterable[_T], /) -> None: ...
12761276
def union(self, *s: Iterable[_S]) -> set[_T | _S]: ...
12771277
def update(self, *s: Iterable[_T]) -> None: ...
@@ -1303,15 +1303,15 @@ class frozenset(AbstractSet[_T_co]):
13031303
def copy(self) -> frozenset[_T_co]: ...
13041304
def difference(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
13051305
def intersection(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
1306-
def isdisjoint(self, s: Iterable[_T_co], /) -> bool: ...
1306+
def isdisjoint(self, s: Iterable[object], /) -> bool: ...
13071307
def issubset(self, s: Iterable[object], /) -> bool: ...
13081308
def issuperset(self, s: Iterable[object], /) -> bool: ...
1309-
def symmetric_difference(self, s: Iterable[_T_co], /) -> frozenset[_T_co]: ...
1309+
def symmetric_difference(self, s: Iterable[_S], /) -> frozenset[_T_co | _S]: ...
13101310
def union(self, *s: Iterable[_S]) -> frozenset[_T_co | _S]: ...
13111311
def __len__(self) -> int: ...
13121312
def __contains__(self, o: object, /) -> bool: ...
13131313
def __iter__(self) -> Iterator[_T_co]: ...
1314-
def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ...
1314+
def __and__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ...
13151315
def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...
13161316
def __sub__(self, value: AbstractSet[object], /) -> frozenset[_T_co]: ...
13171317
def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...

0 commit comments

Comments
 (0)