Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base64Url types #7286

Merged
merged 1 commit into from Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 44 additions & 1 deletion docs/usage/types/encoded.md
Expand Up @@ -86,7 +86,8 @@ except ValidationError as e:

Internally, Pydantic uses the [`EncodedBytes`][pydantic.types.EncodedBytes] and [`EncodedStr`][pydantic.types.EncodedStr]
annotations with [`Base64Encoder`][pydantic.types.Base64Encoder] to implement base64 encoding/decoding in the
[`Base64Bytes`][pydantic.types.Base64Bytes] and [`Base64Str`][pydantic.types.Base64Str] types, respectively.
[`Base64Bytes`][pydantic.types.Base64Bytes], [`Base64UrlBytes`][pydantic.types.Base64UrlBytes],
[`Base64Str`][pydantic.types.Base64Str], and [`Base64UrlStr`][pydantic.types.Base64Str] types.

```py
from typing import Optional
Expand Down Expand Up @@ -131,3 +132,45 @@ except ValidationError as e:
Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value=b'undecodable', input_type=bytes]
"""
```

If you need url-safe base64 encoding, you can use the `Base64UrlBytes` and `Base64UrlStr` types. The following snippet
demonstrates the difference in alphabets used by the url-safe and non-url-safe encodings:

```py
from pydantic import (
Base64Bytes,
Base64Str,
Base64UrlBytes,
Base64UrlStr,
BaseModel,
)


class Model(BaseModel):
base64_bytes: Base64Bytes
base64_str: Base64Str
base64url_bytes: Base64UrlBytes
base64url_str: Base64UrlStr


# Initialize the model with base64 data
m = Model(
base64_bytes=b'SHc/dHc+TXc==',
base64_str='SHc/dHc+TXc==',
base64url_bytes=b'SHc_dHc-TXc==',
base64url_str='SHc_dHc-TXc==',
)
print(m)
"""
base64_bytes=b'Hw?tw>Mw' base64_str='Hw?tw>Mw' base64url_bytes=b'Hw?tw>Mw' base64url_str='Hw?tw>Mw'
"""
```

!!! note
Under the hood, `Base64Bytes` and `Base64Str` use the standard library `base64.encodebytes` and `base64.decodebytes`
functions, while `Base64UrlBytes` and `Base64UrlStr` use the `base64.urlsafe_b64encode` and
`base64.urlsafe_b64decode` functions.

As a result, the `Base64UrlBytes` and `Base64UrlStr` types can be used to faithfully decode "vanilla" base64 data
(using `'+'` and `'/'`), but the reverse is not true — attempting to decode url-safe base64 data using the
`Base64Bytes` and `Base64Str` types may fail or produce an incorrect decoding.
2 changes: 2 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -179,6 +179,8 @@
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'SkipValidation',
'InstanceOf',
'WithJsonSchema',
Expand Down
52 changes: 49 additions & 3 deletions pydantic/types.py
Expand Up @@ -96,6 +96,8 @@
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'StringConstraints',
)
Expand Down Expand Up @@ -1233,7 +1235,7 @@ def get_json_format(cls) -> str:


class Base64Encoder(EncoderProtocol):
"""Base64 encoder."""
"""Standard (non-URL-safe) Base64 encoder."""

@classmethod
def decode(cls, data: bytes) -> bytes:
Expand Down Expand Up @@ -1272,6 +1274,46 @@ def get_json_format(cls) -> Literal['base64']:
return 'base64'


class Base64UrlEncoder(EncoderProtocol):
"""URL-safe Base64 encoder."""

@classmethod
def decode(cls, data: bytes) -> bytes:
"""Decode the data from base64 encoded bytes to original bytes data.

Args:
data: The data to decode.

Returns:
The decoded data.
"""
try:
return base64.urlsafe_b64decode(data)
except ValueError as e:
raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)})

@classmethod
def encode(cls, value: bytes) -> bytes:
"""Encode the data from bytes to a base64 encoded bytes.

Args:
value: The data to encode.

Returns:
The encoded data.
"""
return base64.urlsafe_b64encode(value)

@classmethod
def get_json_format(cls) -> Literal['base64url']:
"""Get the JSON format for the encoded data.

Returns:
The JSON format for the encoded data.
"""
return 'base64url'


@_dataclasses.dataclass(**_internal_dataclass.slots_true)
class EncodedBytes:
"""A bytes type that is encoded and decoded using the specified encoder."""
Expand Down Expand Up @@ -1356,9 +1398,13 @@ def encode_str(self, value: str) -> str:


Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)]
"""A bytes type that is encoded and decoded using the base64 encoder."""
"""A bytes type that is encoded and decoded using the standard (non-URL-safe) base64 encoder."""
Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)]
"""A str type that is encoded and decoded using the base64 encoder."""
"""A str type that is encoded and decoded using the standard (non-URL-safe) base64 encoder."""
Base64UrlBytes = Annotated[bytes, EncodedBytes(encoder=Base64UrlEncoder)]
"""A bytes type that is encoded and decoded using the URL-safe base64 encoder."""
Base64UrlStr = Annotated[str, EncodedStr(encoder=Base64UrlEncoder)]
"""A str type that is encoded and decoded using the URL-safe base64 encoder."""


__getattr__ = getattr_migration(__name__)
Expand Down
102 changes: 102 additions & 0 deletions tests/test_types.py
Expand Up @@ -50,6 +50,8 @@
AwareDatetime,
Base64Bytes,
Base64Str,
Base64UrlBytes,
Base64UrlStr,
BaseModel,
ByteSize,
ConfigDict,
Expand Down Expand Up @@ -4882,6 +4884,13 @@ class Model(BaseModel):
pytest.param(
Base64Str, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-bytearray-input'
),
pytest.param(
Base64Bytes,
b'BCq+6+1/Paun/Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq+6+1/Paun/Q==\n',
id='Base64Bytes-bytes-alphabet-vanilla',
),
],
)
def test_base64(field_type, input_data, expected_value, serialized_data):
Expand Down Expand Up @@ -4946,6 +4955,99 @@ class Model(BaseModel):
]


@pytest.mark.parametrize(
('field_type', 'input_data', 'expected_value', 'serialized_data'),
[
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==\n', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-reversible'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg==\n', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-reversible'),
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytes-input'),
pytest.param(Base64UrlBytes, 'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-str-input'),
pytest.param(
Base64UrlBytes, bytearray(b'Zm9vIGJhcg=='), b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytearray-input'
),
pytest.param(Base64UrlStr, b'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytes-input'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-str-input'),
pytest.param(
Base64UrlStr, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytearray-input'
),
pytest.param(
Base64UrlBytes,
b'BCq-6-1_Paun_Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq-6-1_Paun_Q==',
id='Base64UrlBytes-bytes-alphabet-url',
),
pytest.param(
Base64UrlBytes,
b'BCq+6+1/Paun/Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq-6-1_Paun_Q==',
id='Base64UrlBytes-bytes-alphabet-vanilla',
),
],
)
def test_base64url(field_type, input_data, expected_value, serialized_data):
class Model(BaseModel):
base64url_value: field_type
base64url_value_or_none: Optional[field_type] = None

m = Model(base64url_value=input_data)
assert m.base64url_value == expected_value

m = Model.model_construct(base64url_value=expected_value)
assert m.base64url_value == expected_value

assert m.model_dump() == {
'base64url_value': serialized_data,
'base64url_value_or_none': None,
}

assert Model.model_json_schema() == {
'properties': {
'base64url_value': {
'format': 'base64url',
'title': 'Base64Url Value',
'type': 'string',
},
'base64url_value_or_none': {
'anyOf': [{'type': 'string', 'format': 'base64url'}, {'type': 'null'}],
'default': None,
'title': 'Base64Url Value Or None',
},
},
'required': ['base64url_value'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
('field_type', 'input_data'),
[
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-bytes'),
pytest.param(Base64UrlBytes, 'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-str'),
pytest.param(Base64UrlStr, b'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-bytes'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-str'),
],
)
def test_base64url_invalid(field_type, input_data):
class Model(BaseModel):
base64url_value: field_type

with pytest.raises(ValidationError) as e:
Model(base64url_value=input_data)

assert e.value.errors(include_url=False) == [
{
'ctx': {'error': 'Incorrect padding'},
'input': input_data,
'loc': ('base64url_value',),
'msg': "Base64 decoding error: 'Incorrect padding'",
'type': 'base64_decode',
},
]


def test_sequence_subclass_without_core_schema() -> None:
class MyList(List[int]):
# The point of this is that subclasses can do arbitrary things
Expand Down