|
23 | 23 | Attrs, |
24 | 24 | Bool, |
25 | 25 | Capitalize, |
| 26 | + DeepPartial, |
26 | 27 | DropAnnotations, |
27 | 28 | FromUnion, |
28 | 29 | GenericCallable, |
|
36 | 37 | IsAssignable, |
37 | 38 | IsEquivalent, |
38 | 39 | Iter, |
| 40 | + KeyOf, |
39 | 41 | Length, |
40 | 42 | Lowercase, |
41 | 43 | Member, |
42 | 44 | Members, |
43 | 45 | NewProtocol, |
| 46 | + Omit, |
44 | 47 | Overloaded, |
45 | 48 | Param, |
| 49 | + Partial, |
| 50 | + Pick, |
46 | 51 | RaiseError, |
| 52 | + Required, |
47 | 53 | Slice, |
48 | 54 | SpecialFormEllipsis, |
49 | 55 | StrConcat, |
| 56 | + Template, |
50 | 57 | Uncapitalize, |
51 | 58 | UpdateClass, |
52 | 59 | Uppercase, |
@@ -1282,3 +1289,279 @@ def _eval_NewProtocol(*etyps: Member, ctx): |
1282 | 1289 | cls.__init__ = dct["__init__"] |
1283 | 1290 |
|
1284 | 1291 | return cls |
| 1292 | + |
| 1293 | + |
| 1294 | +@type_eval.register_evaluator(KeyOf) |
| 1295 | +@_lift_over_unions |
| 1296 | +def _eval_KeyOf(tp, *, ctx): |
| 1297 | + """Evaluate KeyOf[T] to get all member names as a tuple of Literals.""" |
| 1298 | + tp = _eval_types(tp, ctx) |
| 1299 | + hints = get_annotated_type_hints( |
| 1300 | + tp, include_extras=True, attrs_only=True, ctx=ctx |
| 1301 | + ) |
| 1302 | + |
| 1303 | + if not hints: |
| 1304 | + return typing.Literal[()] |
| 1305 | + |
| 1306 | + # Extract member names and create tuple of Literal types |
| 1307 | + names = [] |
| 1308 | + for name in hints: |
| 1309 | + names.append(typing.Literal[name]) |
| 1310 | + |
| 1311 | + # Return as tuple of Literal types (use unpacking to make it hashable) |
| 1312 | + return tuple[*names] # type: ignore[return-value] |
| 1313 | + |
| 1314 | + |
| 1315 | +@type_eval.register_evaluator(Template) |
| 1316 | +def _eval_Template(*parts, ctx): |
| 1317 | + """Evaluate Template to concatenate all string parts.""" |
| 1318 | + evaluated_parts = [] |
| 1319 | + for part in parts: |
| 1320 | + evaled = _eval_types(part, ctx) |
| 1321 | + if _typing_inspect.is_generic_alias(evaled): |
| 1322 | + if evaled.__origin__ is typing.Literal: |
| 1323 | + # Extract literal string value |
| 1324 | + lit_val = evaled.__args__[0] |
| 1325 | + if isinstance(lit_val, str): |
| 1326 | + evaluated_parts.append(lit_val) |
| 1327 | + else: |
| 1328 | + raise TypeError( |
| 1329 | + f"Template parts must be string literals, got {lit_val}" |
| 1330 | + ) |
| 1331 | + else: |
| 1332 | + raise TypeError( |
| 1333 | + f"Template parts must be string literals, got {evaled}" |
| 1334 | + ) |
| 1335 | + elif isinstance(evaled, str): |
| 1336 | + # Plain string (shouldn't happen but handle it) |
| 1337 | + evaluated_parts.append(evaled) |
| 1338 | + else: |
| 1339 | + raise TypeError( |
| 1340 | + f"Template parts must be string literals, got {type(evaled)}" |
| 1341 | + ) |
| 1342 | + |
| 1343 | + return typing.Literal["".join(evaluated_parts)] |
| 1344 | + |
| 1345 | + |
| 1346 | +@type_eval.register_evaluator(DeepPartial) |
| 1347 | +def _eval_DeepPartial(tp, *, ctx): |
| 1348 | + """Evaluate DeepPartial[T] to create a class with all fields optional.""" |
| 1349 | + from typing import get_args |
| 1350 | + |
| 1351 | + tp = _eval_types(tp, ctx) |
| 1352 | + |
| 1353 | + # Get attributes using Attrs to get Member objects |
| 1354 | + attrs_result = _eval_Attrs(tp, ctx=ctx) |
| 1355 | + attrs_args = get_args(attrs_result) |
| 1356 | + |
| 1357 | + if not attrs_args: |
| 1358 | + return tp |
| 1359 | + |
| 1360 | + new_annotations = {} |
| 1361 | + |
| 1362 | + for member in attrs_args: |
| 1363 | + # Get the member name |
| 1364 | + name_result = _eval_types(member.name, ctx) |
| 1365 | + name = ( |
| 1366 | + get_args(name_result)[0] |
| 1367 | + if hasattr(name_result, "__args__") |
| 1368 | + else name_result |
| 1369 | + ) |
| 1370 | + |
| 1371 | + # Get the member type |
| 1372 | + type_result = _eval_types(member.type, ctx) |
| 1373 | + |
| 1374 | + # Check if this is a complex type (class with its own attributes) |
| 1375 | + if isinstance(type_result, type): |
| 1376 | + try: |
| 1377 | + nested_attrs = _eval_Attrs(type_result, ctx=ctx) |
| 1378 | + nested_args = get_args(nested_attrs) |
| 1379 | + if nested_args: |
| 1380 | + try: |
| 1381 | + nested_partial = _eval_DeepPartial(type_result, ctx=ctx) |
| 1382 | + new_annotations[name] = nested_partial | None |
| 1383 | + except NameError, TypeError: |
| 1384 | + new_annotations[name] = type_result | None |
| 1385 | + else: |
| 1386 | + new_annotations[name] = type_result | None |
| 1387 | + except NameError, TypeError, AttributeError: |
| 1388 | + new_annotations[name] = type_result | None |
| 1389 | + else: |
| 1390 | + new_annotations[name] = type_result | None |
| 1391 | + |
| 1392 | + class_name = ( |
| 1393 | + f"DeepPartial_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}" |
| 1394 | + ) |
| 1395 | + return type(class_name, (), {"__annotations__": new_annotations}) |
| 1396 | + |
| 1397 | + |
| 1398 | +@type_eval.register_evaluator(Partial) |
| 1399 | +def _eval_Partial(tp, *, ctx): |
| 1400 | + """Evaluate Partial[T] to create a class with all fields optional (non-recursive).""" |
| 1401 | + from typing import get_args |
| 1402 | + |
| 1403 | + tp = _eval_types(tp, ctx) |
| 1404 | + |
| 1405 | + # Get attributes using Attrs |
| 1406 | + attrs_result = _eval_Attrs(tp, ctx=ctx) |
| 1407 | + attrs_args = get_args(attrs_result) |
| 1408 | + |
| 1409 | + if not attrs_args: |
| 1410 | + return tp |
| 1411 | + |
| 1412 | + new_annotations = {} |
| 1413 | + |
| 1414 | + for member in attrs_args: |
| 1415 | + name_result = _eval_types(member.name, ctx) |
| 1416 | + name = ( |
| 1417 | + get_args(name_result)[0] |
| 1418 | + if hasattr(name_result, "__args__") |
| 1419 | + else name_result |
| 1420 | + ) |
| 1421 | + |
| 1422 | + try: |
| 1423 | + type_result = _eval_types(member.type, ctx) |
| 1424 | + new_annotations[name] = type_result | None |
| 1425 | + except NameError, TypeError, AttributeError: |
| 1426 | + new_annotations[name] = typing.Any | None |
| 1427 | + |
| 1428 | + class_name = ( |
| 1429 | + f"Partial_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}" |
| 1430 | + ) |
| 1431 | + return type(class_name, (), {"__annotations__": new_annotations}) |
| 1432 | + |
| 1433 | + |
| 1434 | +@type_eval.register_evaluator(Required) |
| 1435 | +def _eval_Required(tp, *, ctx): |
| 1436 | + """Evaluate Required[T] to remove Optional from all fields.""" |
| 1437 | + from typing import get_args |
| 1438 | + |
| 1439 | + tp = _eval_types(tp, ctx) |
| 1440 | + |
| 1441 | + attrs_result = _eval_Attrs(tp, ctx=ctx) |
| 1442 | + attrs_args = get_args(attrs_result) |
| 1443 | + |
| 1444 | + if not attrs_args: |
| 1445 | + return tp |
| 1446 | + |
| 1447 | + new_annotations = {} |
| 1448 | + |
| 1449 | + for member in attrs_args: |
| 1450 | + name_result = _eval_types(member.name, ctx) |
| 1451 | + name = ( |
| 1452 | + get_args(name_result)[0] |
| 1453 | + if hasattr(name_result, "__args__") |
| 1454 | + else name_result |
| 1455 | + ) |
| 1456 | + |
| 1457 | + type_result = _eval_types(member.type, ctx) |
| 1458 | + |
| 1459 | + # Remove None from union types |
| 1460 | + if isinstance(type_result, types.UnionType): |
| 1461 | + non_none_args = [ |
| 1462 | + arg for arg in type_result.__args__ if arg is not type(None) |
| 1463 | + ] |
| 1464 | + if len(non_none_args) == 1: |
| 1465 | + new_annotations[name] = non_none_args[0] |
| 1466 | + elif len(non_none_args) > 1: |
| 1467 | + new_annotations[name] = types.UnionType(*non_none_args) |
| 1468 | + else: |
| 1469 | + new_annotations[name] = type_result |
| 1470 | + elif ( |
| 1471 | + hasattr(type_result, "__origin__") |
| 1472 | + and type_result.__origin__ is typing.Union |
| 1473 | + ): |
| 1474 | + non_none_args = [ |
| 1475 | + arg for arg in get_args(type_result) if arg is not type(None) |
| 1476 | + ] |
| 1477 | + if len(non_none_args) == 1: |
| 1478 | + new_annotations[name] = non_none_args[0] |
| 1479 | + elif len(non_none_args) > 1: |
| 1480 | + new_annotations[name] = typing.Union[*non_none_args] |
| 1481 | + else: |
| 1482 | + new_annotations[name] = type_result |
| 1483 | + else: |
| 1484 | + new_annotations[name] = type_result |
| 1485 | + |
| 1486 | + class_name = ( |
| 1487 | + f"Required_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}" |
| 1488 | + ) |
| 1489 | + return type(class_name, (), {"__annotations__": new_annotations}) |
| 1490 | + |
| 1491 | + |
| 1492 | +@type_eval.register_evaluator(Pick) |
| 1493 | +def _eval_Pick(tp, keys, *, ctx): |
| 1494 | + """Evaluate Pick[T, K] to create a class with only specified fields.""" |
| 1495 | + from typing import get_args |
| 1496 | + |
| 1497 | + tp = _eval_types(tp, ctx) |
| 1498 | + keys = _eval_types(keys, ctx) |
| 1499 | + |
| 1500 | + key_names = tuple(get_args(keys)) if hasattr(keys, "__args__") else () |
| 1501 | + |
| 1502 | + attrs_result = _eval_Attrs(tp, ctx=ctx) |
| 1503 | + attrs_args = get_args(attrs_result) |
| 1504 | + |
| 1505 | + if not attrs_args: |
| 1506 | + return tp |
| 1507 | + |
| 1508 | + new_annotations = {} |
| 1509 | + |
| 1510 | + for member in attrs_args: |
| 1511 | + name_result = _eval_types(member.name, ctx) |
| 1512 | + name = ( |
| 1513 | + get_args(name_result)[0] |
| 1514 | + if hasattr(name_result, "__args__") |
| 1515 | + else name_result |
| 1516 | + ) |
| 1517 | + |
| 1518 | + if name in key_names: |
| 1519 | + try: |
| 1520 | + type_result = _eval_types(member.type, ctx) |
| 1521 | + new_annotations[name] = type_result |
| 1522 | + except NameError, TypeError, AttributeError: |
| 1523 | + new_annotations[name] = typing.Any |
| 1524 | + |
| 1525 | + class_name = ( |
| 1526 | + f"Pick_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}" |
| 1527 | + ) |
| 1528 | + return type(class_name, (), {"__annotations__": new_annotations}) |
| 1529 | + |
| 1530 | + |
| 1531 | +@type_eval.register_evaluator(Omit) |
| 1532 | +def _eval_Omit(tp, keys, *, ctx): |
| 1533 | + """Evaluate Omit[T, K] to create a class excluding specified fields.""" |
| 1534 | + from typing import get_args |
| 1535 | + |
| 1536 | + tp = _eval_types(tp, ctx) |
| 1537 | + keys = _eval_types(keys, ctx) |
| 1538 | + |
| 1539 | + key_names = set(get_args(keys)) if hasattr(keys, "__args__") else set() |
| 1540 | + |
| 1541 | + attrs_result = _eval_Attrs(tp, ctx=ctx) |
| 1542 | + attrs_args = get_args(attrs_result) |
| 1543 | + |
| 1544 | + if not attrs_args: |
| 1545 | + return tp |
| 1546 | + |
| 1547 | + new_annotations = {} |
| 1548 | + |
| 1549 | + for member in attrs_args: |
| 1550 | + name_result = _eval_types(member.name, ctx) |
| 1551 | + name = ( |
| 1552 | + get_args(name_result)[0] |
| 1553 | + if hasattr(name_result, "__args__") |
| 1554 | + else name_result |
| 1555 | + ) |
| 1556 | + |
| 1557 | + if name not in key_names: |
| 1558 | + try: |
| 1559 | + type_result = _eval_types(member.type, ctx) |
| 1560 | + new_annotations[name] = type_result |
| 1561 | + except NameError, TypeError, AttributeError: |
| 1562 | + new_annotations[name] = typing.Any |
| 1563 | + |
| 1564 | + class_name = ( |
| 1565 | + f"Omit_{tp.__name__ if hasattr(tp, '__name__') else 'Anonymous'}" |
| 1566 | + ) |
| 1567 | + return type(class_name, (), {"__annotations__": new_annotations}) |
0 commit comments