Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CallableDiscriminator and Tag #7983

Merged
merged 9 commits into from Nov 1, 2023
134 changes: 128 additions & 6 deletions docs/api/standard_library_types.md
Expand Up @@ -871,18 +871,24 @@ print(type(User(id='123', age='45').id))

### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
To do that you can set the same field - let's call it `my_discriminator` - in each of the submodels
with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.
We can use discriminated unions to more efficiently validate `Union` types.
Discriminated unions can be used to validate `Union` types with multiple models, or combinations of
models and primitive types.

Setting a discriminated union has many benefits:

- validation is faster since it is only attempted against one model
- only one explicit error is raised in case of failure
- the generated JSON schema implements the [associated OpenAPI specification](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#discriminator-object)

#### Discriminated Unions with `str` discriminators
In the case of a `Union` with multiple models, you sometimes know exactly which field of each
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
model needs to be checked and validated and want to enforce this.
To do that you can set the same field - let's call it `my_discriminator` - in each of the models
with a discriminated value, which is one (or many) `Literal` value(s).
For your `Union`, you can set the discriminator in its value: `Field(discriminator='my_discriminator')`.


```py requires="3.8"
from typing import Literal, Union

Expand Down Expand Up @@ -922,9 +928,125 @@ except ValidationError as e:
"""
```

#### Discriminated Unions with `CallableDiscriminator` discriminators
In the case of a `Union` with multiple models, sometimes there isn't a single uniform field
across all models that you can use as a discriminator. This is the perfect use case for the `CallableDiscriminator` approach.

```py requires="3.8"
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


class Pie(BaseModel):
time_to_cook: int
num_ingredients: int


class ApplePie(Pie):
fruit: Literal['apple'] = 'apple'


class PumpkinPie(Pie):
filling: Literal['pumpkin'] = 'pumpkin'


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get('fruit', v.get('filling'))
return getattr(v, 'fruit', getattr(v, 'filling', None))


class ThanksgivingDinner(BaseModel):
dessert: Annotated[
Union[
Annotated[ApplePie, Tag('apple')],
Annotated[PumpkinPie, Tag('pumpkin')],
],
CallableDiscriminator(get_discriminator_value),
]


apple_variation = ThanksgivingDinner.model_validate(
{'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}}
)
print(repr(apple_variation))
"""
ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple'))
"""

pumpkin_variation = ThanksgivingDinner.model_validate(
{
'dessert': {
'filling': 'pumpkin',
'time_to_cook': 40,
'num_ingredients': 6,
}
}
)
print(repr(pumpkin_variation))
"""
ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin'))
"""
```

`CallableDiscriminators` can also be used to validate `Union` types with combinations of models and primitive types.

For example:

```py requires="3.8"
from typing import Any, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


def model_x_discriminator(v: Any) -> str:
if isinstance(v, str):
return 'str'
if isinstance(v, (dict, BaseModel)):
return 'model'


class DiscriminatedModel(BaseModel):
x: Annotated[
Union[
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
custom_error_context={'discriminator': 'str_or_model'},
),
]


data = {'x': {'x': {'x': 'a'}}}
m = DiscriminatedModel.model_validate(data)
assert m == (
DiscriminatedModel(x=DiscriminatedModel(x=DiscriminatedModel(x='a')))
)
assert m.model_dump() == data
```

!!! note
Using the [`typing.Annotated` fields syntax](../concepts/json_schema.md#typingannotated-fields) can be handy to regroup
the `Union` and `discriminator` information. See below for an example!
the `Union` and `discriminator` information. See the next example for more details.

There are a few ways to set a discriminator for a field, all varying slightly in syntax.

For `str` discriminators:
* `some_field: Union[...] = Field(discriminator='my_discriminator'`
* `some_field: Annotated[Union[...], Field(discriminator='my_discriminator')]`
For `CallableDiscriminator` discriminators:
* `some_field: Union[...] = Field(discriminator=CallableDiscriminator(...))`
* `some_field: Annotated[Union[...], CallableDiscriminator(...)]`
* `some_field: Annotated[Union[...], Field(discriminator=CallableDiscriminator(...))]`

!!! warning
Discriminated unions cannot be used with only a single variant, such as `Union[Cat]`.
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 2 additions & 1 deletion docs/concepts/fields.md
Expand Up @@ -667,7 +667,8 @@ print(repr(Model.model_validate({'pet': {'pet_kind': 'dog', 'age': 12}})))
#> Model(pet=Dog(pet_kind='dog', age=12))
```

See the [Discriminated Unions] for more details.
You can also take advantage of `Annotated` to define your discriminated unions.
See the [Discriminated Unions] docs for more details.
dmontagu marked this conversation as resolved.
Show resolved Hide resolved

## Strict Mode

Expand Down
4 changes: 4 additions & 0 deletions docs/concepts/performance.md
Expand Up @@ -132,6 +132,8 @@ class Html(BaseModel):
)
```

See [Discriminated Unions] for more details.

## Use `Literal` not `Enum`

Instead of using `Enum`, use `Literal` to define the structure of the data.
Expand Down Expand Up @@ -207,3 +209,5 @@ Instead of using nested models, use `TypedDict` to define the structure of the d
## Avoid wrap validators if you really care about performance

<!-- TODO: I need help on this one. -->
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

[Discriminated Unions]: ../api/standard_library_types.md#discriminated-unions-aka-tagged-unions
118 changes: 117 additions & 1 deletion pydantic/types.py
Expand Up @@ -2444,7 +2444,66 @@ def __getattr__(self, item: str) -> Any:
class Tag:
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
"""Provides a way to specify the expected tag to use for a case with a callable discriminated union.

TODO: Need to add more thorough docs..
When using a `CallableDiscriminator`, attach a `Tag` to each case in the `Union` to specify the tag that
should be used to identify that case. For example, in the below example, the `Tag` is used to specify that
if `get_discriminator_value` returns `'apple'`, the input should be validated as an `ApplePie`, and if it
returns `'pumpkin'`, the input should be validated as a `PumpkinPie`.

The primary role of the `Tag` here is to map the return value from the `CallableDiscriminator` function to
the appropriate member of the `Union` in question.

```
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


class Pie(BaseModel):
time_to_cook: int
num_ingredients: int


class ApplePie(Pie):
fruit: Literal['apple'] = 'apple'


class PumpkinPie(Pie):
filling: Literal['pumpkin'] = 'pumpkin'


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get('fruit', v.get('filling'))
return getattr(v, 'fruit', getattr(v, 'filling', None))


class ThanksgivingDinner(BaseModel):
dessert: Annotated[
Union[
Annotated[ApplePie, Tag("apple")],
Annotated[PumpkinPie, Tag("pumpkin")],
],
CallableDiscriminator(get_discriminator_value),
]


apple_variation = ThanksgivingDinner.model_validate(
{'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}}
)
print(repr(apple_variation))
# > ThanksgivingDinner(dessert=ApplePie(fruit='apple', time_to_cook=60, num_ingredients=8))

pumpkin_variation = ThanksgivingDinner.model_validate(
{'dessert': {'filling': 'pumpkin', 'time_to_cook': 40, 'num_ingredients': 6}}
)
print(repr(pumpkin_variation))
# > ThanksgivingDinner(dessert=PumpkinPie(filling='pumpkin', time_to_cook=40, num_ingredients=6))
```

See the [Discriminated Unions](../api/standard_library_types.md#discriminated-unions-aka-tagged-unions)
docs for more details on how to use `Tag`s.
"""

tag: str
Expand All @@ -2464,6 +2523,63 @@ class CallableDiscriminator:
This allows you to get validation behavior like you'd get from `Field(discriminator=<field_name>)`,
but without needing to have a single shared field across all the union choices. This also makes it
possible to handle unions of models and primitive types with discriminated-union-style validation errors.
Finally, this allows you to use a custom callable as the way to identify which member of a union a value
belongs to, while still seeing all the performance benefits of a discriminated union.

Consider this example, which is much more performant with the use of `CallableDiscriminator` and thus a `TaggedUnion`
than it would be as a normal `Union`.
```
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag


class Pie(BaseModel):
time_to_cook: int
num_ingredients: int


class ApplePie(Pie):
fruit: Literal['apple'] = 'apple'


class PumpkinPie(Pie):
filling: Literal['pumpkin'] = 'pumpkin'


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get('fruit', v.get('filling'))
return getattr(v, 'fruit', getattr(v, 'filling', None))


class ThanksgivingDinner(BaseModel):
dessert: Annotated[
Union[
Annotated[ApplePie, Tag("apple")],
Annotated[PumpkinPie, Tag("pumpkin")],
],
CallableDiscriminator(get_discriminator_value),
]


apple_variation = ThanksgivingDinner.model_validate(
{'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}}
)
print(repr(apple_variation))
# > ThanksgivingDinner(dessert=ApplePie(fruit='apple', time_to_cook=60, num_ingredients=8))

pumpkin_variation = ThanksgivingDinner.model_validate(
{'dessert': {'filling': 'pumpkin', 'time_to_cook': 40, 'num_ingredients': 6}}
)
print(repr(pumpkin_variation))
# > ThanksgivingDinner(dessert=PumpkinPie(filling='pumpkin', time_to_cook=40, num_ingredients=6))
```

See the [Discriminated Unions](../api/standard_library_types.md#discriminated-unions-aka-tagged-unions)
docs for more details on how to use `CallableDiscriminator`s.
"""

discriminator: Callable[[Any], Hashable]
Expand Down