Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support __get_validators__ #7197

Merged
merged 2 commits into from Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 16 additions & 7 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -587,16 +587,25 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
else:
ref_mode = 'to-def'

schema: CoreSchema
get_schema = getattr(obj, '__get_pydantic_core_schema__', None)
if get_schema is None:
return None

schema: CoreSchema
if len(inspect.signature(get_schema).parameters) == 1:
# (source) -> CoreSchema
schema = get_schema(source)
validators = getattr(obj, '__get_validators__', None)
if validators is None:
return None
warn(
'`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
PydanticDeprecatedSince20,
)
schema = core_schema.chain_schema([core_schema.general_plain_validator_function(v) for v in validators()])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know whether general_plain_validator_function is the right choice

else:
schema = get_schema(source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode))
if len(inspect.signature(get_schema).parameters) == 1:
# (source) -> CoreSchema
schema = get_schema(source)
else:
schema = get_schema(
source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode)
)

schema = self._unpack_refs_defs(schema)

Expand Down
58 changes: 57 additions & 1 deletion tests/test_deprecated.py
@@ -1,7 +1,7 @@
import platform
import re
import sys
from datetime import timedelta
from datetime import date, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, Iterable, List, Type
Expand Down Expand Up @@ -523,6 +523,62 @@ def __get_validators__(cls) -> Iterable[Any]:
assert ta.json_schema() == {'anyOf': [{'type': 'string'}, {'type': 'number'}]}


def test_v1_get_validators():
class CustomDate(date):
@classmethod
def __get_validators__(cls):
yield cls.validate1
yield cls.validate2

@classmethod
def validate1(cls, v, i):
print(v)

if v.year < 2000:
raise ValueError('Invalid year')
return v

@classmethod
def validate2(cls, v, i):
return date.today().replace(month=1, day=1)

with pytest.warns(
PydanticDeprecatedSince20,
match='^`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
):

class Model(BaseModel):
x: CustomDate

with pytest.raises(ValidationError, match='Value error, Invalid year'):
Model(x=date(1999, 1, 1))

m = Model(x=date.today())
assert m.x.day == 1


def test_v1_get_validators_invalid_validator():
class InvalidValidator:
@classmethod
def __get_validators__(cls):
yield cls.has_wrong_arguments

@classmethod
def has_wrong_arguments(cls):
pass

with pytest.warns(
PydanticDeprecatedSince20,
match='^`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.',
):

class InvalidValidatorModel(BaseModel):
x: InvalidValidator

with pytest.raises(TypeError, match='takes 1 positional argument but 3 were given'):
InvalidValidatorModel(x=1)


def test_field_extra_arguments():
m = re.escape(
'Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. '
Expand Down
7 changes: 0 additions & 7 deletions tests/test_edge_cases.py
Expand Up @@ -1504,13 +1504,6 @@ def __init__(self, t1: T1, t2: T2):
self.t1 = t1
self.t2 = t2

@classmethod
def __get_validators__(cls):
def validator(v):
return v

yield validator


@pytest.mark.parametrize(
'type_,expected',
Expand Down