Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add union_mode='left_to_right' #7151

Merged
merged 1 commit into from Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/migration.md
Expand Up @@ -506,6 +506,9 @@ print(Model(x='1'))
In Pydantic V1, the printed result would have been `x=1`, since the value would pass validation as an `int`.
In Pydantic V2, we recognize that the value is an instance of one of the cases and short-circuit the standard union validation.

To revert to the non-short-circuiting left-to-right behavior of V1, annotate the union with `Field(union_mode='left_to_right')`.
See [Union Mode](./usage/types/unions.md#union-mode) for more details.
Comment on lines +509 to +510
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not an expert on V1's union validation, this statement may not be true. I thought it might be worth calling out this or something similar here in the migration guide.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's right


#### Required, optional, and nullable fields

Pydantic V2 changes some of the logic for specifying whether a field annotated as `Optional` is required
Expand Down
56 changes: 55 additions & 1 deletion docs/usage/types/unions.md
Expand Up @@ -43,7 +43,61 @@ print(user_03_uuid.int)

See more details in [Required fields](../models.md#required-fields).

#### Discriminated Unions (a.k.a. Tagged Unions)
#### Union Mode

By default `Union` validation will try to return the variant which is the best match for the input.

Consider for example the case of `Union[int, str]`. When [`strict` mode](../strict_mode.md) is not enabled
then `int` fields will accept `str` inputs. In the example below, the `id` field (which is `Union[int, str]`)
will accept the string `'123'` as an input, and preserve it as a string:

```py
from typing import Union

from pydantic import BaseModel


class User(BaseModel):
id: Union[int, str]
age: int


print(User(id='123', age='45'))
#> id='123' age=45

print(type(User(id='123', age='45').id))
#> <class 'str'>
```

This is known as `'smart'` mode for `Union` validation.

At present only one other `Union` validation mode exists, called `'left_to_right'` validation. In this mode
variants are attempted from left to right and the first successful validation is accepted as input.

Consider the same example, this time with `union_mode='left_to_right'` set as a [`Field`](../fields.md)
parameter on `id`. With this validation mode, the `int` variant will coerce strings of digits into `int`
values:

```py
from typing import Union

from pydantic import BaseModel, Field


class User(BaseModel):
id: Union[int, str] = Field(..., union_mode='left_to_right')
age: int


print(User(id='123', age='45'))
#> id=123 age=45


print(type(User(id='123', age='45').id))
#> <class 'int'>
```

### Discriminated Unions (a.k.a. Tagged Unions)

When `Union` is used with multiple submodels, you sometimes know exactly which submodel needs to
be checked and validated and want to enforce this.
Expand Down
8 changes: 7 additions & 1 deletion pydantic/_internal/_known_annotated_metadata.py
Expand Up @@ -35,6 +35,7 @@
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}

UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
Expand Down Expand Up @@ -75,6 +76,8 @@
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in UNION_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
Expand Down Expand Up @@ -147,7 +150,10 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]

if schema_type in allowed_schemas:
schema[constraint] = value
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
else:
schema[constraint] = value
continue

if constraint == 'allow_inf_nan' and value is False:
Expand Down
8 changes: 7 additions & 1 deletion pydantic/fields.py
Expand Up @@ -13,7 +13,7 @@
import annotated_types
import typing_extensions
from pydantic_core import PydanticUndefined
from typing_extensions import Unpack
from typing_extensions import Literal, Unpack

from . import types
from ._internal import _decorators, _fields, _generics, _internal_dataclass, _repr, _typing_extra, _utils
Expand Down Expand Up @@ -56,6 +56,7 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
allow_inf_nan: bool | None
max_digits: int | None
decimal_places: int | None
union_mode: Literal['smart', 'left_to_right'] | None
discriminator: str | None
json_schema_extra: dict[str, Any] | typing.Callable[[dict[str, Any]], None] | None
frozen: bool | None
Expand Down Expand Up @@ -161,6 +162,7 @@ class FieldInfo(_repr.Representation):
'allow_inf_nan': None,
'max_digits': None,
'decimal_places': None,
'union_mode': None,
}

def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
Expand Down Expand Up @@ -692,6 +694,7 @@ def Field( # noqa: C901
decimal_places: int | None = _Unset,
min_length: int | None = _Unset,
max_length: int | None = _Unset,
union_mode: Literal['smart', 'left_to_right'] = _Unset,
**extra: Unpack[_EmptyKwargs],
) -> Any:
"""Usage docs: https://docs.pydantic.dev/dev-v2/usage/fields
Expand Down Expand Up @@ -737,6 +740,8 @@ def Field( # noqa: C901
allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to numbers.
max_digits: Maximum number of allow digits for strings.
decimal_places: Maximum number of decimal places allowed for numbers.
union_mode: The strategy to apply when validating a union. Can be `smart` (the default), or `left_to_right`.
See [Union Mode](../usage/types/unions.md#union-mode) for details.
extra: Include extra fields used by the JSON schema.

!!! warning Deprecated
Expand Down Expand Up @@ -838,6 +843,7 @@ def Field( # noqa: C901
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
union_mode=union_mode,
)


Expand Down
70 changes: 64 additions & 6 deletions tests/test_types.py
Expand Up @@ -4681,12 +4681,70 @@ class DefaultModel(BaseModel):
assert repr(DefaultModel(v=1).v) == '1'
assert repr(DefaultModel(v='1').v) == "'1'"

# assert DefaultModel.model_json_schema() == {
# 'title': 'DefaultModel',
# 'type': 'object',
# 'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
# 'required': ['v'],
# }
assert DefaultModel.model_json_schema() == {
'title': 'DefaultModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
'required': ['v'],
}


def test_default_union_types_left_to_right():
class DefaultModel(BaseModel):
v: Annotated[Union[int, bool, str], Field(union_mode='left_to_right')]

print(DefaultModel.__pydantic_core_schema__)

# int will coerce everything in left-to-right mode
assert repr(DefaultModel(v=True).v) == '1'
assert repr(DefaultModel(v=1).v) == '1'
assert repr(DefaultModel(v='1').v) == '1'

assert DefaultModel.model_json_schema() == {
'title': 'DefaultModel',
'type': 'object',
'properties': {'v': {'title': 'V', 'anyOf': [{'type': t} for t in ('integer', 'boolean', 'string')]}},
'required': ['v'],
}


def test_union_enum_int_left_to_right():
class BinaryEnum(IntEnum):
ZERO = 0
ONE = 1

# int will win over enum in this case
assert TypeAdapter(Union[BinaryEnum, int]).validate_python(0) is not BinaryEnum.ZERO

# in left to right mode, enum will validate successfully and take precedence
assert (
TypeAdapter(Annotated[Union[BinaryEnum, int], Field(union_mode='left_to_right')]).validate_python(0)
is BinaryEnum.ZERO
)


def test_union_uuid_str_left_to_right():
IdOrSlug = Union[UUID, str]

# in smart mode JSON and python are currently validated differently in this
# case, because in Python this is a str but in JSON a str is also a UUID
assert TypeAdapter(IdOrSlug).validate_json('\"f4fe10b4-e0c8-4232-ba26-4acd491c2414\"') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)
assert (
TypeAdapter(IdOrSlug).validate_python('f4fe10b4-e0c8-4232-ba26-4acd491c2414')
== 'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)

IdOrSlugLTR = Annotated[Union[UUID, str], Field(union_mode='left_to_right')]

# in left to right mode both JSON and python are validated as UUID
assert TypeAdapter(IdOrSlugLTR).validate_json('\"f4fe10b4-e0c8-4232-ba26-4acd491c2414\"') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)
assert TypeAdapter(IdOrSlugLTR).validate_python('f4fe10b4-e0c8-4232-ba26-4acd491c2414') == UUID(
'f4fe10b4-e0c8-4232-ba26-4acd491c2414'
)


def test_default_union_class():
Expand Down