Skip to content

Commit

Permalink
tweaks to _core_utils (#7040)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Aug 15, 2023
1 parent bd6b723 commit 2009d29
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 67 deletions.
124 changes: 59 additions & 65 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -32,14 +32,6 @@
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}


def is_definition_ref_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionReferenceSchema]:
return s['type'] == 'definition-ref'


def is_definitions_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionsSchema]:
return s['type'] == 'definitions'


def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
Expand Down Expand Up @@ -192,14 +184,15 @@ def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchem

def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema, f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization', None) # type: ignore
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema.copy(), f)
return schema

def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
if 'schema' in schema:
schema['schema'] = self.walk(schema['schema'], f) # type: ignore
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema

def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
Expand Down Expand Up @@ -232,47 +225,55 @@ def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Wa
return new_schema

def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_tuple_variable_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TupleVariableSchema, schema)
if 'items_schema' in schema:
schema['items_schema'] = self.walk(schema['items_schema'], f)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema

def handle_tuple_positional_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TuplePositionalSchema, schema)
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
if 'extra_schema' in schema:
schema['extra_schema'] = self.walk(schema['extra_schema'], f)
extra_schema = schema.get('extra_schema')
if extra_schema is not None:
schema['extra_schema'] = self.walk(extra_schema, f)
return schema

def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
if 'keys_schema' in schema:
schema['keys_schema'] = self.walk(schema['keys_schema'], f)
if 'values_schema' in schema:
schema['values_schema'] = self.walk(schema['values_schema'], f)
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema

def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
Expand Down Expand Up @@ -307,11 +308,12 @@ def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f
return schema

def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
if 'extra_validator' in schema:
schema['extra_validator'] = self.walk(schema['extra_validator'], f)
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand All @@ -325,10 +327,11 @@ def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: W
return schema

def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
if 'extra_validator' in schema:
schema['extra_validator'] = self.walk(schema['extra_validator'], f)
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand All @@ -345,7 +348,7 @@ def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk)
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', None) or ():
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
Expand Down Expand Up @@ -395,12 +398,11 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema.copy(), _dispatch)
return f(schema, _dispatch)


def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901
valid_defs: dict[str, core_schema.CoreSchema] = {}
invalid_defs: dict[str, core_schema.CoreSchema] = {}
all_defs: dict[str, core_schema.CoreSchema] = {}

def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema:
definitions = list(defs)
Expand All @@ -413,42 +415,34 @@ def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
def_schema = recurse(definition, collect_refs).copy()
if 'invalid' in def_schema.get('metadata', {}):
invalid_defs[ref] = def_schema
else:
valid_defs[ref] = def_schema
all_defs[ref] = recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
ref = get_ref(s)
if ref is not None:
if 'invalid' in s.get('metadata', {}):
invalid_defs[ref] = s
else:
valid_defs[ref] = s
return recurse(s, collect_refs)
else:
ref = get_ref(s)
if ref is not None:
all_defs[ref] = s
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

all_defs = {**invalid_defs, **valid_defs}

def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if is_definitions_schema(s):
new: dict[str, Any] = dict(s)
if s['type'] == 'definitions':
# iterate ourselves, we don't want to flatten the actual defs!
definitions: list[CoreSchema] = new.pop('definitions')
schema = cast(CoreSchema, new.pop('schema'))
definitions: list[CoreSchema] = s.pop('definitions') # type: ignore
schema: CoreSchema = s.pop('schema') # type: ignore
# remaining keys are optional like 'serialization'
schema = cast(CoreSchema, {**schema, **new})
schema: CoreSchema = {**schema, **s} # type: ignore
s['schema'] = recurse(schema, flatten_refs)
for definition in definitions:
recurse(definition, flatten_refs) # don't re-assign here!
return schema
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
return core_schema.definition_reference_schema(schema_ref=ref)
return s
else:
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
return core_schema.definition_reference_schema(schema_ref=ref)
return s

schema = walk_core_schema(schema, flatten_refs)

Expand All @@ -458,12 +452,12 @@ def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor
if not inline:
return make_result(schema, all_defs.values())

ref_counts: dict[str, int] = defaultdict(int)
ref_counts: defaultdict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)
current_recursion_ref_count: defaultdict[str, int] = defaultdict(int)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if not is_definition_ref_schema(s):
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
Expand Down
11 changes: 9 additions & 2 deletions tests/benchmarks/test_fastapi_startup.py
Expand Up @@ -122,11 +122,18 @@ def bench():


if __name__ == '__main__':
# run with `python tests/benchmarks/test_fastapi_startup.py`
# run with `pdm run tests/benchmarks/test_fastapi_startup.py`
import cProfile
import sys
import time

INNER_DATA_MODEL_COUNT = 50
OUTER_DATA_MODEL_COUNT = 50
print(f'Python version: {sys.version}')
cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime')
if sys.argv[-1] == 'cProfile':
cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime')
else:
start = time.perf_counter()
test_fastapi_startup_perf(lambda f: f())
end = time.perf_counter()
print(f'Time taken: {end - start:.2f}s')

0 comments on commit 2009d29

Please sign in to comment.