Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it harder to hit collisions with json schema defrefs #6566

Merged
merged 8 commits into from Jul 26, 2023
10 changes: 6 additions & 4 deletions pydantic/json_schema.py
Expand Up @@ -1748,7 +1748,9 @@ def normalize_name(self, name: str) -> str:
Returns:
The normalized name.
"""
return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name).replace('.', '__')
normalized = re.sub(r'[^a-zA-Z0-9.\-_]', '_', name)
normalized = re.sub('(__+)', r'_\1', normalized) # convert any double-or-more underscores to be triple-or-more
return normalized.replace('.', '__') # use double underscores where periods would have been

def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef:
"""Override this method to change the way that definitions keys are generated from a core reference.
Expand All @@ -1775,17 +1777,17 @@ def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef:
# be generated for any other core_ref. Currently, this should be the case because we include
# the id of the source type in the core_ref
name = DefsRef(self.normalize_name(short_ref))
name_mode = DefsRef(self.normalize_name(short_ref + mode_title))
name_mode = DefsRef(self.normalize_name(short_ref) + f'__{mode_title}')
module_qualname = DefsRef(self.normalize_name(core_ref_no_id))
module_qualname_mode = DefsRef(module_qualname + mode_title)
module_qualname_mode = DefsRef(f'{module_qualname}__{mode_title}')
module_qualname_id = DefsRef(self.normalize_name(core_ref))
occurrence_index = self._collision_index.get(module_qualname_id)
if occurrence_index is None:
self._collision_counter[module_qualname] += 1
occurrence_index = self._collision_index[module_qualname_id] = self._collision_counter[module_qualname]

module_qualname_occurrence = DefsRef(f'{module_qualname}__{occurrence_index}')
module_qualname_occurrence_mode = DefsRef(f'{module_qualname}{mode_title}__{occurrence_index}')
module_qualname_occurrence_mode = DefsRef(f'{module_qualname_mode}__{occurrence_index}')

self._prioritized_defsref_choices[module_qualname_occurrence_mode] = [
name,
Expand Down
194 changes: 178 additions & 16 deletions tests/test_json_schema.py
Expand Up @@ -49,6 +49,7 @@
ValidationError,
WithJsonSchema,
computed_field,
field_serializer,
field_validator,
)
from pydantic._internal._core_metadata import CoreMetadataHandler, build_metadata_dict
Expand Down Expand Up @@ -2667,16 +2668,16 @@ class NestedModel(BaseModel):
)
model_names = set(schema['$defs'].keys())
expected_model_names = {
'ModelOneInput',
'ModelOneOutput',
'ModelTwoInput',
'ModelTwoOutput',
f'{module.__name__}__ModelOne__NestedModelInput',
f'{module.__name__}__ModelOne__NestedModelOutput',
f'{module.__name__}__ModelTwo__NestedModelInput',
f'{module.__name__}__ModelTwo__NestedModelOutput',
f'{module.__name__}__NestedModelInput',
f'{module.__name__}__NestedModelOutput',
'ModelOne__Input',
'ModelOne__Output',
'ModelTwo__Input',
'ModelTwo__Output',
f'{module.__name__}__ModelOne__NestedModel__Input',
f'{module.__name__}__ModelOne__NestedModel__Output',
f'{module.__name__}__ModelTwo__NestedModel__Input',
f'{module.__name__}__ModelTwo__NestedModel__Output',
f'{module.__name__}__NestedModel__Input',
f'{module.__name__}__NestedModel__Output',
}
assert model_names == expected_model_names

Expand Down Expand Up @@ -2733,6 +2734,167 @@ class Model(BaseModel):
}


def test_mode_name_causes_no_conflict():
class Organization(BaseModel):
pass

class OrganizationInput(BaseModel):
pass

class OrganizationOutput(BaseModel):
pass

class Model(BaseModel):
# Ensure the validation and serialization schemas are different:
x: Organization = Field(validation_alias='x_validation', serialization_alias='x_serialization')
y: OrganizationInput
z: OrganizationOutput

assert Model.model_json_schema(mode='validation') == {
'$defs': {
'Organization': {'properties': {}, 'title': 'Organization', 'type': 'object'},
'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'},
'OrganizationOutput': {'properties': {}, 'title': 'OrganizationOutput', 'type': 'object'},
},
'properties': {
'x_validation': {'$ref': '#/$defs/Organization'},
'y': {'$ref': '#/$defs/OrganizationInput'},
'z': {'$ref': '#/$defs/OrganizationOutput'},
},
'required': ['x_validation', 'y', 'z'],
'title': 'Model',
'type': 'object',
}
assert Model.model_json_schema(mode='serialization') == {
'$defs': {
'Organization': {'properties': {}, 'title': 'Organization', 'type': 'object'},
'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'},
'OrganizationOutput': {'properties': {}, 'title': 'OrganizationOutput', 'type': 'object'},
},
'properties': {
'x_serialization': {'$ref': '#/$defs/Organization'},
'y': {'$ref': '#/$defs/OrganizationInput'},
'z': {'$ref': '#/$defs/OrganizationOutput'},
},
'required': ['x_serialization', 'y', 'z'],
'title': 'Model',
'type': 'object',
}


def test_ref_conflict_resolution_without_mode_difference():
class OrganizationInput(BaseModel):
pass

class Organization(BaseModel):
x: int

schema_with_defs, defs = GenerateJsonSchema().generate_definitions(
[
(Organization, 'validation', Organization.__pydantic_core_schema__),
(Organization, 'serialization', Organization.__pydantic_core_schema__),
(OrganizationInput, 'validation', OrganizationInput.__pydantic_core_schema__),
]
)
assert schema_with_defs == {
(Organization, 'serialization'): {'$ref': '#/$defs/Organization'},
(Organization, 'validation'): {'$ref': '#/$defs/Organization'},
(OrganizationInput, 'validation'): {'$ref': '#/$defs/OrganizationInput'},
}

assert defs == {
'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'},
'Organization': {
'properties': {'x': {'title': 'X', 'type': 'integer'}},
'required': ['x'],
'title': 'Organization',
'type': 'object',
},
}


def test_ref_conflict_resolution_with_mode_difference():
class OrganizationInput(BaseModel):
pass

class Organization(BaseModel):
x: int

@field_serializer('x')
def serialize_x(self, v: int) -> str:
return str(v)

schema_with_defs, defs = GenerateJsonSchema().generate_definitions(
[
(Organization, 'validation', Organization.__pydantic_core_schema__),
(Organization, 'serialization', Organization.__pydantic_core_schema__),
(OrganizationInput, 'validation', OrganizationInput.__pydantic_core_schema__),
]
)
assert schema_with_defs == {
(Organization, 'serialization'): {'$ref': '#/$defs/Organization__Output'},
(Organization, 'validation'): {'$ref': '#/$defs/Organization__Input'},
(OrganizationInput, 'validation'): {'$ref': '#/$defs/OrganizationInput'},
}

assert defs == {
'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'},
'Organization__Input': {
'properties': {'x': {'title': 'X', 'type': 'integer'}},
'required': ['x'],
'title': 'Organization',
'type': 'object',
},
'Organization__Output': {
'properties': {'x': {'title': 'X', 'type': 'string'}},
'required': ['x'],
'title': 'Organization',
'type': 'object',
},
}


def test_conflicting_names():
class Organization__Input(BaseModel):
pass

class Organization(BaseModel):
x: int

@field_serializer('x')
def serialize_x(self, v: int) -> str:
return str(v)

schema_with_defs, defs = GenerateJsonSchema().generate_definitions(
[
(Organization, 'validation', Organization.__pydantic_core_schema__),
(Organization, 'serialization', Organization.__pydantic_core_schema__),
(Organization__Input, 'validation', Organization__Input.__pydantic_core_schema__),
]
)
assert schema_with_defs == {
(Organization, 'serialization'): {'$ref': '#/$defs/Organization__Output'},
(Organization, 'validation'): {'$ref': '#/$defs/Organization__Input'},
(Organization__Input, 'validation'): {'$ref': '#/$defs/Organization___Input'},
}

assert defs == {
'Organization___Input': {'properties': {}, 'title': 'Organization__Input', 'type': 'object'},
'Organization__Input': {
'properties': {'x': {'title': 'X', 'type': 'integer'}},
'required': ['x'],
'title': 'Organization',
'type': 'object',
},
'Organization__Output': {
'properties': {'x': {'title': 'X', 'type': 'string'}},
'required': ['x'],
'title': 'Organization',
'type': 'object',
},
}


def test_schema_for_generic_field():
T = TypeVar('T')

Expand Down Expand Up @@ -4343,26 +4505,26 @@ class Outer(BaseModel):
_, vs_schema = models_json_schema([(Outer, 'validation'), (Outer, 'serialization')])
assert vs_schema == {
'$defs': {
'InnerInput': {
'Inner__Input': {
'properties': {'x': {'format': 'json-string', 'title': 'X', 'type': 'string'}},
'required': ['x'],
'title': 'Inner',
'type': 'object',
},
'InnerOutput': {
'Inner__Output': {
'properties': {'x': {'title': 'X', 'type': 'integer'}},
'required': ['x'],
'title': 'Inner',
'type': 'object',
},
'OuterInput': {
'properties': {'inner': {'$ref': '#/$defs/InnerInput'}},
'Outer__Input': {
'properties': {'inner': {'$ref': '#/$defs/Inner__Input'}},
'required': ['inner'],
'title': 'Outer',
'type': 'object',
},
'OuterOutput': {
'properties': {'inner': {'$ref': '#/$defs/InnerOutput'}},
'Outer__Output': {
'properties': {'inner': {'$ref': '#/$defs/Inner__Output'}},
'required': ['inner'],
'title': 'Outer',
'type': 'object',
Expand Down