Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable json schema creation with type ByteSize #8537

Merged
83 changes: 48 additions & 35 deletions pydantic/types.py
Expand Up @@ -1710,37 +1710,6 @@ def validate_brand(card_number: str) -> PaymentCardBrand:

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

BYTE_SIZES = {
'b': 1,
'kb': 10**3,
'mb': 10**6,
'gb': 10**9,
'tb': 10**12,
'pb': 10**15,
'eb': 10**18,
'kib': 2**10,
'mib': 2**20,
'gib': 2**30,
'tib': 2**40,
'pib': 2**50,
'eib': 2**60,
'bit': 1 / 8,
'kbit': 10**3 / 8,
'mbit': 10**6 / 8,
'gbit': 10**9 / 8,
'tbit': 10**12 / 8,
'pbit': 10**15 / 8,
'ebit': 10**18 / 8,
'kibit': 2**10 / 8,
'mibit': 2**20 / 8,
'gibit': 2**30 / 8,
'tibit': 2**40 / 8,
'pibit': 2**50 / 8,
'eibit': 2**60 / 8,
}
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)


class ByteSize(int):
"""Converts a string representing a number of bytes with units (such as `'1KB'` or `'11.5MiB'`) into an integer.
Expand Down Expand Up @@ -1777,9 +1746,53 @@ class MyModel(BaseModel):
```
"""

byte_sizes = {
'b': 1,
'kb': 10**3,
'mb': 10**6,
'gb': 10**9,
'tb': 10**12,
'pb': 10**15,
'eb': 10**18,
'kib': 2**10,
'mib': 2**20,
'gib': 2**30,
'tib': 2**40,
'pib': 2**50,
'eib': 2**60,
'bit': 1 / 8,
'kbit': 10**3 / 8,
'mbit': 10**6 / 8,
'gbit': 10**9 / 8,
'tbit': 10**12 / 8,
'pbit': 10**15 / 8,
'ebit': 10**18 / 8,
'kibit': 2**10 / 8,
'mibit': 2**20 / 8,
'gibit': 2**30 / 8,
'tibit': 2**40 / 8,
'pibit': 2**50 / 8,
'eibit': 2**60 / 8,
}
byte_sizes.update({k.lower()[0]: v for k, v in byte_sizes.items() if 'i' not in k})

byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?'
byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE)

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(cls._validate)
return core_schema.with_info_after_validator_function(
function=cls._validate,
schema=core_schema.union_schema(
[
core_schema.str_schema(pattern=cls.byte_string_pattern),
core_schema.int_schema(ge=0),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
int, return_schema=core_schema.int_schema(ge=0)
),
)

@classmethod
def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize:
Expand All @@ -1788,7 +1801,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz
except ValueError:
pass

str_match = byte_string_re.match(str(__input_value))
str_match = cls.byte_string_re.match(str(__input_value))
if str_match is None:
raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string')

Expand All @@ -1797,7 +1810,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz
unit = 'b'

try:
unit_mult = BYTE_SIZES[unit.lower()]
unit_mult = cls.byte_sizes[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit})

Expand Down Expand Up @@ -1846,7 +1859,7 @@ def to(self, unit: str) -> float:
The byte size in the new unit.
"""
try:
unit_div = BYTE_SIZES[unit.lower()]
unit_div = self.byte_sizes[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit})

Expand Down
64 changes: 62 additions & 2 deletions tests/test_json_schema.py
Expand Up @@ -6,7 +6,14 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, IntEnum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -68,13 +75,22 @@
model_json_schema,
models_json_schema,
)
from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, MultiHostUrl, NameEmail
from pydantic.networks import (
AnyUrl,
EmailStr,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
MultiHostUrl,
NameEmail,
)
from pydantic.type_adapter import TypeAdapter
from pydantic.types import (
UUID1,
UUID3,
UUID4,
UUID5,
ByteSize,
DirectoryPath,
FilePath,
Json,
Expand Down Expand Up @@ -1292,6 +1308,50 @@ class MyGenerator(GenerateJsonSchema):
assert model_schema['properties'] == properties


def test_byte_size_type():
class Model(BaseModel):
a: ByteSize
b: ByteSize = Field('1MB', validate_default=True)

model_json_schema_validation = Model.model_json_schema(mode='validation')
model_json_schema_serialization = Model.model_json_schema(mode='serialization')

print(model_json_schema_serialization)

assert model_json_schema_validation == {
'properties': {
'a': {
'anyOf': [
{'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'},
{'minimum': 0, 'type': 'integer'},
],
'title': 'A',
},
'b': {
'anyOf': [
{'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'},
{'minimum': 0, 'type': 'integer'},
],
'default': '1MB',
'title': 'B',
},
},
'required': ['a'],
'title': 'Model',
'type': 'object',
}

assert model_json_schema_serialization == {
'properties': {
'a': {'minimum': 0, 'title': 'A', 'type': 'integer'},
'b': {'default': '1MB', 'minimum': 0, 'title': 'B', 'type': 'integer'},
geospackle marked this conversation as resolved.
Show resolved Hide resolved
},
'required': ['a'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
'type_,default_value,properties',
(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_types.py
Expand Up @@ -4434,6 +4434,7 @@ class FrozenSetModel(BaseModel):
@pytest.mark.parametrize(
'input_value,output,human_bin,human_dec',
(
(1, 1, '1B', '1B'),
('1', 1, '1B', '1B'),
('1.0', 1, '1B', '1B'),
('1b', 1, '1B', '1B'),
Expand Down Expand Up @@ -4476,7 +4477,7 @@ def test_bytesize_raises():
class Model(BaseModel):
size: ByteSize

with pytest.raises(ValidationError, match='parse value'):
with pytest.raises(ValidationError, match='should match'):
Model(size='d1MB')

with pytest.raises(ValidationError, match='byte unit'):
Expand Down