Skip to content

Commit

Permalink
feat: enable json schema creation with type ByteSize (#8537)
Browse files Browse the repository at this point in the history
  • Loading branch information
geospackle committed Jan 15, 2024
1 parent 64d5223 commit c7d965c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 38 deletions.
83 changes: 48 additions & 35 deletions pydantic/types.py
Expand Up @@ -1710,37 +1710,6 @@ def validate_brand(card_number: str) -> PaymentCardBrand:

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

BYTE_SIZES = {
'b': 1,
'kb': 10**3,
'mb': 10**6,
'gb': 10**9,
'tb': 10**12,
'pb': 10**15,
'eb': 10**18,
'kib': 2**10,
'mib': 2**20,
'gib': 2**30,
'tib': 2**40,
'pib': 2**50,
'eib': 2**60,
'bit': 1 / 8,
'kbit': 10**3 / 8,
'mbit': 10**6 / 8,
'gbit': 10**9 / 8,
'tbit': 10**12 / 8,
'pbit': 10**15 / 8,
'ebit': 10**18 / 8,
'kibit': 2**10 / 8,
'mibit': 2**20 / 8,
'gibit': 2**30 / 8,
'tibit': 2**40 / 8,
'pibit': 2**50 / 8,
'eibit': 2**60 / 8,
}
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)


class ByteSize(int):
"""Converts a string representing a number of bytes with units (such as `'1KB'` or `'11.5MiB'`) into an integer.
Expand Down Expand Up @@ -1777,9 +1746,53 @@ class MyModel(BaseModel):
```
"""

byte_sizes = {
'b': 1,
'kb': 10**3,
'mb': 10**6,
'gb': 10**9,
'tb': 10**12,
'pb': 10**15,
'eb': 10**18,
'kib': 2**10,
'mib': 2**20,
'gib': 2**30,
'tib': 2**40,
'pib': 2**50,
'eib': 2**60,
'bit': 1 / 8,
'kbit': 10**3 / 8,
'mbit': 10**6 / 8,
'gbit': 10**9 / 8,
'tbit': 10**12 / 8,
'pbit': 10**15 / 8,
'ebit': 10**18 / 8,
'kibit': 2**10 / 8,
'mibit': 2**20 / 8,
'gibit': 2**30 / 8,
'tibit': 2**40 / 8,
'pibit': 2**50 / 8,
'eibit': 2**60 / 8,
}
byte_sizes.update({k.lower()[0]: v for k, v in byte_sizes.items() if 'i' not in k})

byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?'
byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE)

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(cls._validate)
return core_schema.with_info_after_validator_function(
function=cls._validate,
schema=core_schema.union_schema(
[
core_schema.str_schema(pattern=cls.byte_string_pattern),
core_schema.int_schema(ge=0),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
int, return_schema=core_schema.int_schema(ge=0)
),
)

@classmethod
def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize:
Expand All @@ -1788,7 +1801,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz
except ValueError:
pass

str_match = byte_string_re.match(str(__input_value))
str_match = cls.byte_string_re.match(str(__input_value))
if str_match is None:
raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string')

Expand All @@ -1797,7 +1810,7 @@ def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSiz
unit = 'b'

try:
unit_mult = BYTE_SIZES[unit.lower()]
unit_mult = cls.byte_sizes[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit})

Expand Down Expand Up @@ -1846,7 +1859,7 @@ def to(self, unit: str) -> float:
The byte size in the new unit.
"""
try:
unit_div = BYTE_SIZES[unit.lower()]
unit_div = self.byte_sizes[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit})

Expand Down
64 changes: 62 additions & 2 deletions tests/test_json_schema.py
Expand Up @@ -6,7 +6,14 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, IntEnum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -68,13 +75,22 @@
model_json_schema,
models_json_schema,
)
from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, MultiHostUrl, NameEmail
from pydantic.networks import (
AnyUrl,
EmailStr,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
MultiHostUrl,
NameEmail,
)
from pydantic.type_adapter import TypeAdapter
from pydantic.types import (
UUID1,
UUID3,
UUID4,
UUID5,
ByteSize,
DirectoryPath,
FilePath,
Json,
Expand Down Expand Up @@ -1292,6 +1308,50 @@ class MyGenerator(GenerateJsonSchema):
assert model_schema['properties'] == properties


def test_byte_size_type():
class Model(BaseModel):
a: ByteSize
b: ByteSize = Field('1MB', validate_default=True)

model_json_schema_validation = Model.model_json_schema(mode='validation')
model_json_schema_serialization = Model.model_json_schema(mode='serialization')

print(model_json_schema_serialization)

assert model_json_schema_validation == {
'properties': {
'a': {
'anyOf': [
{'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'},
{'minimum': 0, 'type': 'integer'},
],
'title': 'A',
},
'b': {
'anyOf': [
{'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'},
{'minimum': 0, 'type': 'integer'},
],
'default': '1MB',
'title': 'B',
},
},
'required': ['a'],
'title': 'Model',
'type': 'object',
}

assert model_json_schema_serialization == {
'properties': {
'a': {'minimum': 0, 'title': 'A', 'type': 'integer'},
'b': {'default': '1MB', 'minimum': 0, 'title': 'B', 'type': 'integer'},
},
'required': ['a'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
'type_,default_value,properties',
(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_types.py
Expand Up @@ -4434,6 +4434,7 @@ class FrozenSetModel(BaseModel):
@pytest.mark.parametrize(
'input_value,output,human_bin,human_dec',
(
(1, 1, '1B', '1B'),
('1', 1, '1B', '1B'),
('1.0', 1, '1B', '1B'),
('1b', 1, '1B', '1B'),
Expand Down Expand Up @@ -4476,7 +4477,7 @@ def test_bytesize_raises():
class Model(BaseModel):
size: ByteSize

with pytest.raises(ValidationError, match='parse value'):
with pytest.raises(ValidationError, match='should match'):
Model(size='d1MB')

with pytest.raises(ValidationError, match='byte unit'):
Expand Down

0 comments on commit c7d965c

Please sign in to comment.