Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CallableDiscriminator and Tag #7983

Merged
merged 9 commits into from Nov 1, 2023
141 changes: 135 additions & 6 deletions docs/api/standard_library_types.md
Expand Up @@ -871,18 +871,26 @@ print(type(User(id='123', age='45').id))

### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
To do that you can set the same field - let's call it `my_discriminator` - in each of the submodels
with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.
We can use discriminated unions to more efficiently validate `Union` types.
Discriminated unions can be used to validate `Union` types with multiple models, or combinations of
models and primitive types.

Setting a discriminated union has many benefits:

- validation is faster since it is only attempted against one model
- only one explicit error is raised in case of failure
- the generated JSON schema implements the [associated OpenAPI specification](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#discriminator-object)

#### Discriminated Unions with `str` discriminators
Frequently, in the case of a `Union` with multiple models,
there is a common field to all members of the union that can be used to distinguish
which union case the data should be validated against; this is referred to as the "discriminator" in
[OpenAPI](https://swagger.io/docs/specification/data-models/inheritance-and-polymorphism/).
To validate models based on that information you can set the same field - let's call it `my_discriminator` -
in each of the models with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.


```py requires="3.8"
from typing import Literal, Union

Expand Down Expand Up @@ -922,9 +930,130 @@ except ValidationError as e:
"""
```

#### Discriminated Unions with `CallableDiscriminator` discriminators
In the case of a `Union` with multiple models, sometimes there isn't a single uniform field
across all models that you can use as a discriminator. This is the perfect use case for the `CallableDiscriminator` approach.

```py requires="3.8"
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


class Pie(BaseModel):
time_to_cook: int
num_ingredients: int


class ApplePie(Pie):
fruit: Literal['apple'] = 'apple'


class PumpkinPie(Pie):
filling: Literal['pumpkin'] = 'pumpkin'


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get('fruit', v.get('filling'))
return getattr(v, 'fruit', getattr(v, 'filling', None))


class ThanksgivingDinner(BaseModel):
dessert: Annotated[
Union[
Annotated[ApplePie, Tag('apple')],
Annotated[PumpkinPie, Tag('pumpkin')],
],
CallableDiscriminator(get_discriminator_value),
]


apple_variation = ThanksgivingDinner.model_validate(
{'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}}
)
print(repr(apple_variation))
"""
ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple'))
"""

pumpkin_variation = ThanksgivingDinner.model_validate(
{
'dessert': {
'filling': 'pumpkin',
'time_to_cook': 40,
'num_ingredients': 6,
}
}
)
print(repr(pumpkin_variation))
"""
ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin'))
"""
```

`CallableDiscriminators` can also be used to validate `Union` types with combinations of models and primitive types.

For example:

```py requires="3.8"
from typing import Any, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


def model_x_discriminator(v: Any) -> str:
if isinstance(v, str):
return 'str'
if isinstance(v, (dict, BaseModel)):
return 'model'


class DiscriminatedModel(BaseModel):
x: Annotated[
Union[
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
custom_error_context={'discriminator': 'str_or_model'},
),
]


data = {'x': {'x': {'x': 'a'}}}
m = DiscriminatedModel.model_validate(data)
assert m == (
DiscriminatedModel(x=DiscriminatedModel(x=DiscriminatedModel(x='a')))
)
assert m.model_dump() == data
```

!!! note
Using the [`typing.Annotated` fields syntax](../concepts/json_schema.md#typingannotated-fields) can be handy to regroup
the `Union` and `discriminator` information. See below for an example!
the `Union` and `discriminator` information. See the next example for more details.

There are a few ways to set a discriminator for a field, all varying slightly in syntax.

For `str` discriminators:
```
some_field: Union[...] = Field(discriminator='my_discriminator'
some_field: Annotated[Union[...], Field(discriminator='my_discriminator')]
```

For `CallableDiscriminator` discriminators:
```
some_field: Union[...] = Field(discriminator=CallableDiscriminator(...))
some_field: Annotated[Union[...], CallableDiscriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=CallableDiscriminator(...))]
```

!!! warning
Discriminated unions cannot be used with only a single variant, such as `Union[Cat]`.
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
47 changes: 45 additions & 2 deletions docs/concepts/fields.md
Expand Up @@ -597,7 +597,10 @@ print(user)
## Discriminator

The parameter `discriminator` can be used to control the field that will be used to discriminate between different
models in a union.
models in a union. It takes either the name of a field or a `CallableDiscriminator` instance. The `CallableDiscriminator`
approach can be useful when the discriminator fields aren't the same for all of the models in the `Union`.

The following example shows how to use `discriminator` with a field name:

```py requires="3.8"
from typing import Literal, Union
Expand Down Expand Up @@ -625,7 +628,47 @@ print(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})) # (1)!

1. See more about [Helper Functions] in the [Models] page.

See the [Discriminated Unions] for more details.
The following example shows how to use `discriminator` with a `CallableDiscriminator` instance:

```py requires="3.8"
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Field, Tag


class Cat(BaseModel):
pet_type: Literal['cat']
age: int


class Dog(BaseModel):
pet_kind: Literal['dog']
age: int


def pet_discriminator(v):
if isinstance(v, dict):
return v.get('pet_type', v.get('pet_kind'))
return getattr(v, 'pet_type', getattr(v, 'pet_kind', None))


class Model(BaseModel):
pet: Union[Annotated[Cat, Tag('cat')], Annotated[Dog, Tag('dog')]] = Field(
discriminator=CallableDiscriminator(pet_discriminator)
)


print(repr(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})))
#> Model(pet=Cat(pet_type='cat', age=12))

print(repr(Model.model_validate({'pet': {'pet_kind': 'dog', 'age': 12}})))
#> Model(pet=Dog(pet_kind='dog', age=12))
```

You can also take advantage of `Annotated` to define your discriminated unions.
See the [Discriminated Unions] docs for more details.
dmontagu marked this conversation as resolved.
Show resolved Hide resolved

## Strict Mode

Expand Down
4 changes: 4 additions & 0 deletions docs/concepts/performance.md
Expand Up @@ -132,6 +132,8 @@ class Html(BaseModel):
)
```

See [Discriminated Unions] for more details.

## Use `Literal` not `Enum`

Instead of using `Enum`, use `Literal` to define the structure of the data.
Expand Down Expand Up @@ -207,3 +209,5 @@ Instead of using nested models, use `TypedDict` to define the structure of the d
## Avoid wrap validators if you really care about performance

<!-- TODO: I need help on this one. -->
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

[Discriminated Unions]: ../api/standard_library_types.md#discriminated-unions-aka-tagged-unions
3 changes: 3 additions & 0 deletions docs/concepts/validators.md
Expand Up @@ -392,6 +392,9 @@ However, in most cases where you want to perform validation using multiple field

## Model validators

??? api "API Documentation"
[`pydantic.functional_validators.model_validator`][pydantic.functional_validators.model_validator]<br>

Validation can also be performed on the entire model's data using `@model_validator`.

```py
Expand Down
56 changes: 56 additions & 0 deletions docs/errors/usage_errors.md
Expand Up @@ -365,6 +365,62 @@ class Model(BaseModel):
assert Model(pet={'pet_type': 'kitten'}).pet.pet_type == 'cat'
```

## Callable discriminator case with no tag {#callable-discriminator-no-tag}

This error is raised when a `Union` that uses a `CallableDiscriminator` doesn't have `Tag` annotations for all cases.

```py
from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, PydanticUserError, Tag


def model_x_discriminator(v):
if isinstance(v, str):
return 'str'
if isinstance(v, (dict, BaseModel)):
return 'model'


# tag missing for both union choices
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'

# tag missing for `'DiscriminatedModel'` union choice
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[Annotated[str, Tag('str')], 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'

# tag missing for `str` union choice
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, Annotated['DiscriminatedModel', Tag('model')]],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'
```


## `TypedDict` version {#typed-dict-version}

Expand Down
4 changes: 4 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -181,6 +181,8 @@
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'CallableDiscriminator',
# type_adapter
'TypeAdapter',
# version
Expand Down Expand Up @@ -320,6 +322,8 @@
'Base64UrlBytes': (__package__, '.types'),
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'CallableDiscriminator': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
# warnings
Expand Down
4 changes: 4 additions & 0 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -48,6 +48,10 @@
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
Expand Down
14 changes: 12 additions & 2 deletions pydantic/_internal/_discriminated_union.py
@@ -1,6 +1,6 @@
from __future__ import annotations as _annotations

from typing import Any, Hashable, Sequence
from typing import TYPE_CHECKING, Any, Hashable, Sequence

from pydantic_core import CoreSchema, core_schema

Expand All @@ -13,6 +13,9 @@
simplify_schema_references,
)

if TYPE_CHECKING:
from ..types import CallableDiscriminator

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'


Expand Down Expand Up @@ -58,7 +61,9 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem


def apply_discriminator(
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
schema: core_schema.CoreSchema,
discriminator: str | CallableDiscriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.

Expand All @@ -83,6 +88,11 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import CallableDiscriminator

if isinstance(discriminator, CallableDiscriminator):
return discriminator._convert_schema(schema)

dmontagu marked this conversation as resolved.
Show resolved Hide resolved
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)


Expand Down