Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix strict application to function-after with use_enum_values #9279

Merged
merged 7 commits into from Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions pydantic/_internal/_known_annotated_metadata.py
Expand Up @@ -195,6 +195,12 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]

# if it becomes necessary to handle more than one constraint
# in this recursive case with function-after or function-wrap, we should refactor
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function-after schema
return schema

if schema_type in allowed_schemas:
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
Expand Down
16 changes: 16 additions & 0 deletions tests/test_types.py
Expand Up @@ -6658,3 +6658,19 @@ def test_can_serialize_deque_passed_to_sequence() -> None:

assert ta.dump_python(my_dec) == my_dec
assert ta.dump_json(my_dec) == b'[1,2,3]'


def test_strict_enum_with_use_enum_values() -> None:
class SomeEnum(int, Enum):
SOME_KEY = 1

class Foo(BaseModel):
model_config = ConfigDict(strict=False, use_enum_values=True)
foo: Annotated[SomeEnum, Strict(strict=True)]

f = Foo(foo=SomeEnum.SOME_KEY)
assert f.foo == 1

# validation error raised bc foo field uses strict mode
with pytest.raises(ValidationError):
Foo(foo='1')