Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make union case tags affect union error messages #8001

Merged
merged 6 commits into from Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 82 additions & 3 deletions docs/concepts/unions.md
Expand Up @@ -11,7 +11,9 @@ To solve these problems, Pydantic supports three fundamental approaches to valid
2. [smart mode](#smart-mode) - as with "left to right mode" all members are tried, but strict validation is used to try to find the best match
3. [discriminated unions]() - only one member of the union is tried, based on a discriminator

## Left to Right Mode
## Union Modes

### Left to Right Mode

!!! note
Because this mode often leads to unexpected validation results, it is not the default in Pydantic >=2, instead `union_mode='smart'` is the default.
Expand Down Expand Up @@ -72,7 +74,7 @@ print(User(id='456')) # (2)
2. We're in lax mode and the numeric string `'123'` is valid as input to the first member of the union, `int`.
Since that is tried first, we get the surprising result of `id` being an `int` instead of a `str`.

## Smart Mode
### Smart Mode

Because of the surprising side effects of `union_mode='left_to_right'`, in Pydantic >=2 the default mode for `Union` validation is `union_mode='smart'`.

Expand Down Expand Up @@ -513,4 +515,81 @@ assert m == DiscriminatedModel(
assert m.model_dump() == data
```

You can also simplify error messages with a custom error, like this:
You can also simplify error messages by labeling each case with a `Tag`. This is especially useful
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
when you have complex types like those in this example:

```py
from typing import Dict, List, Union

from typing_extensions import Annotated

from pydantic import AfterValidator, Tag, TypeAdapter, ValidationError

DoubledList = Annotated[List[int], AfterValidator(lambda x: x * 2)]
StringsMap = Dict[str, str]


# Not using any `Tag`s for each union case, the errors are not so nice to look at
adapter = TypeAdapter(Union[DoubledList, StringsMap])

try:
adapter.validate_python(['a'])
except ValidationError as exc_info:
assert (
'2 validation errors for union[function-after[<lambda>(), list[int]],dict[str,str]]'
in str(exc_info)
)

# the loc's are bad here:
assert exc_info.errors() == [
{
'input': 'a',
'loc': ('function-after[<lambda>(), list[int]]', 0),
'msg': 'Input should be a valid integer, unable to parse string as an '
'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('dict[str,str]',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]


tag_adapter = TypeAdapter(
Union[
Annotated[DoubledList, Tag('DoubledList')],
Annotated[StringsMap, Tag('StringsMap')],
]
)

try:
tag_adapter.validate_python(['a'])
except ValidationError as exc_info:
assert '2 validation errors for union[DoubledList,StringsMap]' in str(
exc_info
)

# the loc's are good here:
assert exc_info.errors() == [
{
'input': 'a',
'loc': ('DoubledList', 0),
'msg': 'Input should be a valid integer, unable to parse string as an '
'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('StringsMap',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]
```
2 changes: 1 addition & 1 deletion docs/migration.md
Expand Up @@ -521,7 +521,7 @@ In Pydantic V1, the printed result would have been `x=1`, since the value would
In Pydantic V2, we recognize that the value is an instance of one of the cases and short-circuit the standard union validation.

To revert to the non-short-circuiting left-to-right behavior of V1, annotate the union with `Field(union_mode='left_to_right')`.
See [Union Mode](./api/standard_library_types.md#union-mode) for more details.
See [Union Mode](./concepts/unions.md#union-modes) for more details.

#### Required, optional, and nullable fields

Expand Down
18 changes: 13 additions & 5 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -42,7 +42,7 @@
from ..json_schema import JsonSchemaValue
from ..version import version_short
from ..warnings import PydanticDeprecatedSince20
from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from . import _core_utils, _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import CoreMetadataHandler, build_metadata_dict
from ._core_utils import (
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def json_schema_update_func(schema: CoreSchemaOrField, handler: GetJsonSchemaHan
def _union_schema(self, union_type: Any) -> core_schema.CoreSchema:
"""Generate schema for a Union."""
args = self._get_args_resolving_forward_refs(union_type, required=True)
choices: list[CoreSchema | tuple[CoreSchema, str]] = []
choices: list[CoreSchema] = []
nullable = False
for arg in args:
if arg is None or arg is _typing_extra.NoneType:
Expand All @@ -1042,10 +1042,18 @@ def _union_schema(self, union_type: Any) -> core_schema.CoreSchema:
choices.append(self.generate_schema(arg))

if len(choices) == 1:
first_choice = choices[0]
s = first_choice[0] if isinstance(first_choice, tuple) else first_choice
s = choices[0]
else:
s = core_schema.union_schema(choices)
choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = []
for choice in choices:
metadata = choice.get('metadata')
if isinstance(metadata, dict):
tag = metadata.get(_core_utils.TAGGED_UNION_TAG_KEY)
if tag is not None:
choices_with_tags.append((choice, tag))
else:
choices_with_tags.append(choice)
s = core_schema.union_schema(choices_with_tags)

if nullable:
s = core_schema.nullable_schema(s)
Expand Down
6 changes: 4 additions & 2 deletions pydantic/types.py
Expand Up @@ -2442,6 +2442,8 @@ def __getattr__(self, item: str) -> Any:
class Tag:
"""Provides a way to specify the expected tag to use for a case with a callable discriminated union.
Also provides a way to label a union case in error messages.
When using a `CallableDiscriminator`, attach a `Tag` to each case in the `Union` to specify the tag that
should be used to identify that case. For example, in the below example, the `Tag` is used to specify that
if `get_discriminator_value` returns `'apple'`, the input should be validated as an `ApplePie`, and if it
Expand Down Expand Up @@ -2509,7 +2511,7 @@ class ThanksgivingDinner(BaseModel):
Failing to do so will result in a `PydanticUserError` with code
[`callable-discriminator-no-tag`](../errors/usage_errors.md#callable-discriminator-no-tag).
See the [Discriminated Unions](../api/standard_library_types.md#discriminated-unions-aka-tagged-unions)
See the [Discriminated Unions](../concepts/unions.md#discriminated-unions)
docs for more details on how to use `Tag`s.
"""

Expand Down Expand Up @@ -2590,7 +2592,7 @@ class ThanksgivingDinner(BaseModel):
'''
```
See the [Discriminated Unions](../api/standard_library_types.md#discriminated-unions-aka-tagged-unions)
See the [Discriminated Unions](../concepts/unions.md#discriminated-unions)
docs for more details on how to use `CallableDiscriminator`s.
"""

Expand Down
56 changes: 56 additions & 0 deletions tests/test_types.py
Expand Up @@ -47,6 +47,7 @@
UUID3,
UUID4,
UUID5,
AfterValidator,
AwareDatetime,
Base64Bytes,
Base64Str,
Expand Down Expand Up @@ -88,6 +89,7 @@
StrictFloat,
StrictInt,
StrictStr,
Tag,
TypeAdapter,
ValidationError,
conbytes,
Expand Down Expand Up @@ -6028,3 +6030,57 @@ class Model(BaseModel):
value: str

assert Model.model_validate_json(f'{{"value": {number}}}').model_dump() == {'value': expected_str}


def test_union_tags_in_errors():
DoubledList = Annotated[List[int], AfterValidator(lambda x: x * 2)]
StringsMap = Dict[str, str]

adapter = TypeAdapter(Union[DoubledList, StringsMap])

with pytest.raises(ValidationError) as exc_info:
adapter.validate_python(['a'])

assert '2 validation errors for union[function-after[<lambda>(), list[int]],dict[str,str]]' in str(exc_info) # yuck
# the loc's are bad here:
assert exc_info.value.errors() == [
{
'input': 'a',
'loc': ('function-after[<lambda>(), list[int]]', 0),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('dict[str,str]',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]

tag_adapter = TypeAdapter(
Union[Annotated[DoubledList, Tag('DoubledList')], Annotated[StringsMap, Tag('StringsMap')]]
)
with pytest.raises(ValidationError) as exc_info:
tag_adapter.validate_python(['a'])

assert '2 validation errors for union[DoubledList,StringsMap]' in str(exc_info) # nice
# the loc's are good here:
assert exc_info.value.errors() == [
{
'input': 'a',
'loc': ('DoubledList', 0),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('StringsMap',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]