Skip to content

Commit

Permalink
Fix json schema generation for recursive models (#7653)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 27, 2023
1 parent 861cfe3 commit 97c0199
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
15 changes: 15 additions & 0 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -557,6 +557,19 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
"""Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions`
and return the inner schema.
"""

def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
return schema

def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None:
"""Try to generate schema from either the `__get_pydantic_core_schema__` function or
`__pydantic_core_schema__` property.
Expand Down Expand Up @@ -593,6 +606,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode)
)

schema = self._unpack_refs_defs(schema)

ref = get_ref(schema)
if ref:
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
Expand Down
12 changes: 2 additions & 10 deletions tests/test_edge_cases.py
Expand Up @@ -2663,7 +2663,7 @@ class Outer(BaseModel):
b: Annotated[Union[Root1, Root2], Field(discriminator='kind')]

validated = Outer.model_validate({'a': {'kind': '1', 'two': None}, 'b': {'kind': '2', 'one': None}})
assert validated == Outer(a=Root1(Model1(two=None)), b=Root2(Model2(one=None)))
assert validated == Outer(a=Root1(root=Model1(two=None)), b=Root2(root=Model2(one=None)))

assert Outer.model_json_schema() == {
'$defs': {
Expand All @@ -2686,15 +2686,7 @@ class Outer(BaseModel):
'type': 'object',
},
'Root1': {'allOf': [{'$ref': '#/$defs/Model1'}], 'title': 'Root1'},
'Root2': {
'properties': {
'kind': {'const': '2', 'default': '2', 'title': 'Kind'},
'one': {'anyOf': [{'$ref': '#/$defs/Model1'}, {'type': 'null'}]},
},
'required': ['one'],
'title': 'Model2',
'type': 'object',
},
'Root2': {'allOf': [{'$ref': '#/$defs/Model2'}], 'title': 'Root2'},
},
'properties': {
'a': {
Expand Down
29 changes: 29 additions & 0 deletions tests/test_json_schema.py
Expand Up @@ -5627,3 +5627,32 @@ class SerializationModel(Model):
# Ensure the submodels' JSON schemas match the expected mode even when the opposite value is specified:
assert ValidationModel.model_json_schema(mode='serialization') == Model.model_json_schema(mode='validation')
assert SerializationModel.model_json_schema(mode='validation') == Model.model_json_schema(mode='serialization')


def test_recursive_non_generic_model() -> None:
class Foo(BaseModel):
maybe_bar: Union[None, 'Bar']

class Bar(BaseModel):
foo: Foo

# insert_assert(Bar(foo=Foo(maybe_bar=None)).model_dump())
assert Bar.model_validate({'foo': {'maybe_bar': None}}).model_dump() == {'foo': {'maybe_bar': None}}
# insert_assert(Bar.model_json_schema())
assert Bar.model_json_schema() == {
'$defs': {
'Bar': {
'properties': {'foo': {'$ref': '#/$defs/Foo'}},
'required': ['foo'],
'title': 'Bar',
'type': 'object',
},
'Foo': {
'properties': {'maybe_bar': {'anyOf': [{'$ref': '#/$defs/Bar'}, {'type': 'null'}]}},
'required': ['maybe_bar'],
'title': 'Foo',
'type': 'object',
},
},
'allOf': [{'$ref': '#/$defs/Bar'}],
}

0 comments on commit 97c0199

Please sign in to comment.