Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tweaks to _core_utils #7040

Merged
merged 3 commits into from Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 59 additions & 65 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -32,14 +32,6 @@
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}


def is_definition_ref_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionReferenceSchema]:
return s['type'] == 'definition-ref'


def is_definitions_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionsSchema]:
return s['type'] == 'definitions'


def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
Expand Down Expand Up @@ -192,14 +184,15 @@ def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchem

def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema, f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization', None) # type: ignore
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema.copy(), f)
return schema

def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
if 'schema' in schema:
schema['schema'] = self.walk(schema['schema'], f) # type: ignore
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema

def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
Expand Down Expand Up @@ -232,47 +225,55 @@ def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Wa
return new_schema

def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
return schema

def handle_tuple_variable_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TupleVariableSchema, schema)
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_tuple_positional_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TuplePositionalSchema, schema)
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
if 'extra_schema' in schema:
schema['extra_schema'] = self.walk(schema['extra_schema'], f)
extra_schema = schema.get('extra_schema')
if extra_schema is not None:
schema['extra_schema'] = self.walk(extra_schema, f)
return schema

def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
if 'keys_schema' in schema:
schema['keys_schema'] = self.walk(schema['keys_schema'], f)
if 'values_schema' in schema:
schema['values_schema'] = self.walk(schema['values_schema'], f)
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema

def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -307,11 +308,12 @@ def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f
return schema

def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
if 'extra_validator' in schema:
schema['extra_validator'] = self.walk(schema['extra_validator'], f)
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields') or ():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for computed_field in schema.get('computed_fields') or ():
for computed_field in schema.get('computed_fields', ()):

replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand All @@ -325,10 +327,11 @@ def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: W
return schema

def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
if 'extra_validator' in schema:
schema['extra_validator'] = self.walk(schema['extra_validator'], f)
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields') or ():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for computed_field in schema.get('computed_fields') or ():
for computed_field in schema.get('computed_fields', ()):

replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand All @@ -345,7 +348,7 @@ def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk)
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields') or ():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for computed_field in schema.get('computed_fields') or ():
for computed_field in schema.get('computed_fields', ()):

replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand Down Expand Up @@ -395,12 +398,11 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema.copy(), _dispatch)
return f(schema, _dispatch)


def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901
valid_defs: dict[str, core_schema.CoreSchema] = {}
invalid_defs: dict[str, core_schema.CoreSchema] = {}
all_defs: dict[str, core_schema.CoreSchema] = {}

def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema:
definitions = list(defs)
Expand All @@ -413,42 +415,34 @@ def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
def_schema = recurse(definition, collect_refs).copy()
if 'invalid' in def_schema.get('metadata', {}):
invalid_defs[ref] = def_schema
else:
valid_defs[ref] = def_schema
all_defs[ref] = recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
ref = get_ref(s)
if ref is not None:
if 'invalid' in s.get('metadata', {}):
invalid_defs[ref] = s
else:
valid_defs[ref] = s
return recurse(s, collect_refs)
else:
ref = get_ref(s)
if ref is not None:
all_defs[ref] = s
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

all_defs = {**invalid_defs, **valid_defs}

def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if is_definitions_schema(s):
new: dict[str, Any] = dict(s)
if s['type'] == 'definitions':
# iterate ourselves, we don't want to flatten the actual defs!
definitions: list[CoreSchema] = new.pop('definitions')
schema = cast(CoreSchema, new.pop('schema'))
definitions: list[CoreSchema] = s.pop('definitions') # type: ignore
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
schema: CoreSchema = s.pop('schema') # type: ignore
# remaining keys are optional like 'serialization'
schema = cast(CoreSchema, {**schema, **new})
schema: CoreSchema = {**schema, **s} # type: ignore
s['schema'] = recurse(schema, flatten_refs)
for definition in definitions:
recurse(definition, flatten_refs) # don't re-assign here!
return schema
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
return core_schema.definition_reference_schema(schema_ref=ref)
return s
else:
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
return core_schema.definition_reference_schema(schema_ref=ref)
return s

schema = walk_core_schema(schema, flatten_refs)

Expand All @@ -458,12 +452,12 @@ def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor
if not inline:
return make_result(schema, all_defs.values())

ref_counts: dict[str, int] = defaultdict(int)
ref_counts: defaultdict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)
current_recursion_ref_count: defaultdict[str, int] = defaultdict(int)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if not is_definition_ref_schema(s):
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
Expand Down
11 changes: 9 additions & 2 deletions tests/benchmarks/test_fastapi_startup.py
Expand Up @@ -122,11 +122,18 @@ def bench():


if __name__ == '__main__':
# run with `python tests/benchmarks/test_fastapi_startup.py`
# run with `pdm run tests/benchmarks/test_fastapi_startup.py`
import cProfile
import sys
import time

INNER_DATA_MODEL_COUNT = 50
OUTER_DATA_MODEL_COUNT = 50
print(f'Python version: {sys.version}')
cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime')
if sys.argv[-1] == 'cProfile':
cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime')
else:
start = time.perf_counter()
test_fastapi_startup_perf(lambda f: f())
end = time.perf_counter()
print(f'Time taken: {end - start:.2f}s')