Skip to content

Commit

Permalink
Enum validator improvements (#9045)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Mar 19, 2024
1 parent 243b8da commit b0dfc54
Showing 1 changed file with 21 additions and 35 deletions.
56 changes: 21 additions & 35 deletions pydantic/_internal/_std_types_schema.py
Expand Up @@ -14,6 +14,7 @@
from enum import Enum
from functools import partial
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from operator import attrgetter
from typing import Any, Callable, Iterable, TypeVar

import typing_extensions
Expand Down Expand Up @@ -63,13 +64,13 @@ def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchem
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}
js_updates = {'title': enum_type.__name__, 'description': description}
js_updates = {k: v for k, v in js_updates.items() if v is not None}

def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
update_json_schema(original_schema, js_updates)
return json_schema

if not cases:
Expand All @@ -90,47 +91,32 @@ def to_enum(input_value: Any, /) -> Enum:
try:
return enum_type(input_value)
except ValueError:
# The type: ignore on the next line is to ignore the requirement of LiteralString
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
raise PydanticCustomError('enum', 'Input should be {expected}', {'expected': expected})

strict_python_schema = core_schema.is_instance_schema(enum_type)

to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
if issubclass(enum_type, int):
# this handles `IntEnum`, and also `Foobar(int, Enum)`
updates['type'] = 'integer'
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
# Disallow float from JSON due to strict mode
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
python_schema=strict_python_schema,
)
js_updates['type'] = 'integer'
lax_schema = core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema())
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
updates['type'] = 'string'
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
python_schema=strict_python_schema,
)
js_updates['type'] = 'string'
lax_schema = core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema())
elif issubclass(enum_type, float):
updates['type'] = 'numeric'
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
python_schema=strict_python_schema,
)
js_updates['type'] = 'numeric'
lax_schema = core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema())
else:
lax = to_enum_validator
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
lax_schema = core_schema.no_info_plain_validator_function(to_enum)

enum_schema = core_schema.lax_or_strict_schema(
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
lax_schema=lax_schema,
strict_schema=core_schema.json_or_python_schema(
json_schema=lax_schema, python_schema=core_schema.is_instance_schema(enum_type)
),
ref=enum_ref,
metadata={'pydantic_js_functions': [get_json_schema]},
)
use_enum_values = config.get('use_enum_values', False)
if use_enum_values:
enum_schema = core_schema.chain_schema(
[enum_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
)
if config.get('use_enum_values', False):
enum_schema = core_schema.no_info_after_validator_function(attrgetter('value'), enum_schema)

return enum_schema

Expand Down

0 comments on commit b0dfc54

Please sign in to comment.