Skip to content

Commit fba80da

Browse files
codewizdaveclaude
andcommitted
Add type utilities: KeyOf, Template, DeepPartial, Partial, Required, Pick, Omit
Add the following type utilities: - KeyOf[T]: Returns all member names as a tuple of Literal types - Template[*Parts]: Template literal string builder - DeepPartial[T]: Make all fields recursively optional - Partial[T]: Make all fields optional (non-recursive) - Required[T]: Remove Optional from all fields - Pick[T, K]: Pick specific fields from a type - Omit[T, K]: Omit specific fields from a type Add comprehensive tests for all new types. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a802a6b commit fba80da

3 files changed

Lines changed: 814 additions & 0 deletions

File tree

packages/typemap/src/typemap/type_eval/_eval_operators.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Attrs,
2424
Bool,
2525
Capitalize,
26+
DeepPartial,
2627
DropAnnotations,
2728
FromUnion,
2829
GenericCallable,
@@ -36,17 +37,23 @@
3637
IsAssignable,
3738
IsEquivalent,
3839
Iter,
40+
KeyOf,
3941
Length,
4042
Lowercase,
4143
Member,
4244
Members,
4345
NewProtocol,
46+
Omit,
4447
Overloaded,
4548
Param,
49+
Partial,
50+
Pick,
4651
RaiseError,
52+
Required,
4753
Slice,
4854
SpecialFormEllipsis,
4955
StrConcat,
56+
Template,
5057
Uncapitalize,
5158
UpdateClass,
5259
Uppercase,
@@ -1282,3 +1289,279 @@ def _eval_NewProtocol(*etyps: Member, ctx):
12821289
cls.__init__ = dct["__init__"]
12831290

12841291
return cls
1292+
1293+
1294+
@type_eval.register_evaluator(KeyOf)
1295+
@_lift_over_unions
1296+
def _eval_KeyOf(tp, *, ctx):
1297+
"""Evaluate KeyOf[T] to get all member names as a tuple of Literals."""
1298+
tp = _eval_types(tp, ctx)
1299+
hints = get_annotated_type_hints(
1300+
tp, include_extras=True, attrs_only=True, ctx=ctx
1301+
)
1302+
1303+
if not hints:
1304+
return typing.Literal[()]
1305+
1306+
# Extract member names and create tuple of Literal types
1307+
names = []
1308+
for name in hints:
1309+
names.append(typing.Literal[name])
1310+
1311+
# Return as tuple of Literal types (use unpacking to make it hashable)
1312+
return tuple[*names] # type: ignore[return-value]
1313+
1314+
1315+
@type_eval.register_evaluator(Template)
1316+
def _eval_Template(*parts, ctx):
1317+
"""Evaluate Template to concatenate all string parts."""
1318+
evaluated_parts = []
1319+
for part in parts:
1320+
evaled = _eval_types(part, ctx)
1321+
if _typing_inspect.is_generic_alias(evaled):
1322+
if evaled.__origin__ is typing.Literal:
1323+
# Extract literal string value
1324+
lit_val = evaled.__args__[0]
1325+
if isinstance(lit_val, str):
1326+
evaluated_parts.append(lit_val)
1327+
else:
1328+
raise TypeError(
1329+
f"Template parts must be string literals, got {lit_val}"
1330+
)
1331+
else:
1332+
raise TypeError(
1333+
f"Template parts must be string literals, got {evaled}"
1334+
)
1335+
elif isinstance(evaled, str):
1336+
# Plain string (shouldn't happen but handle it)
1337+
evaluated_parts.append(evaled)
1338+
else:
1339+
raise TypeError(
1340+
f"Template parts must be string literals, got {type(evaled)}"
1341+
)
1342+
1343+
return typing.Literal["".join(evaluated_parts)]
1344+
1345+
1346+
@type_eval.register_evaluator(DeepPartial)
1347+
def _eval_DeepPartial(tp, *, ctx):
1348+
"""Evaluate DeepPartial[T] to create a class with all fields optional."""
1349+
from typing import get_args
1350+
1351+
tp = _eval_types(tp, ctx)
1352+
1353+
# Get attributes using Attrs to get Member objects
1354+
attrs_result = _eval_Attrs(tp, ctx=ctx)
1355+
attrs_args = get_args(attrs_result)
1356+
1357+
if not attrs_args:
1358+
return tp
1359+
1360+
new_annotations = {}
1361+
1362+
for member in attrs_args:
1363+
# Get the member name
1364+
name_result = _eval_types(member.name, ctx)
1365+
name = (
1366+
get_args(name_result)[0]
1367+
if hasattr(name_result, "__args__")
1368+
else name_result
1369+
)
1370+
1371+
# Get the member type
1372+
type_result = _eval_types(member.type, ctx)
1373+
1374+
# Check if this is a complex type (class with its own attributes)
1375+
if isinstance(type_result, type):
1376+
try:
1377+
nested_attrs = _eval_Attrs(type_result, ctx=ctx)
1378+
nested_args = get_args(nested_attrs)
1379+
if nested_args:
1380+
try:
1381+
nested_partial = _eval_DeepPartial(type_result, ctx=ctx)
1382+
new_annotations[name] = nested_partial | None
1383+
except NameError, TypeError:
1384+
new_annotations[name] = type_result | None
1385+
else:
1386+
new_annotations[name] = type_result | None
1387+
except NameError, TypeError, AttributeError:
1388+
new_annotations[name] = type_result | None
1389+
else:
1390+
new_annotations[name] = type_result | None
1391+
1392+
class_name = (
1393+
f"DeepPartial_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}"
1394+
)
1395+
return type(class_name, (), {"__annotations__": new_annotations})
1396+
1397+
1398+
@type_eval.register_evaluator(Partial)
1399+
def _eval_Partial(tp, *, ctx):
1400+
"""Evaluate Partial[T] to create a class with all fields optional (non-recursive)."""
1401+
from typing import get_args
1402+
1403+
tp = _eval_types(tp, ctx)
1404+
1405+
# Get attributes using Attrs
1406+
attrs_result = _eval_Attrs(tp, ctx=ctx)
1407+
attrs_args = get_args(attrs_result)
1408+
1409+
if not attrs_args:
1410+
return tp
1411+
1412+
new_annotations = {}
1413+
1414+
for member in attrs_args:
1415+
name_result = _eval_types(member.name, ctx)
1416+
name = (
1417+
get_args(name_result)[0]
1418+
if hasattr(name_result, "__args__")
1419+
else name_result
1420+
)
1421+
1422+
try:
1423+
type_result = _eval_types(member.type, ctx)
1424+
new_annotations[name] = type_result | None
1425+
except NameError, TypeError, AttributeError:
1426+
new_annotations[name] = typing.Any | None
1427+
1428+
class_name = (
1429+
f"Partial_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}"
1430+
)
1431+
return type(class_name, (), {"__annotations__": new_annotations})
1432+
1433+
1434+
@type_eval.register_evaluator(Required)
1435+
def _eval_Required(tp, *, ctx):
1436+
"""Evaluate Required[T] to remove Optional from all fields."""
1437+
from typing import get_args
1438+
1439+
tp = _eval_types(tp, ctx)
1440+
1441+
attrs_result = _eval_Attrs(tp, ctx=ctx)
1442+
attrs_args = get_args(attrs_result)
1443+
1444+
if not attrs_args:
1445+
return tp
1446+
1447+
new_annotations = {}
1448+
1449+
for member in attrs_args:
1450+
name_result = _eval_types(member.name, ctx)
1451+
name = (
1452+
get_args(name_result)[0]
1453+
if hasattr(name_result, "__args__")
1454+
else name_result
1455+
)
1456+
1457+
type_result = _eval_types(member.type, ctx)
1458+
1459+
# Remove None from union types
1460+
if isinstance(type_result, types.UnionType):
1461+
non_none_args = [
1462+
arg for arg in type_result.__args__ if arg is not type(None)
1463+
]
1464+
if len(non_none_args) == 1:
1465+
new_annotations[name] = non_none_args[0]
1466+
elif len(non_none_args) > 1:
1467+
new_annotations[name] = types.UnionType(*non_none_args)
1468+
else:
1469+
new_annotations[name] = type_result
1470+
elif (
1471+
hasattr(type_result, "__origin__")
1472+
and type_result.__origin__ is typing.Union
1473+
):
1474+
non_none_args = [
1475+
arg for arg in get_args(type_result) if arg is not type(None)
1476+
]
1477+
if len(non_none_args) == 1:
1478+
new_annotations[name] = non_none_args[0]
1479+
elif len(non_none_args) > 1:
1480+
new_annotations[name] = typing.Union[*non_none_args]
1481+
else:
1482+
new_annotations[name] = type_result
1483+
else:
1484+
new_annotations[name] = type_result
1485+
1486+
class_name = (
1487+
f"Required_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}"
1488+
)
1489+
return type(class_name, (), {"__annotations__": new_annotations})
1490+
1491+
1492+
@type_eval.register_evaluator(Pick)
1493+
def _eval_Pick(tp, keys, *, ctx):
1494+
"""Evaluate Pick[T, K] to create a class with only specified fields."""
1495+
from typing import get_args
1496+
1497+
tp = _eval_types(tp, ctx)
1498+
keys = _eval_types(keys, ctx)
1499+
1500+
key_names = tuple(get_args(keys)) if hasattr(keys, "__args__") else ()
1501+
1502+
attrs_result = _eval_Attrs(tp, ctx=ctx)
1503+
attrs_args = get_args(attrs_result)
1504+
1505+
if not attrs_args:
1506+
return tp
1507+
1508+
new_annotations = {}
1509+
1510+
for member in attrs_args:
1511+
name_result = _eval_types(member.name, ctx)
1512+
name = (
1513+
get_args(name_result)[0]
1514+
if hasattr(name_result, "__args__")
1515+
else name_result
1516+
)
1517+
1518+
if name in key_names:
1519+
try:
1520+
type_result = _eval_types(member.type, ctx)
1521+
new_annotations[name] = type_result
1522+
except NameError, TypeError, AttributeError:
1523+
new_annotations[name] = typing.Any
1524+
1525+
class_name = (
1526+
f"Pick_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}"
1527+
)
1528+
return type(class_name, (), {"__annotations__": new_annotations})
1529+
1530+
1531+
@type_eval.register_evaluator(Omit)
1532+
def _eval_Omit(tp, keys, *, ctx):
1533+
"""Evaluate Omit[T, K] to create a class excluding specified fields."""
1534+
from typing import get_args
1535+
1536+
tp = _eval_types(tp, ctx)
1537+
keys = _eval_types(keys, ctx)
1538+
1539+
key_names = set(get_args(keys)) if hasattr(keys, "__args__") else set()
1540+
1541+
attrs_result = _eval_Attrs(tp, ctx=ctx)
1542+
attrs_args = get_args(attrs_result)
1543+
1544+
if not attrs_args:
1545+
return tp
1546+
1547+
new_annotations = {}
1548+
1549+
for member in attrs_args:
1550+
name_result = _eval_types(member.name, ctx)
1551+
name = (
1552+
get_args(name_result)[0]
1553+
if hasattr(name_result, "__args__")
1554+
else name_result
1555+
)
1556+
1557+
if name not in key_names:
1558+
try:
1559+
type_result = _eval_types(member.type, ctx)
1560+
new_annotations[name] = type_result
1561+
except NameError, TypeError, AttributeError:
1562+
new_annotations[name] = typing.Any
1563+
1564+
class_name = (
1565+
f"Omit_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}"
1566+
)
1567+
return type(class_name, (), {"__annotations__": new_annotations})

0 commit comments

Comments
 (0)