Skip to content

Commit

Permalink
Allow str as argument to Discriminator (#8047)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Nov 7, 2023
1 parent da2173e commit a9cebd4
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 69 deletions.
10 changes: 5 additions & 5 deletions docs/concepts/fields.md
Expand Up @@ -597,8 +597,8 @@ print(user)
## Discriminator

The parameter `discriminator` can be used to control the field that will be used to discriminate between different
models in a union. It takes either the name of a field or a `CallableDiscriminator` instance. The `CallableDiscriminator`
approach can be useful when the discriminator fields aren't the same for all of the models in the `Union`.
models in a union. It takes either the name of a field or a `Discriminator` instance. The `Discriminator`
approach can be useful when the discriminator fields aren't the same for all the models in the `Union`.

The following example shows how to use `discriminator` with a field name:

Expand Down Expand Up @@ -628,14 +628,14 @@ print(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})) # (1)!

1. See more about [Helper Functions] in the [Models] page.

The following example shows how to use `discriminator` with a `CallableDiscriminator` instance:
The following example shows how to use the `discriminator` keyword argument with a `Discriminator` instance:

```py requires="3.8"
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Field, Tag
from pydantic import BaseModel, Discriminator, Field, Tag


class Cat(BaseModel):
Expand All @@ -656,7 +656,7 @@ def pet_discriminator(v):

class Model(BaseModel):
pet: Union[Annotated[Cat, Tag('cat')], Annotated[Dog, Tag('dog')]] = Field(
discriminator=CallableDiscriminator(pet_discriminator)
discriminator=Discriminator(pet_discriminator)
)


Expand Down
27 changes: 14 additions & 13 deletions docs/concepts/unions.md
Expand Up @@ -184,17 +184,18 @@ except ValidationError as e:
"""
```

### Discriminated Unions with `CallableDiscriminator` discriminators
### Discriminated Unions with callable `Discriminator`s

In the case of a `Union` with multiple models, sometimes there isn't a single uniform field
across all models that you can use as a discriminator. This is the perfect use case for the `CallableDiscriminator` approach.
across all models that you can use as a discriminator.
This is the perfect use case for a callable `Discriminator`.

```py requires="3.8"
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag
from pydantic import BaseModel, Discriminator, Tag


class Pie(BaseModel):
Expand Down Expand Up @@ -222,7 +223,7 @@ class ThanksgivingDinner(BaseModel):
Annotated[ApplePie, Tag('apple')],
Annotated[PumpkinPie, Tag('pumpkin')],
],
CallableDiscriminator(get_discriminator_value),
Discriminator(get_discriminator_value),
]


Expand All @@ -249,7 +250,7 @@ ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, fillin
"""
```

`CallableDiscriminators` can also be used to validate `Union` types with combinations of models and primitive types.
`Discriminator`s can also be used to validate `Union` types with combinations of models and primitive types.

For example:

Expand All @@ -258,7 +259,7 @@ from typing import Any, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag
from pydantic import BaseModel, Discriminator, Tag


def model_x_discriminator(v: Any) -> str:
Expand All @@ -274,7 +275,7 @@ class DiscriminatedModel(BaseModel):
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
Discriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
Expand Down Expand Up @@ -303,11 +304,11 @@ assert m.model_dump() == data
some_field: Annotated[Union[...], Field(discriminator='my_discriminator')]
```

For `CallableDiscriminator` discriminators:
For callable `Discriminator`s:
```
some_field: Union[...] = Field(discriminator=CallableDiscriminator(...))
some_field: Annotated[Union[...], CallableDiscriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=CallableDiscriminator(...))]
some_field: Union[...] = Field(discriminator=Discriminator(...))
some_field: Annotated[Union[...], Discriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=Discriminator(...))]
```

!!! warning
Expand Down Expand Up @@ -390,7 +391,7 @@ from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag, ValidationError
from pydantic import BaseModel, Discriminator, Tag, ValidationError


# Errors are quite verbose with a normal Union:
Expand Down Expand Up @@ -474,7 +475,7 @@ class DiscriminatedModel(BaseModel):
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
Discriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
Expand Down
10 changes: 5 additions & 5 deletions docs/errors/usage_errors.md
Expand Up @@ -367,14 +367,14 @@ assert Model(pet={'pet_type': 'kitten'}).pet.pet_type == 'cat'

## Callable discriminator case with no tag {#callable-discriminator-no-tag}

This error is raised when a `Union` that uses a `CallableDiscriminator` doesn't have `Tag` annotations for all cases.
This error is raised when a `Union` that uses a callable `Discriminator` doesn't have `Tag` annotations for all cases.

```py
from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, PydanticUserError, Tag
from pydantic import BaseModel, Discriminator, PydanticUserError, Tag


def model_x_discriminator(v):
Expand All @@ -390,7 +390,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand All @@ -402,7 +402,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[Annotated[str, Tag('str')], 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand All @@ -414,7 +414,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, Annotated['DiscriminatedModel', Tag('model')]],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/__init__.py
Expand Up @@ -182,7 +182,7 @@
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'CallableDiscriminator',
'Discriminator',
'JsonValue',
# type_adapter
'TypeAdapter',
Expand Down Expand Up @@ -324,7 +324,7 @@
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'CallableDiscriminator': (__package__, '.types'),
'Discriminator': (__package__, '.types'),
'JsonValue': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
Expand Down
13 changes: 8 additions & 5 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -14,7 +14,7 @@
)

if TYPE_CHECKING:
from ..types import CallableDiscriminator
from ..types import Discriminator

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'

Expand Down Expand Up @@ -62,7 +62,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem

def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | CallableDiscriminator,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Expand All @@ -88,10 +88,13 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import CallableDiscriminator
from ..types import Discriminator

if isinstance(discriminator, CallableDiscriminator):
return discriminator._convert_schema(schema)
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)

return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)

Expand Down
4 changes: 2 additions & 2 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -82,7 +82,7 @@
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo, FieldInfo
from ..main import BaseModel
from ..types import CallableDiscriminator
from ..types import Discriminator
from ..validators import FieldValidatorModes
from ._dataclasses import StandardDataclass
from ._schema_generation_shared import GetJsonSchemaFunction
Expand Down Expand Up @@ -374,7 +374,7 @@ def _unknown_type_schema(self, obj: Any) -> CoreSchema:
)

def _apply_discriminator_to_union(
self, schema: CoreSchema, discriminator: str | CallableDiscriminator | None
self, schema: CoreSchema, discriminator: str | Discriminator | None
) -> CoreSchema:
if discriminator is None:
return schema
Expand Down
10 changes: 5 additions & 5 deletions pydantic/fields.py
Expand Up @@ -64,7 +64,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
max_digits: int | None
decimal_places: int | None
union_mode: Literal['smart', 'left_to_right'] | None
discriminator: str | types.CallableDiscriminator | None
discriminator: str | types.Discriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -101,7 +101,7 @@ class FieldInfo(_repr.Representation):
description: The description of the field.
examples: List of examples of the field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
discriminator: Field name or Discriminator for discriminating the type in a tagged union.
json_schema_extra: Dictionary of extra JSON schema properties.
frozen: Whether the field is frozen.
validate_default: Whether to validate the default value of the field.
Expand All @@ -122,7 +122,7 @@ class FieldInfo(_repr.Representation):
description: str | None
examples: list[Any] | None
exclude: bool | None
discriminator: str | types.CallableDiscriminator | None
discriminator: str | types.Discriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -682,7 +682,7 @@ def Field( # noqa: C901
description: str | None = _Unset,
examples: list[Any] | None = _Unset,
exclude: bool | None = _Unset,
discriminator: str | types.CallableDiscriminator | None = _Unset,
discriminator: str | types.Discriminator | None = _Unset,
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = _Unset,
frozen: bool | None = _Unset,
validate_default: bool | None = _Unset,
Expand Down Expand Up @@ -727,7 +727,7 @@ def Field( # noqa: C901
description: Human-readable description.
examples: Example values for this field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
discriminator: Field name or Discriminator for discriminating the type in a tagged union.
json_schema_extra: Any additional JSON schema data for the schema property.
frozen: Whether the field is frozen.
validate_default: Run validation that isn't only checking existence of defaults. This can be set to `True` or `False`. If not set, it defaults to `None`.
Expand Down

0 comments on commit a9cebd4

Please sign in to comment.