Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable json schema creation with type ByteSize #8537

Merged
14 changes: 12 additions & 2 deletions pydantic/types.py
Expand Up @@ -1719,7 +1719,8 @@ def validate_brand(card_number: str) -> PaymentCardBrand:
'eib': 2**60,
}
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)
byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?'
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE)


class ByteSize(int):
Expand Down Expand Up @@ -1759,7 +1760,16 @@ class MyModel(BaseModel):

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(cls._validate)
return core_schema.with_info_after_validator_function(
function=cls._validate,
schema=core_schema.union_schema(
[
core_schema.str_schema(pattern=byte_string_pattern),
core_schema.int_schema(ge=0),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(function=cls.to),
)

@classmethod
def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_types.py
Expand Up @@ -4435,6 +4435,7 @@ class FrozenSetModel(BaseModel):
@pytest.mark.parametrize(
'input_value,output,human_bin,human_dec',
(
(1, 1, '1B', '1B'),
('1', 1, '1B', '1B'),
('1.0', 1, '1B', '1B'),
('1b', 1, '1B', '1B'),
Expand Down Expand Up @@ -4473,7 +4474,7 @@ def test_bytesize_raises():
class Model(BaseModel):
size: ByteSize

with pytest.raises(ValidationError, match='parse value'):
with pytest.raises(ValidationError, match='should match'):
Model(size='d1MB')

with pytest.raises(ValidationError, match='byte unit'):
Expand Down