Skip to content

Commit

Permalink
Add source_type
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Oct 20, 2023
1 parent cbd8a7d commit 62ec577
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Expand Up @@ -170,7 +170,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
1 change: 1 addition & 0 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -502,6 +502,7 @@ def complete_model_class(

cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
cls_module if cls_module else cls.__module__,
cls.__qualname__,
'create_model' if cls_module else 'BaseModel',
Expand Down
18 changes: 16 additions & 2 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -46,12 +46,14 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
source_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
source_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
Expand All @@ -66,7 +68,13 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
core_config = config_wrapper.core_config(self)

self.__pydantic_validator__ = create_schema_validator(
schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings
schema,
source_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)

if self._validate_return:
Expand All @@ -79,7 +87,13 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(
schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings
schema,
source_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):

Expand Down
2 changes: 2 additions & 0 deletions pydantic/plugin/__init__.py
Expand Up @@ -27,6 +27,7 @@ class PydanticPluginProtocol(Protocol):
def new_schema_validator(
self,
schema: CoreSchema,
source_type: str,
type_path: str,
item_type: str,
config: CoreConfig | None,
Expand All @@ -42,6 +43,7 @@ def new_schema_validator(
Args:
schema: The schema to validate against.
source_type: The item to validate against.
type_path: The path of item to validate against.
item_type: The type of item to validate against.
config: The config to use for validation.
Expand Down
6 changes: 4 additions & 2 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -23,6 +23,7 @@ def build_type_path(module: str, name: str) -> str:

def create_schema_validator(
schema: CoreSchema,
source_type: Any,
module: str,
type_name: str,
item_type: str,
Expand All @@ -39,7 +40,7 @@ def create_schema_validator(
plugins = get_plugins()
if plugins:
type_path = build_type_path(module, type_name)
return PluggableSchemaValidator(schema, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore
return PluggableSchemaValidator(schema, source_type, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -52,6 +53,7 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
source_type: Any,
type_path: str,
item_type: str,
config: CoreConfig | None,
Expand All @@ -64,7 +66,7 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, type_path, item_type, config, plugin_settings)
p, j, s = plugin.new_schema_validator(schema, source_type, type_path, item_type, config, plugin_settings)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/type_adapter.py
Expand Up @@ -254,7 +254,7 @@ def __init__(
if module is None:
f = sys._getframe(1)
module = f.f_globals['__name__']
validator = create_schema_validator(core_schema, module, type, 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore
validator = create_schema_validator(core_schema, type, module, str(type), 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore

serializer: SchemaSerializer
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugin/example_plugin.py
Expand Up @@ -24,7 +24,7 @@ def on_error(self, error) -> None:


class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
return ValidatePythonHandler(), None, None


Expand Down
46 changes: 28 additions & 18 deletions tests/test_plugins.py
Expand Up @@ -44,9 +44,10 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class CustomPlugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert source_type.__name__ == 'Model'
assert type_path == 'tests.test_plugins:test_on_validate_json_on_success.<locals>.Model'
assert item_type == 'BaseModel'
return None, CustomOnValidateJson(), None
Expand Down Expand Up @@ -88,7 +89,7 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
return None, CustomOnValidateJson(), None
Expand Down Expand Up @@ -124,9 +125,10 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert source_type.__name__ == 'Model'
assert item_type == 'BaseModel'
return CustomOnValidatePython(), None, None

Expand Down Expand Up @@ -168,9 +170,10 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert source_type.__name__ == 'Model'
assert item_type == 'BaseModel'
return CustomOnValidatePython(), None, None

Expand Down Expand Up @@ -210,7 +213,7 @@ def on_exception(self, exception: Exception) -> None:
stack.pop()

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand Down Expand Up @@ -273,7 +276,7 @@ def on_error(self, error: ValidationError) -> None:
log.append(f'strings error error={error}')

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
return Python(), Json(), Strings()

plugin = Plugin()
Expand Down Expand Up @@ -306,7 +309,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert source_type.__name__ == 'Bar'
assert type_path == 'tests.test_plugins:test_plugin_path_dataclass.<locals>.Bar'
assert item_type == 'dataclass'
return CustomOnValidatePython(), None, None
Expand All @@ -324,7 +328,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert str(source_type) == 'typing.List[str]'
assert type_path == 'tests.test_plugins:typing.List[str]'
assert item_type == 'type_adapter'
return CustomOnValidatePython(), None, None
Expand All @@ -339,7 +344,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert str(source_type) == 'typing.List[str]'
assert type_path == 'provided_module_by_type_adapter:typing.List[str]'
assert item_type == 'type_adapter'
return CustomOnValidatePython(), None, None
Expand All @@ -354,7 +360,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin1:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert source_type.__name__ == 'foo'
assert type_path == 'tests.test_plugins:test_plugin_path_validate_call.<locals>.foo'
assert item_type == 'validate_call'
return CustomOnValidatePython(), None, None
Expand All @@ -367,7 +374,8 @@ def foo(a: int):
return a

class Plugin2:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert source_type.__name__ == 'my_wrapped_function'
assert (
type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call.<locals>.my_wrapped_function)'
)
Expand All @@ -389,7 +397,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
assert source_type.__name__ == 'FooModel'
assert list(source_type.model_fields.keys()) == ['foo', 'bar']
assert type_path == 'tests.test_plugins:FooModel'
assert item_type == 'create_model'
return CustomOnValidatePython(), None, None
Expand All @@ -406,25 +416,25 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
paths.append((type_path, item_type))
def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings):
paths.append((source_type.__name__, type_path, item_type))
return CustomOnValidatePython(), None, None

plugin = Plugin()
with install_plugin(plugin):

def foo():
class Model(BaseModel):
class Model1(BaseModel):
pass

def bar():
class Model(BaseModel):
class Model2(BaseModel):
pass

foo()
bar()

assert paths == [
('tests.test_plugins:test_plugin_path_complex.<locals>.foo.<locals>.Model', 'BaseModel'),
('tests.test_plugins:test_plugin_path_complex.<locals>.bar.<locals>.Model', 'BaseModel'),
('Model1', 'tests.test_plugins:test_plugin_path_complex.<locals>.foo.<locals>.Model1', 'BaseModel'),
('Model2', 'tests.test_plugins:test_plugin_path_complex.<locals>.bar.<locals>.Model2', 'BaseModel'),
]

0 comments on commit 62ec577

Please sign in to comment.