Skip to content

Commit

Permalink
Fix usage of AliasGenerator with computed_field decorator (#8806)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
  • Loading branch information
sydney-runkle and alexmojaki committed Feb 13, 2024
1 parent 071c2c8 commit b290b31
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
63 changes: 49 additions & 14 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -973,7 +973,7 @@ def _apply_alias_generator_to_field_info(
# Apply an alias_generator if
# 1. An alias is not specified
# 2. An alias is specified, but the priority is <= 1
if alias_generator and (
if (
field_info.alias_priority is None
or field_info.alias_priority <= 1
or field_info.alias is None
Expand Down Expand Up @@ -1009,6 +1009,49 @@ def _apply_alias_generator_to_field_info(
if field_info.validation_alias is None:
field_info.validation_alias = validation_alias or alias

@staticmethod
def _apply_alias_generator_to_computed_field_info(
alias_generator: Callable[[str], str] | AliasGenerator,
computed_field_info: ComputedFieldInfo,
computed_field_name: str,
):
"""Apply an alias_generator to alias on a ComputedFieldInfo instance if appropriate.
Args:
alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance.
computed_field_info: The ComputedFieldInfo instance to which the alias_generator is (maybe) applied.
computed_field_name: The name of the computed field from which to generate the alias.
"""
# Apply an alias_generator if
# 1. An alias is not specified
# 2. An alias is specified, but the priority is <= 1

if (
computed_field_info.alias_priority is None
or computed_field_info.alias_priority <= 1
or computed_field_info.alias is None
):
alias, validation_alias, serialization_alias = None, None, None

if isinstance(alias_generator, AliasGenerator):
alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name)
elif isinstance(alias_generator, Callable):
alias = alias_generator(computed_field_name)
if not isinstance(alias, str):
raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}')

# if priority is not set, we set to 1
# which supports the case where the alias_generator from a child class is used
# to generate an alias for a field in a parent class
if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1:
computed_field_info.alias_priority = 1

# if the priority is 1, then we set the aliases to the generated alias
# note that we use the serialization_alias with priority over alias, as computed_field
# aliases are used for serialization only (not validation)
if computed_field_info.alias_priority == 1:
computed_field_info.alias = serialization_alias or alias

def _common_field_schema( # C901
self, name: str, field_info: FieldInfo, decorators: DecoratorInfos
) -> _CommonField:
Expand Down Expand Up @@ -1659,20 +1702,12 @@ def _computed_field_schema(
filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name),
computed_field=True,
)
# Handle alias_generator using similar logic to that from
# pydantic._internal._generate_schema.GenerateSchema._common_field_schema,
# with field_info -> d.info and name -> d.cls_var_name

alias_generator = self._config_wrapper.alias_generator
if alias_generator and (d.info.alias_priority is None or d.info.alias_priority <= 1):
alias = None
if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None:
alias = alias_generator.alias(d.cls_var_name)
elif isinstance(alias_generator, Callable):
alias = alias_generator(d.cls_var_name)
if not isinstance(alias, str):
raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}')
d.info.alias = alias
d.info.alias_priority = 1
if alias_generator is not None:
self._apply_alias_generator_to_computed_field_info(
alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name
)

def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(schema)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_aliases.py
Expand Up @@ -702,3 +702,26 @@ class Foo(BaseModel):
assert f.a == 'a'
assert f.model_dump(by_alias=True) == {'a_ser_alias': 'a'}
assert f.model_dump(by_alias=False) == {'a': 'a'}


def test_alias_generator_with_computed_field_for_serialization() -> None:
"""Tests that the alias generator is used for computed fields, with serialization_alias taking precedence over alias."""

class Rectangle(BaseModel):
model_config = ConfigDict(
alias_generator=AliasGenerator(
validation_alias=lambda field_name: f'{field_name}_val_alias',
alias=lambda field_name: f'{field_name}_alias',
serialization_alias=lambda field_name: f'{field_name}_ser_alias',
)
)

width: int
height: int

@computed_field
def area(self) -> int:
return self.width * self.height

r = Rectangle(width_val_alias=10, height_val_alias=20)
assert r.model_dump(by_alias=True) == {'width_ser_alias': 10, 'height_ser_alias': 20, 'area_ser_alias': 200}

0 comments on commit b290b31

Please sign in to comment.