Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add field_serializer to computed_field #6965

Merged
merged 5 commits into from Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion pydantic/_internal/_decorators.py
Expand Up @@ -527,7 +527,9 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
)


def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
def inspect_field_serializer(
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.

Expand All @@ -536,6 +538,8 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
computed_field: When serializer is applied on computed_field. It doesn't require
info signature.

Returns:
Tuple of (is_field_serializer, info_arg).
Expand All @@ -557,6 +561,11 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
if info_arg and computed_field:
raise PydanticUserError(
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
)

else:
return is_field_serializer, info_arg

Expand Down
41 changes: 32 additions & 9 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -464,13 +464,14 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:

fields = cls.model_fields
decorators = cls.__pydantic_decorators__
computed_fields = decorators.computed_fields
check_decorator_fields_exist(
chain(
decorators.field_validators.values(),
decorators.field_serializers.values(),
decorators.validators.values(),
),
fields.keys(),
{*fields.keys(), *computed_fields.keys()},
)
config_wrapper = ConfigWrapper(cls.model_config, check=False)
core_config = config_wrapper.core_config(cls)
Expand Down Expand Up @@ -515,11 +516,13 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
else:
fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema(
{k: self._generate_md_field_schema(k, v, decorators) for k, v in fields.items()},
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in computed_fields.values()
],
extra_validator=extra_validator,
model_name=cls.__name__,
)

inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None)
inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs())
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')
Expand Down Expand Up @@ -1083,7 +1086,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

td_schema = core_schema.typed_dict_schema(
fields,
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in decorators.computed_fields.values()
],
ref=typed_dict_ref,
metadata=metadata,
config=core_config,
Expand Down Expand Up @@ -1316,7 +1322,10 @@ def _dataclass_schema(
args_schema = core_schema.dataclass_args_schema(
dataclass.__name__,
args,
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in decorators.computed_fields.values()
],
collect_init_only=has_post_init,
)

Expand Down Expand Up @@ -1402,7 +1411,11 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.
else:
return core_schema.any_schema()

def _computed_field_schema(self, d: Decorator[ComputedFieldInfo]) -> core_schema.ComputedField:
def _computed_field_schema(
self,
d: Decorator[ComputedFieldInfo],
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]],
) -> core_schema.ComputedField:
try:
return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace)
except NameError as e:
Expand All @@ -1415,7 +1428,12 @@ def _computed_field_schema(self, d: Decorator[ComputedFieldInfo]) -> core_schema
)

return_type_schema = self.generate_schema(return_type)

# Apply serializers to computed field if there exist
return_type_schema = self._apply_field_serializers(
return_type_schema,
filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name),
computed_field=True,
)
# Handle alias_generator using similar logic to that from
# pydantic._internal._generate_schema.GenerateSchema._common_field_schema,
# with field_info -> d.info and name -> d.cls_var_name
Expand Down Expand Up @@ -1677,7 +1695,10 @@ def new_handler(source: Any) -> core_schema.CoreSchema:
return CallbackGetCoreSchemaHandler(new_handler, self)

def _apply_field_serializers(
self, schema: core_schema.CoreSchema, serializers: list[Decorator[FieldSerializerDecoratorInfo]]
self,
schema: core_schema.CoreSchema,
serializers: list[Decorator[FieldSerializerDecoratorInfo]],
computed_field: bool = False,
) -> core_schema.CoreSchema:
"""Apply field serializers to a schema."""
if serializers:
Expand All @@ -1693,7 +1714,9 @@ def _apply_field_serializers(

# use the last serializer to make it easy to override a serializer set on a parent model
serializer = serializers[-1]
is_field_serializer, info_arg = inspect_field_serializer(serializer.func, serializer.info.mode)
is_field_serializer, info_arg = inspect_field_serializer(
serializer.func, serializer.info.mode, computed_field=computed_field
)

try:
return_type = _decorators.get_function_return_type(
Expand Down
14 changes: 12 additions & 2 deletions tests/test_computed_fields.py
Expand Up @@ -15,6 +15,7 @@
TypeAdapter,
computed_field,
dataclasses,
field_serializer,
field_validator,
)
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -119,14 +120,23 @@ class Square(BaseModel):
def area(self) -> float:
return self.side**2

@computed_field
@property
def area_string(self) -> str:
return f'{self.area} square units'

@field_serializer('area_string')
def serialize_area_string(self, area_string):
return area_string.upper()

@area.setter
def area(self, new_area: int):
self.side = new_area**0.5

s = Square(side=10)
assert s.model_dump() == {'side': 10.0, 'area': 100.0}
assert s.model_dump() == {'side': 10.0, 'area': 100.0, 'area_string': '100.0 SQUARE UNITS'}
s.area = 64
assert s.model_dump() == {'side': 8.0, 'area': 64.0}
assert s.model_dump() == {'side': 8.0, 'area': 64.0, 'area_string': '64.0 SQUARE UNITS'}


def test_computed_fields_del():
Expand Down
93 changes: 93 additions & 0 deletions tests/test_serialize.py
Expand Up @@ -15,6 +15,7 @@
BaseModel,
Field,
FieldSerializationInfo,
PydanticUserError,
SecretStr,
SerializationInfo,
SerializeAsAny,
Expand Down Expand Up @@ -985,6 +986,98 @@ def two_x(self) -> 'IntAlias': # noqa F821
assert Model(x=1).model_dump() == {'two_x': 2, 'x': 1}


def test_computed_field_custom_serializer():
class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> int:
return self.x * 2

@field_serializer('two_x', when_used='json')
def ser_two_x(self, v):
return f'The double of x is {v}'

m = Model(x=1)

assert m.model_dump() == {'two_x': 2, 'x': 1}
assert json.loads(m.model_dump_json()) == {'two_x': 'The double of x is 2', 'x': 1}


def test_annotated_computed_field_custom_serializer():
class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> Annotated[int, PlainSerializer(lambda v: f'The double of x is {v}', return_type=str)]:
return self.x * 2

@computed_field
@property
def triple_x(self) -> Annotated[int, PlainSerializer(lambda v: f'The triple of x is {v}', return_type=str)]:
return self.two_x * 3

@computed_field
@property
def quadruple_x_plus_one(self) -> Annotated[int, PlainSerializer(lambda v: v + 1, return_type=int)]:
return self.two_x * 2

m = Model(x=1)
assert m.x == 1
assert m.two_x == 2
assert m.triple_x == 6
assert m.quadruple_x_plus_one == 4

# insert_assert(m.model_dump())
assert m.model_dump() == {
'x': 1,
'two_x': 'The double of x is 2',
'triple_x': 'The triple of x is 6',
'quadruple_x_plus_one': 5,
}

# insert_assert(json.loads(m.model_dump_json()))
assert json.loads(m.model_dump_json()) == {
'x': 1,
'two_x': 'The double of x is 2',
'triple_x': 'The triple of x is 6',
'quadruple_x_plus_one': 5,
}

# insert_assert(Model.model_json_schema(mode='serialization'))
assert Model.model_json_schema(mode='serialization') == {
'properties': {
'x': {'title': 'X', 'type': 'integer'},
'two_x': {'readOnly': True, 'title': 'Two X', 'type': 'string'},
'triple_x': {'readOnly': True, 'title': 'Triple X', 'type': 'string'},
'quadruple_x_plus_one': {'readOnly': True, 'title': 'Quadruple X Plus One', 'type': 'integer'},
},
'required': ['x', 'two_x', 'triple_x', 'quadruple_x_plus_one'],
'title': 'Model',
'type': 'object',
}


def test_computed_field_custom_serializer_bad_signature():
error_msg = 'field_serializer on computed_field does not use info signature'

with pytest.raises(PydanticUserError, match=error_msg):

class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> int:
return self.x * 2

@field_serializer('two_x')
def ser_two_x_bad_signature(self, v, _info):
return f'The double of x is {v}'


@pytest.mark.skipif(sys.version_info < (3, 9), reason='@computed_field @classmethod @property only works in 3.9+')
def test_forward_ref_for_classmethod_computed_fields():
class Model(BaseModel):
Expand Down