Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix usage of AliasGenerator with computed_field decorator #8806

Merged
merged 8 commits into from Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 49 additions & 14 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -973,7 +973,7 @@ def _apply_alias_generator_to_field_info(
# Apply an alias_generator if
# 1. An alias is not specified
# 2. An alias is specified, but the priority is <= 1
if alias_generator and (
if (
field_info.alias_priority is None
or field_info.alias_priority <= 1
or field_info.alias is None
Expand Down Expand Up @@ -1009,6 +1009,49 @@ def _apply_alias_generator_to_field_info(
if field_info.validation_alias is None:
field_info.validation_alias = validation_alias or alias

@staticmethod
def _apply_alias_generator_to_computed_field_info(
alias_generator: Callable[[str], str] | AliasGenerator,
computed_field_info: ComputedFieldInfo,
computed_field_name: str,
):
"""Apply an alias_generator to alias on a ComputedFieldInfo instance if appropriate.

Args:
alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance.
computed_field_info: The ComputedFieldInfo instance to which the alias_generator is (maybe) applied.
computed_field_name: The name of the computed field from which to generate the alias.
"""
# Apply an alias_generator if
# 1. An alias is not specified
# 2. An alias is specified, but the priority is <= 1

if (
computed_field_info.alias_priority is None
or computed_field_info.alias_priority <= 1
or computed_field_info.alias is None
):
alias, validation_alias, serialization_alias = None, None, None

if isinstance(alias_generator, AliasGenerator):
alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name)
elif isinstance(alias_generator, Callable):
alias = alias_generator(computed_field_name)
if not isinstance(alias, str):
raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}')

# if priority is not set, we set to 1
# which supports the case where the alias_generator from a child class is used
# to generate an alias for a field in a parent class
if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1:
computed_field_info.alias_priority = 1

# if the priority is 1, then we set the aliases to the generated alias
# note that we use the serialization_alias with priority over alias, as computed_field
# aliases are used for serialization only (not validation)
if computed_field_info.alias_priority == 1:
computed_field_info.alias = serialization_alias or alias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do

Suggested change
computed_field_info.alias = serialization_alias or alias
computed_field_info.alias = serialization_alias if serialization_alias is not None else alias

to handle the case of serialization_alias = '' and alias != ''. Probably worth a comment about why, though not a test.

I could be convinced not to do that but I do think there are tests for the alias = '' case elsewhere in the codebase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this, but in a different PR, to both the field_info logic and the computed_field_info logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, doesn't seem like an urgent change 👍


def _common_field_schema( # C901
self, name: str, field_info: FieldInfo, decorators: DecoratorInfos
) -> _CommonField:
Expand Down Expand Up @@ -1659,20 +1702,12 @@ def _computed_field_schema(
filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name),
computed_field=True,
)
# Handle alias_generator using similar logic to that from
# pydantic._internal._generate_schema.GenerateSchema._common_field_schema,
# with field_info -> d.info and name -> d.cls_var_name

alias_generator = self._config_wrapper.alias_generator
if alias_generator and (d.info.alias_priority is None or d.info.alias_priority <= 1):
alias = None
if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None:
alias = alias_generator.alias(d.cls_var_name)
elif isinstance(alias_generator, Callable):
alias = alias_generator(d.cls_var_name)
if not isinstance(alias, str):
raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}')
d.info.alias = alias
d.info.alias_priority = 1
if alias_generator is not None:
self._apply_alias_generator_to_computed_field_info(
alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name
)

def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(schema)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_aliases.py
Expand Up @@ -702,3 +702,26 @@ class Foo(BaseModel):
assert f.a == 'a'
assert f.model_dump(by_alias=True) == {'a_ser_alias': 'a'}
assert f.model_dump(by_alias=False) == {'a': 'a'}


def test_alias_generator_with_computed_field_for_serialization() -> None:
"""Tests that the alias generator is used for computed fields, with serialization_alias taking precedence over alias."""

class Rectangle(BaseModel):
model_config = ConfigDict(
alias_generator=AliasGenerator(
validation_alias=lambda field_name: f'{field_name}_val_alias',
alias=lambda field_name: f'{field_name}_alias',
serialization_alias=lambda field_name: f'{field_name}_ser_alias',
)
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
)

width: int
height: int

@computed_field
def area(self) -> int:
return self.width * self.height

r = Rectangle(width_val_alias=10, height_val_alias=20)
assert r.model_dump(by_alias=True) == {'width_ser_alias': 10, 'height_ser_alias': 20, 'area_ser_alias': 200}