Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make union case tags affect union error messages #8001

Merged
merged 6 commits into from Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -42,7 +42,7 @@
from ..json_schema import JsonSchemaValue
from ..version import version_short
from ..warnings import PydanticDeprecatedSince20
from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from . import _core_utils, _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import CoreMetadataHandler, build_metadata_dict
from ._core_utils import (
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def json_schema_update_func(schema: CoreSchemaOrField, handler: GetJsonSchemaHan
def _union_schema(self, union_type: Any) -> core_schema.CoreSchema:
"""Generate schema for a Union."""
args = self._get_args_resolving_forward_refs(union_type, required=True)
choices: list[CoreSchema | tuple[CoreSchema, str]] = []
choices: list[CoreSchema] = []
nullable = False
for arg in args:
if arg is None or arg is _typing_extra.NoneType:
Expand All @@ -1042,10 +1042,18 @@ def _union_schema(self, union_type: Any) -> core_schema.CoreSchema:
choices.append(self.generate_schema(arg))

if len(choices) == 1:
first_choice = choices[0]
s = first_choice[0] if isinstance(first_choice, tuple) else first_choice
s = choices[0]
else:
s = core_schema.union_schema(choices)
choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = []
for choice in choices:
metadata = choice.get('metadata')
if isinstance(metadata, dict):
tag = metadata.get(_core_utils.TAGGED_UNION_TAG_KEY)
if tag is not None:
choices_with_tags.append((choice, tag))
else:
choices_with_tags.append(choice)
s = core_schema.union_schema(choices_with_tags)

if nullable:
s = core_schema.nullable_schema(s)
Expand Down
2 changes: 2 additions & 0 deletions pydantic/types.py
Expand Up @@ -2444,6 +2444,8 @@ def __getattr__(self, item: str) -> Any:
class Tag:
"""Provides a way to specify the expected tag to use for a case with a callable discriminated union.

Also provides a way to label a union case in error messages.

When using a `CallableDiscriminator`, attach a `Tag` to each case in the `Union` to specify the tag that
should be used to identify that case. For example, in the below example, the `Tag` is used to specify that
if `get_discriminator_value` returns `'apple'`, the input should be validated as an `ApplePie`, and if it
Expand Down
57 changes: 57 additions & 0 deletions tests/test_types.py
Expand Up @@ -15,6 +15,7 @@
from numbers import Number
from pathlib import Path
from typing import (
Annotated,
Any,
Callable,
Counter,
Expand Down Expand Up @@ -47,6 +48,7 @@
UUID3,
UUID4,
UUID5,
AfterValidator,
AwareDatetime,
Base64Bytes,
Base64Str,
Expand Down Expand Up @@ -88,6 +90,7 @@
StrictFloat,
StrictInt,
StrictStr,
Tag,
TypeAdapter,
ValidationError,
conbytes,
Expand Down Expand Up @@ -6025,3 +6028,57 @@ class Model(BaseModel):
value: str

assert Model.model_validate_json(f'{{"value": {number}}}').model_dump() == {'value': expected_str}


def test_union_tags_in_errors():
DoubledList = Annotated[list[int], AfterValidator(lambda x: x * 2)]
StringsMap = dict[str, str]

adapter = TypeAdapter(Union[DoubledList, StringsMap])

with pytest.raises(ValidationError) as exc_info:
adapter.validate_python(['a'])

assert '2 validation errors for union[function-after[<lambda>(), list[int]],dict[str,str]]' in str(exc_info) # yuck
# the loc's are bad here:
assert exc_info.value.errors() == [
{
'input': 'a',
'loc': ('function-after[<lambda>(), list[int]]', 0),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('dict[str,str]',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]

tag_adapter = TypeAdapter(
Union[Annotated[DoubledList, Tag('DoubledList')], Annotated[StringsMap, Tag('StringsMap')]]
)
with pytest.raises(ValidationError) as exc_info:
tag_adapter.validate_python(['a'])

assert '2 validation errors for union[DoubledList,StringsMap]' in str(exc_info) # nice
# the loc's are good here:
assert exc_info.value.errors() == [
{
'input': 'a',
'loc': ('DoubledList', 0),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
'url': 'https://errors.pydantic.dev/2.4/v/int_parsing',
},
{
'input': ['a'],
'loc': ('StringsMap',),
'msg': 'Input should be a valid dictionary',
'type': 'dict_type',
'url': 'https://errors.pydantic.dev/2.4/v/dict_type',
},
]