Skip to content

Commit

Permalink
add field_serializer to computed_field (#6965)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
  • Loading branch information
andresliszt and adriangb committed Aug 1, 2023
1 parent 586ed21 commit de6fc67
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 12 deletions.
11 changes: 10 additions & 1 deletion pydantic/_internal/_decorators.py
Expand Up @@ -527,7 +527,9 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
)


def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
def inspect_field_serializer(
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.
Expand All @@ -536,6 +538,8 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
computed_field: When serializer is applied on computed_field. It doesn't require
info signature.
Returns:
Tuple of (is_field_serializer, info_arg).
Expand All @@ -557,6 +561,11 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
if info_arg and computed_field:
raise PydanticUserError(
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
)

else:
return is_field_serializer, info_arg

Expand Down
41 changes: 32 additions & 9 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -464,13 +464,14 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:

fields = cls.model_fields
decorators = cls.__pydantic_decorators__
computed_fields = decorators.computed_fields
check_decorator_fields_exist(
chain(
decorators.field_validators.values(),
decorators.field_serializers.values(),
decorators.validators.values(),
),
fields.keys(),
{*fields.keys(), *computed_fields.keys()},
)
config_wrapper = ConfigWrapper(cls.model_config, check=False)
core_config = config_wrapper.core_config(cls)
Expand Down Expand Up @@ -515,11 +516,13 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
else:
fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema(
{k: self._generate_md_field_schema(k, v, decorators) for k, v in fields.items()},
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in computed_fields.values()
],
extra_validator=extra_validator,
model_name=cls.__name__,
)

inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None)
inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs())
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')
Expand Down Expand Up @@ -1083,7 +1086,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

td_schema = core_schema.typed_dict_schema(
fields,
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in decorators.computed_fields.values()
],
ref=typed_dict_ref,
metadata=metadata,
config=core_config,
Expand Down Expand Up @@ -1316,7 +1322,10 @@ def _dataclass_schema(
args_schema = core_schema.dataclass_args_schema(
dataclass.__name__,
args,
computed_fields=[self._computed_field_schema(d) for d in decorators.computed_fields.values()],
computed_fields=[
self._computed_field_schema(d, decorators.field_serializers)
for d in decorators.computed_fields.values()
],
collect_init_only=has_post_init,
)

Expand Down Expand Up @@ -1402,7 +1411,11 @@ def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.
else:
return core_schema.any_schema()

def _computed_field_schema(self, d: Decorator[ComputedFieldInfo]) -> core_schema.ComputedField:
def _computed_field_schema(
self,
d: Decorator[ComputedFieldInfo],
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]],
) -> core_schema.ComputedField:
try:
return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace)
except NameError as e:
Expand All @@ -1415,7 +1428,12 @@ def _computed_field_schema(self, d: Decorator[ComputedFieldInfo]) -> core_schema
)

return_type_schema = self.generate_schema(return_type)

# Apply serializers to computed field if there exist
return_type_schema = self._apply_field_serializers(
return_type_schema,
filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name),
computed_field=True,
)
# Handle alias_generator using similar logic to that from
# pydantic._internal._generate_schema.GenerateSchema._common_field_schema,
# with field_info -> d.info and name -> d.cls_var_name
Expand Down Expand Up @@ -1677,7 +1695,10 @@ def new_handler(source: Any) -> core_schema.CoreSchema:
return CallbackGetCoreSchemaHandler(new_handler, self)

def _apply_field_serializers(
self, schema: core_schema.CoreSchema, serializers: list[Decorator[FieldSerializerDecoratorInfo]]
self,
schema: core_schema.CoreSchema,
serializers: list[Decorator[FieldSerializerDecoratorInfo]],
computed_field: bool = False,
) -> core_schema.CoreSchema:
"""Apply field serializers to a schema."""
if serializers:
Expand All @@ -1693,7 +1714,9 @@ def _apply_field_serializers(

# use the last serializer to make it easy to override a serializer set on a parent model
serializer = serializers[-1]
is_field_serializer, info_arg = inspect_field_serializer(serializer.func, serializer.info.mode)
is_field_serializer, info_arg = inspect_field_serializer(
serializer.func, serializer.info.mode, computed_field=computed_field
)

try:
return_type = _decorators.get_function_return_type(
Expand Down
14 changes: 12 additions & 2 deletions tests/test_computed_fields.py
Expand Up @@ -15,6 +15,7 @@
TypeAdapter,
computed_field,
dataclasses,
field_serializer,
field_validator,
)
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -119,14 +120,23 @@ class Square(BaseModel):
def area(self) -> float:
return self.side**2

@computed_field
@property
def area_string(self) -> str:
return f'{self.area} square units'

@field_serializer('area_string')
def serialize_area_string(self, area_string):
return area_string.upper()

@area.setter
def area(self, new_area: int):
self.side = new_area**0.5

s = Square(side=10)
assert s.model_dump() == {'side': 10.0, 'area': 100.0}
assert s.model_dump() == {'side': 10.0, 'area': 100.0, 'area_string': '100.0 SQUARE UNITS'}
s.area = 64
assert s.model_dump() == {'side': 8.0, 'area': 64.0}
assert s.model_dump() == {'side': 8.0, 'area': 64.0, 'area_string': '64.0 SQUARE UNITS'}


def test_computed_fields_del():
Expand Down
93 changes: 93 additions & 0 deletions tests/test_serialize.py
Expand Up @@ -15,6 +15,7 @@
BaseModel,
Field,
FieldSerializationInfo,
PydanticUserError,
SecretStr,
SerializationInfo,
SerializeAsAny,
Expand Down Expand Up @@ -985,6 +986,98 @@ def two_x(self) -> 'IntAlias': # noqa F821
assert Model(x=1).model_dump() == {'two_x': 2, 'x': 1}


def test_computed_field_custom_serializer():
class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> int:
return self.x * 2

@field_serializer('two_x', when_used='json')
def ser_two_x(self, v):
return f'The double of x is {v}'

m = Model(x=1)

assert m.model_dump() == {'two_x': 2, 'x': 1}
assert json.loads(m.model_dump_json()) == {'two_x': 'The double of x is 2', 'x': 1}


def test_annotated_computed_field_custom_serializer():
class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> Annotated[int, PlainSerializer(lambda v: f'The double of x is {v}', return_type=str)]:
return self.x * 2

@computed_field
@property
def triple_x(self) -> Annotated[int, PlainSerializer(lambda v: f'The triple of x is {v}', return_type=str)]:
return self.two_x * 3

@computed_field
@property
def quadruple_x_plus_one(self) -> Annotated[int, PlainSerializer(lambda v: v + 1, return_type=int)]:
return self.two_x * 2

m = Model(x=1)
assert m.x == 1
assert m.two_x == 2
assert m.triple_x == 6
assert m.quadruple_x_plus_one == 4

# insert_assert(m.model_dump())
assert m.model_dump() == {
'x': 1,
'two_x': 'The double of x is 2',
'triple_x': 'The triple of x is 6',
'quadruple_x_plus_one': 5,
}

# insert_assert(json.loads(m.model_dump_json()))
assert json.loads(m.model_dump_json()) == {
'x': 1,
'two_x': 'The double of x is 2',
'triple_x': 'The triple of x is 6',
'quadruple_x_plus_one': 5,
}

# insert_assert(Model.model_json_schema(mode='serialization'))
assert Model.model_json_schema(mode='serialization') == {
'properties': {
'x': {'title': 'X', 'type': 'integer'},
'two_x': {'readOnly': True, 'title': 'Two X', 'type': 'string'},
'triple_x': {'readOnly': True, 'title': 'Triple X', 'type': 'string'},
'quadruple_x_plus_one': {'readOnly': True, 'title': 'Quadruple X Plus One', 'type': 'integer'},
},
'required': ['x', 'two_x', 'triple_x', 'quadruple_x_plus_one'],
'title': 'Model',
'type': 'object',
}


def test_computed_field_custom_serializer_bad_signature():
error_msg = 'field_serializer on computed_field does not use info signature'

with pytest.raises(PydanticUserError, match=error_msg):

class Model(BaseModel):
x: int

@computed_field
@property
def two_x(self) -> int:
return self.x * 2

@field_serializer('two_x')
def ser_two_x_bad_signature(self, v, _info):
return f'The double of x is {v}'


@pytest.mark.skipif(sys.version_info < (3, 9), reason='@computed_field @classmethod @property only works in 3.9+')
def test_forward_ref_for_classmethod_computed_fields():
class Model(BaseModel):
Expand Down

0 comments on commit de6fc67

Please sign in to comment.