Skip to content

Commit

Permalink
Add CallableDiscriminator and Tag (#7983)
Browse files Browse the repository at this point in the history
Co-authored-by: sydney-runkle <sydneymarierunkle@gmail.com>
Co-authored-by: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 1, 2023
1 parent 60c5db6 commit 9868b45
Show file tree
Hide file tree
Showing 13 changed files with 839 additions and 19 deletions.
141 changes: 135 additions & 6 deletions docs/api/standard_library_types.md
Expand Up @@ -871,18 +871,26 @@ print(type(User(id='123', age='45').id))

### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
To do that you can set the same field - let's call it `my_discriminator` - in each of the submodels
with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.
We can use discriminated unions to more efficiently validate `Union` types.
Discriminated unions can be used to validate `Union` types with multiple models, or combinations of
models and primitive types.

Setting a discriminated union has many benefits:

- validation is faster since it is only attempted against one model
- only one explicit error is raised in case of failure
- the generated JSON schema implements the [associated OpenAPI specification](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#discriminator-object)

#### Discriminated Unions with `str` discriminators
Frequently, in the case of a `Union` with multiple models,
there is a common field to all members of the union that can be used to distinguish
which union case the data should be validated against; this is referred to as the "discriminator" in
[OpenAPI](https://swagger.io/docs/specification/data-models/inheritance-and-polymorphism/).
To validate models based on that information you can set the same field - let's call it `my_discriminator` -
in each of the models with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.


```py requires="3.8"
from typing import Literal, Union

Expand Down Expand Up @@ -922,9 +930,130 @@ except ValidationError as e:
"""
```

#### Discriminated Unions with `CallableDiscriminator` discriminators
In the case of a `Union` with multiple models, sometimes there isn't a single uniform field
across all models that you can use as a discriminator. This is the perfect use case for the `CallableDiscriminator` approach.

```py requires="3.8"
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


class Pie(BaseModel):
time_to_cook: int
num_ingredients: int


class ApplePie(Pie):
fruit: Literal['apple'] = 'apple'


class PumpkinPie(Pie):
filling: Literal['pumpkin'] = 'pumpkin'


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get('fruit', v.get('filling'))
return getattr(v, 'fruit', getattr(v, 'filling', None))


class ThanksgivingDinner(BaseModel):
dessert: Annotated[
Union[
Annotated[ApplePie, Tag('apple')],
Annotated[PumpkinPie, Tag('pumpkin')],
],
CallableDiscriminator(get_discriminator_value),
]


apple_variation = ThanksgivingDinner.model_validate(
{'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}}
)
print(repr(apple_variation))
"""
ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple'))
"""

pumpkin_variation = ThanksgivingDinner.model_validate(
{
'dessert': {
'filling': 'pumpkin',
'time_to_cook': 40,
'num_ingredients': 6,
}
}
)
print(repr(pumpkin_variation))
"""
ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin'))
"""
```

`CallableDiscriminators` can also be used to validate `Union` types with combinations of models and primitive types.

For example:

```py requires="3.8"
from typing import Any, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


def model_x_discriminator(v: Any) -> str:
if isinstance(v, str):
return 'str'
if isinstance(v, (dict, BaseModel)):
return 'model'


class DiscriminatedModel(BaseModel):
x: Annotated[
Union[
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
custom_error_context={'discriminator': 'str_or_model'},
),
]


data = {'x': {'x': {'x': 'a'}}}
m = DiscriminatedModel.model_validate(data)
assert m == (
DiscriminatedModel(x=DiscriminatedModel(x=DiscriminatedModel(x='a')))
)
assert m.model_dump() == data
```

!!! note
Using the [`typing.Annotated` fields syntax](../concepts/json_schema.md#typingannotated-fields) can be handy to regroup
the `Union` and `discriminator` information. See below for an example!
the `Union` and `discriminator` information. See the next example for more details.

There are a few ways to set a discriminator for a field, all varying slightly in syntax.

For `str` discriminators:
```
some_field: Union[...] = Field(discriminator='my_discriminator'
some_field: Annotated[Union[...], Field(discriminator='my_discriminator')]
```

For `CallableDiscriminator` discriminators:
```
some_field: Union[...] = Field(discriminator=CallableDiscriminator(...))
some_field: Annotated[Union[...], CallableDiscriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=CallableDiscriminator(...))]
```

!!! warning
Discriminated unions cannot be used with only a single variant, such as `Union[Cat]`.
Expand Down
47 changes: 45 additions & 2 deletions docs/concepts/fields.md
Expand Up @@ -597,7 +597,10 @@ print(user)
## Discriminator

The parameter `discriminator` can be used to control the field that will be used to discriminate between different
models in a union.
models in a union. It takes either the name of a field or a `CallableDiscriminator` instance. The `CallableDiscriminator`
approach can be useful when the discriminator fields aren't the same for all of the models in the `Union`.

The following example shows how to use `discriminator` with a field name:

```py requires="3.8"
from typing import Literal, Union
Expand Down Expand Up @@ -625,7 +628,47 @@ print(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})) # (1)!

1. See more about [Helper Functions] in the [Models] page.

See the [Discriminated Unions] for more details.
The following example shows how to use `discriminator` with a `CallableDiscriminator` instance:

```py requires="3.8"
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Field, Tag


class Cat(BaseModel):
pet_type: Literal['cat']
age: int


class Dog(BaseModel):
pet_kind: Literal['dog']
age: int


def pet_discriminator(v):
if isinstance(v, dict):
return v.get('pet_type', v.get('pet_kind'))
return getattr(v, 'pet_type', getattr(v, 'pet_kind', None))


class Model(BaseModel):
pet: Union[Annotated[Cat, Tag('cat')], Annotated[Dog, Tag('dog')]] = Field(
discriminator=CallableDiscriminator(pet_discriminator)
)


print(repr(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})))
#> Model(pet=Cat(pet_type='cat', age=12))

print(repr(Model.model_validate({'pet': {'pet_kind': 'dog', 'age': 12}})))
#> Model(pet=Dog(pet_kind='dog', age=12))
```

You can also take advantage of `Annotated` to define your discriminated unions.
See the [Discriminated Unions] docs for more details.

## Strict Mode

Expand Down
4 changes: 4 additions & 0 deletions docs/concepts/performance.md
Expand Up @@ -132,6 +132,8 @@ class Html(BaseModel):
)
```

See [Discriminated Unions] for more details.

## Use `Literal` not `Enum`

Instead of using `Enum`, use `Literal` to define the structure of the data.
Expand Down Expand Up @@ -207,3 +209,5 @@ Instead of using nested models, use `TypedDict` to define the structure of the d
## Avoid wrap validators if you really care about performance

<!-- TODO: I need help on this one. -->

[Discriminated Unions]: ../api/standard_library_types.md#discriminated-unions-aka-tagged-unions
3 changes: 3 additions & 0 deletions docs/concepts/validators.md
Expand Up @@ -392,6 +392,9 @@ However, in most cases where you want to perform validation using multiple field

## Model validators

??? api "API Documentation"
[`pydantic.functional_validators.model_validator`][pydantic.functional_validators.model_validator]<br>

Validation can also be performed on the entire model's data using `@model_validator`.

```py
Expand Down
56 changes: 56 additions & 0 deletions docs/errors/usage_errors.md
Expand Up @@ -365,6 +365,62 @@ class Model(BaseModel):
assert Model(pet={'pet_type': 'kitten'}).pet.pet_type == 'cat'
```

## Callable discriminator case with no tag {#callable-discriminator-no-tag}

This error is raised when a `Union` that uses a `CallableDiscriminator` doesn't have `Tag` annotations for all cases.

```py
from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, PydanticUserError, Tag


def model_x_discriminator(v):
if isinstance(v, str):
return 'str'
if isinstance(v, (dict, BaseModel)):
return 'model'


# tag missing for both union choices
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'

# tag missing for `'DiscriminatedModel'` union choice
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[Annotated[str, Tag('str')], 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'

# tag missing for `str` union choice
try:

class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, Annotated['DiscriminatedModel', Tag('model')]],
CallableDiscriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
assert exc_info.code == 'callable-discriminator-no-tag'
```


## `TypedDict` version {#typed-dict-version}

Expand Down
4 changes: 4 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -181,6 +181,8 @@
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'CallableDiscriminator',
# type_adapter
'TypeAdapter',
# version
Expand Down Expand Up @@ -320,6 +322,8 @@
'Base64UrlBytes': (__package__, '.types'),
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'CallableDiscriminator': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
# warnings
Expand Down
4 changes: 4 additions & 0 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -48,6 +48,10 @@
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
Expand Down
14 changes: 12 additions & 2 deletions pydantic/_internal/_discriminated_union.py
@@ -1,6 +1,6 @@
from __future__ import annotations as _annotations

from typing import Any, Hashable, Sequence
from typing import TYPE_CHECKING, Any, Hashable, Sequence

from pydantic_core import CoreSchema, core_schema

Expand All @@ -13,6 +13,9 @@
simplify_schema_references,
)

if TYPE_CHECKING:
from ..types import CallableDiscriminator

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'


Expand Down Expand Up @@ -58,7 +61,9 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem


def apply_discriminator(
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
schema: core_schema.CoreSchema,
discriminator: str | CallableDiscriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Expand All @@ -83,6 +88,11 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import CallableDiscriminator

if isinstance(discriminator, CallableDiscriminator):
return discriminator._convert_schema(schema)

return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)


Expand Down

0 comments on commit 9868b45

Please sign in to comment.