Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow str as argument to Discriminator #8047

Merged
merged 1 commit into from Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/concepts/fields.md
Expand Up @@ -597,8 +597,8 @@ print(user)
## Discriminator

The parameter `discriminator` can be used to control the field that will be used to discriminate between different
models in a union. It takes either the name of a field or a `CallableDiscriminator` instance. The `CallableDiscriminator`
approach can be useful when the discriminator fields aren't the same for all of the models in the `Union`.
models in a union. It takes either the name of a field or a `Discriminator` instance. The `Discriminator`
approach can be useful when the discriminator fields aren't the same for all the models in the `Union`.

The following example shows how to use `discriminator` with a field name:

Expand Down Expand Up @@ -628,14 +628,14 @@ print(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})) # (1)!

1. See more about [Helper Functions] in the [Models] page.

The following example shows how to use `discriminator` with a `CallableDiscriminator` instance:
The following example shows how to use the `discriminator` keyword argument with a `Discriminator` instance:

```py requires="3.8"
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Field, Tag
from pydantic import BaseModel, Discriminator, Field, Tag


class Cat(BaseModel):
Expand All @@ -656,7 +656,7 @@ def pet_discriminator(v):

class Model(BaseModel):
pet: Union[Annotated[Cat, Tag('cat')], Annotated[Dog, Tag('dog')]] = Field(
discriminator=CallableDiscriminator(pet_discriminator)
discriminator=Discriminator(pet_discriminator)
)


Expand Down
27 changes: 14 additions & 13 deletions docs/concepts/unions.md
Expand Up @@ -184,17 +184,18 @@ except ValidationError as e:
"""
```

### Discriminated Unions with `CallableDiscriminator` discriminators
### Discriminated Unions with callable `Discriminator`s

In the case of a `Union` with multiple models, sometimes there isn't a single uniform field
across all models that you can use as a discriminator. This is the perfect use case for the `CallableDiscriminator` approach.
across all models that you can use as a discriminator.
This is the perfect use case for a callable `Discriminator`.

```py requires="3.8"
from typing import Any, Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag
from pydantic import BaseModel, Discriminator, Tag


class Pie(BaseModel):
Expand Down Expand Up @@ -222,7 +223,7 @@ class ThanksgivingDinner(BaseModel):
Annotated[ApplePie, Tag('apple')],
Annotated[PumpkinPie, Tag('pumpkin')],
],
CallableDiscriminator(get_discriminator_value),
Discriminator(get_discriminator_value),
]


Expand All @@ -249,7 +250,7 @@ ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, fillin
"""
```

`CallableDiscriminators` can also be used to validate `Union` types with combinations of models and primitive types.
`Discriminator`s can also be used to validate `Union` types with combinations of models and primitive types.

For example:

Expand All @@ -258,7 +259,7 @@ from typing import Any, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag
from pydantic import BaseModel, Discriminator, Tag


def model_x_discriminator(v: Any) -> str:
Expand All @@ -274,7 +275,7 @@ class DiscriminatedModel(BaseModel):
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
Discriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
Expand Down Expand Up @@ -303,11 +304,11 @@ assert m.model_dump() == data
some_field: Annotated[Union[...], Field(discriminator='my_discriminator')]
```

For `CallableDiscriminator` discriminators:
For callable `Discriminator`s:
```
some_field: Union[...] = Field(discriminator=CallableDiscriminator(...))
some_field: Annotated[Union[...], CallableDiscriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=CallableDiscriminator(...))]
some_field: Union[...] = Field(discriminator=Discriminator(...))
some_field: Annotated[Union[...], Discriminator(...)]
some_field: Annotated[Union[...], Field(discriminator=Discriminator(...))]
```

!!! warning
Expand Down Expand Up @@ -390,7 +391,7 @@ from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Tag, ValidationError
from pydantic import BaseModel, Discriminator, Tag, ValidationError


# Errors are quite verbose with a normal Union:
Expand Down Expand Up @@ -474,7 +475,7 @@ class DiscriminatedModel(BaseModel):
Annotated[str, Tag('str')],
Annotated['DiscriminatedModel', Tag('model')],
],
CallableDiscriminator(
Discriminator(
model_x_discriminator,
custom_error_type='invalid_union_member',
custom_error_message='Invalid union member',
Expand Down
10 changes: 5 additions & 5 deletions docs/errors/usage_errors.md
Expand Up @@ -367,14 +367,14 @@ assert Model(pet={'pet_type': 'kitten'}).pet.pet_type == 'cat'

## Callable discriminator case with no tag {#callable-discriminator-no-tag}

This error is raised when a `Union` that uses a `CallableDiscriminator` doesn't have `Tag` annotations for all cases.
This error is raised when a `Union` that uses a callable `Discriminator` doesn't have `Tag` annotations for all cases.

```py
from typing import Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, PydanticUserError, Tag
from pydantic import BaseModel, Discriminator, PydanticUserError, Tag


def model_x_discriminator(v):
Expand All @@ -390,7 +390,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand All @@ -402,7 +402,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[Annotated[str, Tag('str')], 'DiscriminatedModel'],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand All @@ -414,7 +414,7 @@ try:
class DiscriminatedModel(BaseModel):
x: Annotated[
Union[str, Annotated['DiscriminatedModel', Tag('model')]],
CallableDiscriminator(model_x_discriminator),
Discriminator(model_x_discriminator),
]

except PydanticUserError as exc_info:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/__init__.py
Expand Up @@ -182,7 +182,7 @@
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'CallableDiscriminator',
'Discriminator',
'JsonValue',
# type_adapter
'TypeAdapter',
Expand Down Expand Up @@ -324,7 +324,7 @@
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'CallableDiscriminator': (__package__, '.types'),
'Discriminator': (__package__, '.types'),
'JsonValue': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
Expand Down
13 changes: 8 additions & 5 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -14,7 +14,7 @@
)

if TYPE_CHECKING:
from ..types import CallableDiscriminator
from ..types import Discriminator

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'

Expand Down Expand Up @@ -62,7 +62,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem

def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | CallableDiscriminator,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Expand All @@ -88,10 +88,13 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import CallableDiscriminator
from ..types import Discriminator

if isinstance(discriminator, CallableDiscriminator):
return discriminator._convert_schema(schema)
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)

return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)

Expand Down
4 changes: 2 additions & 2 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -82,7 +82,7 @@
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo, FieldInfo
from ..main import BaseModel
from ..types import CallableDiscriminator
from ..types import Discriminator
from ..validators import FieldValidatorModes
from ._dataclasses import StandardDataclass
from ._schema_generation_shared import GetJsonSchemaFunction
Expand Down Expand Up @@ -374,7 +374,7 @@ def _unknown_type_schema(self, obj: Any) -> CoreSchema:
)

def _apply_discriminator_to_union(
self, schema: CoreSchema, discriminator: str | CallableDiscriminator | None
self, schema: CoreSchema, discriminator: str | Discriminator | None
) -> CoreSchema:
if discriminator is None:
return schema
Expand Down
10 changes: 5 additions & 5 deletions pydantic/fields.py
Expand Up @@ -64,7 +64,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
max_digits: int | None
decimal_places: int | None
union_mode: Literal['smart', 'left_to_right'] | None
discriminator: str | types.CallableDiscriminator | None
discriminator: str | types.Discriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -101,7 +101,7 @@ class FieldInfo(_repr.Representation):
description: The description of the field.
examples: List of examples of the field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
discriminator: Field name or Discriminator for discriminating the type in a tagged union.
json_schema_extra: Dictionary of extra JSON schema properties.
frozen: Whether the field is frozen.
validate_default: Whether to validate the default value of the field.
Expand All @@ -122,7 +122,7 @@ class FieldInfo(_repr.Representation):
description: str | None
examples: list[Any] | None
exclude: bool | None
discriminator: str | types.CallableDiscriminator | None
discriminator: str | types.Discriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -682,7 +682,7 @@ def Field( # noqa: C901
description: str | None = _Unset,
examples: list[Any] | None = _Unset,
exclude: bool | None = _Unset,
discriminator: str | types.CallableDiscriminator | None = _Unset,
discriminator: str | types.Discriminator | None = _Unset,
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = _Unset,
frozen: bool | None = _Unset,
validate_default: bool | None = _Unset,
Expand Down Expand Up @@ -727,7 +727,7 @@ def Field( # noqa: C901
description: Human-readable description.
examples: Example values for this field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
discriminator: Field name or Discriminator for discriminating the type in a tagged union.
json_schema_extra: Any additional JSON schema data for the schema property.
frozen: Whether the field is frozen.
validate_default: Run validation that isn't only checking existence of defaults. This can be set to `True` or `False`. If not set, it defaults to `None`.
Expand Down