Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable json schema creation with type ByteSize #8537

Merged
14 changes: 12 additions & 2 deletions pydantic/types.py
Expand Up @@ -1719,7 +1719,8 @@ def validate_brand(card_number: str) -> PaymentCardBrand:
'eib': 2**60,
}
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)
byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?'
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE)


class ByteSize(int):
Expand Down Expand Up @@ -1759,7 +1760,16 @@ class MyModel(BaseModel):

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(cls._validate)
return core_schema.with_info_after_validator_function(
function=cls._validate,
schema=core_schema.union_schema(
[
core_schema.str_schema(pattern=byte_string_pattern),
core_schema.int_schema(ge=0),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(function=cls.to),
)

@classmethod
def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize:
Expand Down
40 changes: 38 additions & 2 deletions tests/test_json_schema.py
Expand Up @@ -6,7 +6,14 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, IntEnum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -68,13 +75,22 @@
model_json_schema,
models_json_schema,
)
from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, MultiHostUrl, NameEmail
from pydantic.networks import (
AnyUrl,
EmailStr,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
MultiHostUrl,
NameEmail,
)
from pydantic.type_adapter import TypeAdapter
from pydantic.types import (
UUID1,
UUID3,
UUID4,
UUID5,
ByteSize,
DirectoryPath,
FilePath,
Json,
Expand Down Expand Up @@ -1292,6 +1308,26 @@ class MyGenerator(GenerateJsonSchema):
assert model_schema['properties'] == properties


def test_byte_size_type():
class Model(BaseModel):
a: ByteSize

assert Model.model_json_schema() == {
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
'properties': {
'a': {
'anyOf': [
{'pattern': '^\\s*(\\d*\\.?\\d+)\\s*(\\w+)?', 'type': 'string'},
{'minimum': 0, 'type': 'integer'},
],
'title': 'A',
}
},
'required': ['a'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
'type_,default_value,properties',
(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_types.py
Expand Up @@ -4435,6 +4435,7 @@ class FrozenSetModel(BaseModel):
@pytest.mark.parametrize(
'input_value,output,human_bin,human_dec',
(
(1, 1, '1B', '1B'),
('1', 1, '1B', '1B'),
('1.0', 1, '1B', '1B'),
('1b', 1, '1B', '1B'),
Expand Down Expand Up @@ -4473,7 +4474,7 @@ def test_bytesize_raises():
class Model(BaseModel):
size: ByteSize

with pytest.raises(ValidationError, match='parse value'):
with pytest.raises(ValidationError, match='should match'):
Model(size='d1MB')

with pytest.raises(ValidationError, match='byte unit'):
Expand Down