Skip to content

Commit

Permalink
add union_mode='left_to_right'
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Aug 16, 2023
1 parent 0d4b77b commit 84d6f6d
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/migration.md
Expand Up @@ -506,6 +506,9 @@ print(Model(x='1'))
In Pydantic V1, the printed result would have been `x=1`, since the value would pass validation as an `int`.
In Pydantic V2, we recognize that the value is an instance of one of the cases and short-circuit the standard union validation.

To revert to the non-short-circuiting left-to-right behavior of V1, annotate the union with `Field(union_mode='left_to_right')`.
See [Union Mode](./usage/types/unions.md#union-mode) for more details.

#### Required, optional, and nullable fields

Pydantic V2 changes some of the logic for specifying whether a field annotated as `Optional` is required
Expand Down
56 changes: 55 additions & 1 deletion docs/usage/types/unions.md
Expand Up @@ -43,7 +43,61 @@ print(user_03_uuid.int)

See more details in [Required fields](../models.md#required-fields).

#### Discriminated Unions (a.k.a. Tagged Unions)
#### Union Mode

By default `Union` validation will try to return the variant which is the best match for the input.

Consider for example the case of `Union[int, str]`. When [`strict` mode](../strict_mode.md) is not enabled
then `int` fields will accept `str` inputs. In the example below, the `id` field (which is `Union[int, str]`)
will accept the string `'123'` as an input, and preserve it as a string:

```py
from typing import Union

from pydantic import BaseModel


class User(BaseModel):
id: Union[int, str]
age: int


print(User(id='123', age='45'))
#> id='123' age=45

print(type(User(id='123', age='45').id))
#> <class 'str'>
```

This is known as `'smart'` mode for `Union` validation.

At present only one other `Union` validation mode exists, called `'left_to_right'` validation. In this mode
variants are attempted from left to right and the first successful validation is accepted as input.

Consider the same example, this time with `union_mode='left_to_right'` set as a [`Field`](../fields.md)
parameter on `id`. With this validation mode, the `int` variant will coerce strings of digits into `int`
values:

```py
from typing import Union

from pydantic import BaseModel, Field


class User(BaseModel):
id: Union[int, str] = Field(..., union_mode='left_to_right')
age: int


print(User(id='123', age='45'))
#> id=123 age=45


print(type(User(id='123', age='45').id))
#> <class 'int'>
```

### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
Expand Down
8 changes: 7 additions & 1 deletion pydantic/_internal/_known_annotated_metadata.py
Expand Up @@ -35,6 +35,7 @@
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}

UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
Expand Down Expand Up @@ -75,6 +76,8 @@
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in UNION_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
Expand Down Expand Up @@ -147,7 +150,10 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]

if schema_type in allowed_schemas:
schema[constraint] = value
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
else:
schema[constraint] = value
continue

if constraint == 'allow_inf_nan' and value is False:
Expand Down
8 changes: 7 additions & 1 deletion pydantic/fields.py
Expand Up @@ -13,7 +13,7 @@
import annotated_types
import typing_extensions
from pydantic_core import PydanticUndefined
from typing_extensions import Unpack
from typing_extensions import Literal, Unpack

from . import types
from ._internal import _decorators, _fields, _generics, _internal_dataclass, _repr, _typing_extra, _utils
Expand Down Expand Up @@ -56,6 +56,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
allow_inf_nan: bool | None
max_digits: int | None
decimal_places: int | None
union_mode: Literal['smart', 'left_to_right'] | None
discriminator: str | None
json_schema_extra: dict[str, Any] | typing.Callable[[dict[str, Any]], None] | None
frozen: bool | None
Expand Down Expand Up @@ -161,6 +162,7 @@ class FieldInfo(_repr.Representation):
'allow_inf_nan': None,
'max_digits': None,
'decimal_places': None,
'union_mode': None,
}

def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
Expand Down Expand Up @@ -692,6 +694,7 @@ def Field( # noqa: C901
decimal_places: int | None = _Unset,
min_length: int | None = _Unset,
max_length: int | None = _Unset,
union_mode: Literal['smart', 'left_to_right'] = _Unset,
**extra: Unpack[_EmptyKwargs],
) -> Any:
"""Usage docs: https://docs.pydantic.dev/dev-v2/usage/fields
Expand Down Expand Up @@ -737,6 +740,8 @@ def Field( # noqa: C901
allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to numbers.
max_digits: Maximum number of allow digits for strings.
decimal_places: Maximum number of decimal places allowed for numbers.
union_mode: The strategy to apply when validating a union. Can be `smart` (the default), or `left_to_right`.
See [Union Mode](../usage/types/unions.md#union-mode) for details.
extra: Include extra fields used by the JSON schema.
!!! warning Deprecated
Expand Down Expand Up @@ -838,6 +843,7 @@ def Field( # noqa: C901
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
union_mode=union_mode,
)


Expand Down
70 changes: 64 additions & 6 deletions tests/test_types.py
Expand Up @@ -4681,12 +4681,70 @@ class DefaultModel(BaseModel):
assert repr(DefaultModel(v=1).v) == '1'
assert repr(DefaultModel(v='1').v) == "'1'"

# assert DefaultModel.model_json_schema() == {
# 'title': 'DefaultModel',
# 'type': 'object',
# 'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
# 'required': ['v'],
# }
assert DefaultModel.model_json_schema() == {
'title': 'DefaultModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
'required': ['v'],
}


def test_default_union_types_left_to_right():
class DefaultModel(BaseModel):
v: Annotated[Union[int, bool, str], Field(union_mode='left_to_right')]

print(DefaultModel.__pydantic_core_schema__)

# int will coerce everything in left-to-right mode
assert repr(DefaultModel(v=True).v) == '1'
assert repr(DefaultModel(v=1).v) == '1'
assert repr(DefaultModel(v='1').v) == '1'

assert DefaultModel.model_json_schema() == {
'title': 'DefaultModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
'required': ['v'],
}


def test_union_enum_int_left_to_right():
class BinaryEnum(IntEnum):
ZERO = 0
ONE = 1

# int will win over enum in this case
assert TypeAdapter(Union[BinaryEnum, int]).validate_python(0) is not BinaryEnum.ZERO

# in left to right mode, enum will validate successfully and take precedence
assert (
TypeAdapter(Annotated[Union[BinaryEnum, int], Field(union_mode='left_to_right')]).validate_python(0)
is BinaryEnum.ZERO
)


def test_union_uuid_str_left_to_right():
IdOrSlug = Union[UUID, str]

# in smart mode JSON and python are currently validated differently in this
# case, because in Python this is a str but in JSON a str is also a UUID
assert TypeAdapter(IdOrSlug).validate_json('\"f4fe10b4-e0c8-4232-ba26-4acd491c2414\"') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)
assert (
TypeAdapter(IdOrSlug).validate_python('f4fe10b4-e0c8-4232-ba26-4acd491c2414')
== 'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)

IdOrSlugLTR = Annotated[Union[UUID, str], Field(union_mode='left_to_right')]

# in left to right mode both JSON and python are validated as UUID
assert TypeAdapter(IdOrSlugLTR).validate_json('\"f4fe10b4-e0c8-4232-ba26-4acd491c2414\"') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)
assert TypeAdapter(IdOrSlugLTR).validate_python('f4fe10b4-e0c8-4232-ba26-4acd491c2414') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)


def test_default_union_class():
Expand Down

0 comments on commit 84d6f6d

Please sign in to comment.