Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/anthropic/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,26 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]
args = get_args(type_)

if is_union(origin):
# For a discriminated union we can resolve the matching variant up-front from
# the discriminator value. That lets us validate just that one variant instead
# of the whole union, which is materially cheaper -- validating a union forces
# every member to be considered -- while returning an identical object. This is
# the hot path for streaming, where every event is decoded against the
# `RawMessageStreamEvent` union (#1649). When the data doesn't validate against
# the resolved variant we fall through to the existing handling below, so the
# behaviour for invalid / non-discriminated data is unchanged.
variant_type: type | None = None
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type is not None:
try:
return validate_type(type_=cast("type[object]", variant_type), value=value)
except Exception:
pass

try:
return validate_type(type_=cast("type[object]", original_type or type_), value=value)
except Exception:
Expand All @@ -626,13 +646,8 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
if variant_type is not None:
return construct_type(type_=variant_type, value=value)

# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
Expand Down
117 changes: 117 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,123 @@ class B(BaseModel):
assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator


def test_discriminated_union_fast_path_validates_single_variant(monkeypatch: pytest.MonkeyPatch) -> None:
# valid data for a discriminated union should validate only the matched variant,
# never the whole union (the union decode is materially more expensive). See #1649.
from anthropic import _models

class A(BaseModel):
type: Literal["a"]

data: str

class B(BaseModel):
type: Literal["b"]

data: int

validated: list[object] = []
real_validate = _models.validate_type

def spy(*, type_: Any, value: object) -> object:
validated.append(type_)
return real_validate(type_=type_, value=value)

monkeypatch.setattr(_models, "validate_type", spy)

m = construct_type(
value={"type": "b", "data": 7},
type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
)
assert isinstance(m, B)
assert m.data == 7
# only variant B was validated -- the whole Union[A, B] was never handed to validate_type
assert validated == [B]


def test_discriminated_union_fast_path_matches_full_union(monkeypatch: pytest.MonkeyPatch) -> None:
# the fast path must return an object identical to validating the whole union,
# for both clean data and data that only validates after discriminator selection.
from anthropic import _models

class A(BaseModel):
type: Literal["a"]

data: str

class B(BaseModel):
type: Literal["b"]

data: int

union = cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")])

for value in [{"type": "a", "data": "x"}, {"type": "b", "data": 100}]:
fast = construct_type(value=value, type_=union)

# force the old whole-union path by stubbing the discriminator lookup to None
monkeypatch.setattr(_models, "_build_discriminated_union_meta", lambda **_: None)
full = construct_type(value=value, type_=union)
monkeypatch.undo()

assert type(fast) is type(full)
assert fast == full


def test_discriminated_union_fast_path_falls_back_on_invalid_data(monkeypatch: pytest.MonkeyPatch) -> None:
# when the data does not validate against its discriminated variant we must fall
# through to the existing (unvalidated) construct path -- unchanged behavior.
from anthropic import _models

class A(BaseModel):
type: Literal["a"]

data: str

class B(BaseModel):
type: Literal["b"]

data: int

validated: list[object] = []
real_validate = _models.validate_type

def spy(*, type_: Any, value: object) -> object:
validated.append(type_)
return real_validate(type_=type_, value=value)

monkeypatch.setattr(_models, "validate_type", spy)

m = construct_type(
value={"type": "b", "data": "not-an-int"},
type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]),
)
# invalid int -> variant validation fails, falls back to .construct() keeping the raw value
assert isinstance(m, B)
assert m.data == "not-an-int" # type: ignore[comparison-overlap]
# the fast path attempted variant B first, before any whole-union validation
assert validated[0] is B


def test_non_discriminated_union_unaffected(monkeypatch: pytest.MonkeyPatch) -> None:
# a plain (non-discriminated) union has no fast path; the whole union is validated.
from anthropic import _models

validated: list[object] = []
real_validate = _models.validate_type

def spy(*, type_: Any, value: object) -> object:
validated.append(type_)
return real_validate(type_=type_, value=value)

monkeypatch.setattr(_models, "validate_type", spy)

m = construct_type(value=12, type_=cast(Any, Union[int, str]))
assert m == 12
# validated exactly once, with the union itself (no per-variant fast path)
assert len(validated) == 1


@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
def test_type_alias_type() -> None:
Alias = TypeAliasType("Alias", str) # pyright: ignore
Expand Down