Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CallableDiscriminator and Tag #7983

Merged
merged 9 commits into from Nov 1, 2023
44 changes: 43 additions & 1 deletion docs/concepts/fields.md
Expand Up @@ -597,7 +597,10 @@ print(user)
## Discriminator

The parameter `discriminator` can be used to control the field that will be used to discriminate between different
models in a union.
models in a union. It takes either the name of a field or a `CallableDiscriminator` instance. The `CallableDiscriminator`
approach can be useful when the discriminator fields aren't the same for all of the models in the `Union`.

The following example shows how to use `discriminator` with a field name:

```py requires="3.8"
from typing import Literal, Union
Expand Down Expand Up @@ -625,6 +628,45 @@ print(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})) # (1)!

1. See more about [Helper Functions] in the [Models] page.

The following example shows how to use `discriminator` with a `CallableDiscriminator` instance:

```py requires="3.8"
from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, CallableDiscriminator, Field, Tag


class Cat(BaseModel):
pet_type: Literal['cat']
age: int


class Dog(BaseModel):
pet_kind: Literal['dog']
age: int


def pet_discriminator(v):
if isinstance(v, dict):
return v.get('pet_type', v.get('pet_kind'))
return getattr(v, 'pet_type', getattr(v, 'pet_kind', None))


class Model(BaseModel):
pet: Union[Annotated[Cat, Tag('cat')], Annotated[Dog, Tag('dog')]] = Field(
discriminator=CallableDiscriminator(pet_discriminator)
)


print(repr(Model.model_validate({'pet': {'pet_type': 'cat', 'age': 12}})))
#> Model(pet=Cat(pet_type='cat', age=12))

print(repr(Model.model_validate({'pet': {'pet_kind': 'dog', 'age': 12}})))
#> Model(pet=Dog(pet_kind='dog', age=12))
```

See the [Discriminated Unions] for more details.

## Strict Mode
Expand Down
4 changes: 4 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -181,6 +181,8 @@
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'CallableDiscriminator',
# type_adapter
'TypeAdapter',
# version
Expand Down Expand Up @@ -320,6 +322,8 @@
'Base64UrlBytes': (__package__, '.types'),
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'CallableDiscriminator': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
# warnings
Expand Down
14 changes: 12 additions & 2 deletions pydantic/_internal/_discriminated_union.py
@@ -1,6 +1,6 @@
from __future__ import annotations as _annotations

from typing import Any, Hashable, Sequence
from typing import TYPE_CHECKING, Any, Hashable, Sequence

from pydantic_core import CoreSchema, core_schema

Expand All @@ -13,6 +13,9 @@
simplify_schema_references,
)

if TYPE_CHECKING:
from ..types import CallableDiscriminator

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'


Expand Down Expand Up @@ -58,7 +61,9 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem


def apply_discriminator(
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
schema: core_schema.CoreSchema,
discriminator: str | CallableDiscriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.

Expand All @@ -83,6 +88,11 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import CallableDiscriminator

if isinstance(discriminator, CallableDiscriminator):
return discriminator._convert_schema(schema)

dmontagu marked this conversation as resolved.
Show resolved Hide resolved
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)


Expand Down
7 changes: 6 additions & 1 deletion pydantic/_internal/_generate_schema.py
Expand Up @@ -82,6 +82,7 @@
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo, FieldInfo
from ..main import BaseModel
from ..types import CallableDiscriminator
from ..validators import FieldValidatorModes
from ._dataclasses import StandardDataclass
from ._schema_generation_shared import GetJsonSchemaFunction
Expand Down Expand Up @@ -372,7 +373,11 @@ def _unknown_type_schema(self, obj: Any) -> CoreSchema:
' `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.'
)

def _apply_discriminator_to_union(self, schema: CoreSchema, discriminator: Any) -> CoreSchema:
def _apply_discriminator_to_union(
self, schema: CoreSchema, discriminator: str | CallableDiscriminator | None
) -> CoreSchema:
if discriminator is None:
return schema
try:
return _discriminated_union.apply_discriminator(
schema,
Expand Down
10 changes: 5 additions & 5 deletions pydantic/fields.py
Expand Up @@ -64,7 +64,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
max_digits: int | None
decimal_places: int | None
union_mode: Literal['smart', 'left_to_right'] | None
discriminator: str | None
discriminator: str | types.CallableDiscriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -101,7 +101,7 @@ class FieldInfo(_repr.Representation):
description: The description of the field.
examples: List of examples of the field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name for discriminating the type in a tagged union.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
json_schema_extra: Dictionary of extra JSON schema properties.
frozen: Whether the field is frozen.
validate_default: Whether to validate the default value of the field.
Expand All @@ -122,7 +122,7 @@ class FieldInfo(_repr.Representation):
description: str | None
examples: list[Any] | None
exclude: bool | None
discriminator: str | None
discriminator: str | types.CallableDiscriminator | None
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None
frozen: bool | None
validate_default: bool | None
Expand Down Expand Up @@ -682,7 +682,7 @@ def Field( # noqa: C901
description: str | None = _Unset,
examples: list[Any] | None = _Unset,
exclude: bool | None = _Unset,
discriminator: str | None = _Unset,
discriminator: str | types.CallableDiscriminator | None = _Unset,
json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = _Unset,
frozen: bool | None = _Unset,
validate_default: bool | None = _Unset,
Expand Down Expand Up @@ -727,7 +727,7 @@ def Field( # noqa: C901
description: Human-readable description.
examples: Example values for this field.
exclude: Whether to exclude the field from the model serialization.
discriminator: Field name for discriminating the type in a tagged union.
discriminator: Field name or CallableDiscriminator for discriminating the type in a tagged union.
json_schema_extra: Any additional JSON schema data for the schema property.
frozen: Whether the field is frozen.
validate_default: Run validation that isn't only checking existence of defaults. This can be set to `True` or `False`. If not set, it defaults to `None`.
Expand Down
99 changes: 98 additions & 1 deletion pydantic/types.py
Expand Up @@ -30,7 +30,13 @@
from pydantic_core import CoreSchema, PydanticCustomError, core_schema
from typing_extensions import Annotated, Literal, Protocol, deprecated

from ._internal import _fields, _internal_dataclass, _utils, _validators
from ._internal import (
_fields,
_internal_dataclass,
_typing_extra,
_utils,
_validators,
)
from ._migration import getattr_migration
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from .errors import PydanticUserError
Expand Down Expand Up @@ -92,6 +98,8 @@
'Base64UrlStr',
'GetPydanticSchema',
'StringConstraints',
'Tag',
'CallableDiscriminator',
)


Expand Down Expand Up @@ -2429,3 +2437,92 @@ def __getattr__(self, item: str) -> Any:
return object.__getattribute__(self, item)

__hash__ = object.__hash__


_TAGGED_UNION_TAG_KEY = 'pydantic-tagged-union-tag'
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved


@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
class Tag:
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
"""Provides a way to specify the expected tag to use for a case with a callable discriminated union.

TODO: Need to add more thorough docs..
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
"""

tag: str

def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(source_type)
metadata = schema.setdefault('metadata', {})
assert isinstance(metadata, dict)
metadata[_TAGGED_UNION_TAG_KEY] = self.tag
return schema


@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
class CallableDiscriminator:
"""Provides a way to use a custom callable as the way to extract the value of a union discriminator.

This allows you to get validation behavior like you'd get from `Field(discriminator=<field_name>)`,
but without needing to have a single shared field across all the union choices. This also makes it
possible to handle unions of models and primitive types with discriminated-union-style validation errors.
"""

discriminator: Callable[[Any], Hashable]
custom_error_type: str | None = None
custom_error_message: str | None = None
custom_error_context: dict[str, int | str | float] | None = None

def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
origin = _typing_extra.get_origin(source_type)
if not origin or not _typing_extra.origin_is_union(origin):
raise TypeError(f'{type(self).__name__} must be used with a Union type, not {source_type}')

original_schema = handler.generate_schema(source_type)
return self._convert_schema(original_schema)

def _convert_schema(self, original_schema: core_schema.CoreSchema) -> core_schema.TaggedUnionSchema:
if original_schema['type'] != 'union':
# This likely indicates that the schema was a single-item union that was simplified.
# In this case, we do the same thing we do in
# `pydantic._internal._discriminated_union._ApplyInferredDiscriminator._apply_to_root`, namely,
# package the generated schema back into a single-item union.
original_schema = core_schema.union_schema([original_schema])

tagged_union_choices = {}
for i, choice in enumerate(original_schema['choices']):
tag = f'case-{i}'
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(choice, tuple):
choice, tag = choice
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
metadata = choice.get('metadata')
if metadata is not None:
metadata_tag = metadata.get(_TAGGED_UNION_TAG_KEY)
if metadata_tag is not None:
tag = metadata_tag
tagged_union_choices[tag] = choice

# Have to do these verbose checks to ensure falsy values ('' and {}) don't get ignored
custom_error_type = self.custom_error_type
if custom_error_type is None:
custom_error_type = original_schema.get('custom_error_type')

custom_error_message = self.custom_error_message
if custom_error_message is None:
custom_error_message = original_schema.get('custom_error_message')

custom_error_context = self.custom_error_context
if custom_error_context is None:
custom_error_context = original_schema.get('custom_error_context')

custom_error_type = original_schema.get('custom_error_type') if custom_error_type is None else custom_error_type
return core_schema.tagged_union_schema(
tagged_union_choices,
self.discriminator,
custom_error_type=custom_error_type,
custom_error_message=custom_error_message,
custom_error_context=custom_error_context,
strict=original_schema.get('strict'),
ref=original_schema.get('ref'),
metadata=original_schema.get('metadata'),
serialization=original_schema.get('serialization'),
)