Skip to content

Commit

Permalink
Fix models_json_schema for generic models (#7654)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
Co-authored-by: Hasan Ramezani <hasan.r67@gmail.com>
  • Loading branch information
3 people committed Sep 27, 2023
1 parent 97c0199 commit ea9aa13
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pydantic/_internal/_repr.py
Expand Up @@ -101,7 +101,10 @@ def display_as_type(obj: Any) -> str:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'Union[{args}]'
elif isinstance(obj, _typing_extra.WithArgsTypes):
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
else:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'{obj.__qualname__}[{args}]'
elif isinstance(obj, type):
return obj.__qualname__
Expand Down
4 changes: 4 additions & 0 deletions pydantic/json_schema.py
Expand Up @@ -2180,6 +2180,10 @@ def models_json_schema(
- The second element is a JSON schema containing all definitions referenced in the first returned
element, along with the optional title and description keys.
"""
for cls, _ in models:
if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer):
cls.__pydantic_validator__.rebuild()

instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
inputs = [(m, mode, m.__pydantic_core_schema__) for m, mode in models]
json_schemas_map, definitions = instance.generate_definitions(inputs)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_json_schema.py
Expand Up @@ -5629,6 +5629,46 @@ class SerializationModel(Model):
assert SerializationModel.model_json_schema(mode='validation') == Model.model_json_schema(mode='serialization')


def test_models_json_schema_generics() -> None:
class G(BaseModel, Generic[T]):
foo: T

class M(BaseModel):
foo: Literal['a', 'b']

GLiteral = G[Literal['a', 'b']]

assert models_json_schema(
[
(GLiteral, 'serialization'),
(GLiteral, 'validation'),
(M, 'validation'),
]
) == (
{
(GLiteral, 'serialization'): {'$ref': '#/$defs/G_Literal__a____b___'},
(GLiteral, 'validation'): {'$ref': '#/$defs/G_Literal__a____b___'},
(M, 'validation'): {'$ref': '#/$defs/M'},
},
{
'$defs': {
'G_Literal__a____b___': {
'properties': {'foo': {'enum': ['a', 'b'], 'title': 'Foo', 'type': 'string'}},
'required': ['foo'],
'title': "G[Literal['a', 'b']]",
'type': 'object',
},
'M': {
'properties': {'foo': {'enum': ['a', 'b'], 'title': 'Foo', 'type': 'string'}},
'required': ['foo'],
'title': 'M',
'type': 'object',
},
}
},
)


def test_recursive_non_generic_model() -> None:
class Foo(BaseModel):
maybe_bar: Union[None, 'Bar']
Expand Down

0 comments on commit ea9aa13

Please sign in to comment.