Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix json schema generation for recursive models #7653

Merged
merged 3 commits into from Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -557,6 +557,19 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
"""Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions`
and return the inner schema.
"""

def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
return schema

def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None:
"""Try to generate schema from either the `__get_pydantic_core_schema__` function or
`__pydantic_core_schema__` property.
Expand Down Expand Up @@ -593,6 +606,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode)
)

schema = self._unpack_refs_defs(schema)

ref = get_ref(schema)
if ref:
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
Expand Down
12 changes: 2 additions & 10 deletions tests/test_edge_cases.py
Expand Up @@ -2663,7 +2663,7 @@ class Outer(BaseModel):
b: Annotated[Union[Root1, Root2], Field(discriminator='kind')]

validated = Outer.model_validate({'a': {'kind': '1', 'two': None}, 'b': {'kind': '2', 'one': None}})
assert validated == Outer(a=Root1(Model1(two=None)), b=Root2(Model2(one=None)))
assert validated == Outer(a=Root1(root=Model1(two=None)), b=Root2(root=Model2(one=None)))

assert Outer.model_json_schema() == {
'$defs': {
Expand All @@ -2686,15 +2686,7 @@ class Outer(BaseModel):
'type': 'object',
},
'Root1': {'allOf': [{'$ref': '#/$defs/Model1'}], 'title': 'Root1'},
'Root2': {
'properties': {
'kind': {'const': '2', 'default': '2', 'title': 'Kind'},
'one': {'anyOf': [{'$ref': '#/$defs/Model1'}, {'type': 'null'}]},
},
'required': ['one'],
'title': 'Model2',
'type': 'object',
},
'Root2': {'allOf': [{'$ref': '#/$defs/Model2'}], 'title': 'Root2'},
},
'properties': {
'a': {
Expand Down
29 changes: 29 additions & 0 deletions tests/test_json_schema.py
Expand Up @@ -5627,3 +5627,32 @@ class SerializationModel(Model):
# Ensure the submodels' JSON schemas match the expected mode even when the opposite value is specified:
assert ValidationModel.model_json_schema(mode='serialization') == Model.model_json_schema(mode='validation')
assert SerializationModel.model_json_schema(mode='validation') == Model.model_json_schema(mode='serialization')


def test_recursive_non_generic_model() -> None:
class Foo(BaseModel):
maybe_bar: Union[None, 'Bar']

class Bar(BaseModel):
foo: Foo

# insert_assert(Bar(foo=Foo(maybe_bar=None)).model_dump())
assert Bar.model_validate({'foo': {'maybe_bar': None}}).model_dump() == {'foo': {'maybe_bar': None}}
# insert_assert(Bar.model_json_schema())
assert Bar.model_json_schema() == {
'$defs': {
'Bar': {
'properties': {'foo': {'$ref': '#/$defs/Foo'}},
'required': ['foo'],
'title': 'Bar',
'type': 'object',
},
'Foo': {
'properties': {'maybe_bar': {'anyOf': [{'$ref': '#/$defs/Bar'}, {'type': 'null'}]}},
'required': ['maybe_bar'],
'title': 'Foo',
'type': 'object',
},
},
'allOf': [{'$ref': '#/$defs/Bar'}],
}