diff --git a/src/anthropic/_models.py b/src/anthropic/_models.py index dc00516b..0f666a10 100644 --- a/src/anthropic/_models.py +++ b/src/anthropic/_models.py @@ -607,6 +607,26 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] args = get_args(type_) if is_union(origin): + # For a discriminated union we can resolve the matching variant up-front from + # the discriminator value. That lets us validate just that one variant instead + # of the whole union, which is materially cheaper -- validating a union forces + # every member to be considered -- while returning an identical object. This is + # the hot path for streaming, where every event is decoded against the + # `RawMessageStreamEvent` union (#1649). When the data doesn't validate against + # the resolved variant we fall through to the existing handling below, so the + # behaviour for invalid / non-discriminated data is unchanged. + variant_type: type | None = None + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type is not None: + try: + return validate_type(type_=cast("type[object]", variant_type), value=value) + except Exception: + pass + try: return validate_type(type_=cast("type[object]", original_type or type_), value=value) except Exception: @@ -626,13 +646,8 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] # # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then # we'd end up constructing `FooType` when it should be `BarType`. - discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) - if discriminator and is_mapping(value): - variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) - if variant_value and isinstance(variant_value, str): - variant_type = discriminator.mapping.get(variant_value) - if variant_type: - return construct_type(type_=variant_type, value=value) + if variant_type is not None: + return construct_type(type_=variant_type, value=value) # if the data is not valid, use the first variant that doesn't fail while deserializing for variant in args: diff --git a/tests/test_models.py b/tests/test_models.py index 195f2307..c8c16211 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -834,6 +834,123 @@ class B(BaseModel): assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator +def test_discriminated_union_fast_path_validates_single_variant(monkeypatch: pytest.MonkeyPatch) -> None: + # valid data for a discriminated union should validate only the matched variant, + # never the whole union (the union decode is materially more expensive). See #1649. + from anthropic import _models + + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + validated: list[object] = [] + real_validate = _models.validate_type + + def spy(*, type_: Any, value: object) -> object: + validated.append(type_) + return real_validate(type_=type_, value=value) + + monkeypatch.setattr(_models, "validate_type", spy) + + m = construct_type( + value={"type": "b", "data": 7}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.data == 7 + # only variant B was validated -- the whole Union[A, B] was never handed to validate_type + assert validated == [B] + + +def test_discriminated_union_fast_path_matches_full_union(monkeypatch: pytest.MonkeyPatch) -> None: + # the fast path must return an object identical to validating the whole union, + # for both clean data and data that only validates after discriminator selection. + from anthropic import _models + + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + union = cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]) + + for value in [{"type": "a", "data": "x"}, {"type": "b", "data": 100}]: + fast = construct_type(value=value, type_=union) + + # force the old whole-union path by stubbing the discriminator lookup to None + monkeypatch.setattr(_models, "_build_discriminated_union_meta", lambda **_: None) + full = construct_type(value=value, type_=union) + monkeypatch.undo() + + assert type(fast) is type(full) + assert fast == full + + +def test_discriminated_union_fast_path_falls_back_on_invalid_data(monkeypatch: pytest.MonkeyPatch) -> None: + # when the data does not validate against its discriminated variant we must fall + # through to the existing (unvalidated) construct path -- unchanged behavior. + from anthropic import _models + + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + validated: list[object] = [] + real_validate = _models.validate_type + + def spy(*, type_: Any, value: object) -> object: + validated.append(type_) + return real_validate(type_=type_, value=value) + + monkeypatch.setattr(_models, "validate_type", spy) + + m = construct_type( + value={"type": "b", "data": "not-an-int"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + # invalid int -> variant validation fails, falls back to .construct() keeping the raw value + assert isinstance(m, B) + assert m.data == "not-an-int" # type: ignore[comparison-overlap] + # the fast path attempted variant B first, before any whole-union validation + assert validated[0] is B + + +def test_non_discriminated_union_unaffected(monkeypatch: pytest.MonkeyPatch) -> None: + # a plain (non-discriminated) union has no fast path; the whole union is validated. + from anthropic import _models + + validated: list[object] = [] + real_validate = _models.validate_type + + def spy(*, type_: Any, value: object) -> object: + validated.append(type_) + return real_validate(type_=type_, value=value) + + monkeypatch.setattr(_models, "validate_type", spy) + + m = construct_type(value=12, type_=cast(Any, Union[int, str])) + assert m == 12 + # validated exactly once, with the union itself (no per-variant fast path) + assert len(validated) == 1 + + @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: Alias = TypeAliasType("Alias", str) # pyright: ignore