Skip to content

Commit

Permalink
Add Base64Url types (#7286)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Aug 29, 2023
1 parent 2575e71 commit 84282ef
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 4 deletions.
45 changes: 44 additions & 1 deletion docs/usage/types/encoded.md
Expand Up @@ -86,7 +86,8 @@ except ValidationError as e:

Internally, Pydantic uses the [`EncodedBytes`][pydantic.types.EncodedBytes] and [`EncodedStr`][pydantic.types.EncodedStr]
annotations with [`Base64Encoder`][pydantic.types.Base64Encoder] to implement base64 encoding/decoding in the
[`Base64Bytes`][pydantic.types.Base64Bytes] and [`Base64Str`][pydantic.types.Base64Str] types, respectively.
[`Base64Bytes`][pydantic.types.Base64Bytes], [`Base64UrlBytes`][pydantic.types.Base64UrlBytes],
[`Base64Str`][pydantic.types.Base64Str], and [`Base64UrlStr`][pydantic.types.Base64Str] types.

```py
from typing import Optional
Expand Down Expand Up @@ -131,3 +132,45 @@ except ValidationError as e:
Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value=b'undecodable', input_type=bytes]
"""
```

If you need url-safe base64 encoding, you can use the `Base64UrlBytes` and `Base64UrlStr` types. The following snippet
demonstrates the difference in alphabets used by the url-safe and non-url-safe encodings:

```py
from pydantic import (
Base64Bytes,
Base64Str,
Base64UrlBytes,
Base64UrlStr,
BaseModel,
)


class Model(BaseModel):
base64_bytes: Base64Bytes
base64_str: Base64Str
base64url_bytes: Base64UrlBytes
base64url_str: Base64UrlStr


# Initialize the model with base64 data
m = Model(
base64_bytes=b'SHc/dHc+TXc==',
base64_str='SHc/dHc+TXc==',
base64url_bytes=b'SHc_dHc-TXc==',
base64url_str='SHc_dHc-TXc==',
)
print(m)
"""
base64_bytes=b'Hw?tw>Mw' base64_str='Hw?tw>Mw' base64url_bytes=b'Hw?tw>Mw' base64url_str='Hw?tw>Mw'
"""
```

!!! note
Under the hood, `Base64Bytes` and `Base64Str` use the standard library `base64.encodebytes` and `base64.decodebytes`
functions, while `Base64UrlBytes` and `Base64UrlStr` use the `base64.urlsafe_b64encode` and
`base64.urlsafe_b64decode` functions.

As a result, the `Base64UrlBytes` and `Base64UrlStr` types can be used to faithfully decode "vanilla" base64 data
(using `'+'` and `'/'`), but the reverse is not true — attempting to decode url-safe base64 data using the
`Base64Bytes` and `Base64Str` types may fail or produce an incorrect decoding.
2 changes: 2 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -179,6 +179,8 @@
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'SkipValidation',
'InstanceOf',
'WithJsonSchema',
Expand Down
52 changes: 49 additions & 3 deletions pydantic/types.py
Expand Up @@ -96,6 +96,8 @@
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'StringConstraints',
)
Expand Down Expand Up @@ -1233,7 +1235,7 @@ def get_json_format(cls) -> str:


class Base64Encoder(EncoderProtocol):
"""Base64 encoder."""
"""Standard (non-URL-safe) Base64 encoder."""

@classmethod
def decode(cls, data: bytes) -> bytes:
Expand Down Expand Up @@ -1272,6 +1274,46 @@ def get_json_format(cls) -> Literal['base64']:
return 'base64'


class Base64UrlEncoder(EncoderProtocol):
"""URL-safe Base64 encoder."""

@classmethod
def decode(cls, data: bytes) -> bytes:
"""Decode the data from base64 encoded bytes to original bytes data.
Args:
data: The data to decode.
Returns:
The decoded data.
"""
try:
return base64.urlsafe_b64decode(data)
except ValueError as e:
raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)})

@classmethod
def encode(cls, value: bytes) -> bytes:
"""Encode the data from bytes to a base64 encoded bytes.
Args:
value: The data to encode.
Returns:
The encoded data.
"""
return base64.urlsafe_b64encode(value)

@classmethod
def get_json_format(cls) -> Literal['base64url']:
"""Get the JSON format for the encoded data.
Returns:
The JSON format for the encoded data.
"""
return 'base64url'


@_dataclasses.dataclass(**_internal_dataclass.slots_true)
class EncodedBytes:
"""A bytes type that is encoded and decoded using the specified encoder."""
Expand Down Expand Up @@ -1356,9 +1398,13 @@ def encode_str(self, value: str) -> str:


Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)]
"""A bytes type that is encoded and decoded using the base64 encoder."""
"""A bytes type that is encoded and decoded using the standard (non-URL-safe) base64 encoder."""
Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)]
"""A str type that is encoded and decoded using the base64 encoder."""
"""A str type that is encoded and decoded using the standard (non-URL-safe) base64 encoder."""
Base64UrlBytes = Annotated[bytes, EncodedBytes(encoder=Base64UrlEncoder)]
"""A bytes type that is encoded and decoded using the URL-safe base64 encoder."""
Base64UrlStr = Annotated[str, EncodedStr(encoder=Base64UrlEncoder)]
"""A str type that is encoded and decoded using the URL-safe base64 encoder."""


__getattr__ = getattr_migration(__name__)
Expand Down
102 changes: 102 additions & 0 deletions tests/test_types.py
Expand Up @@ -50,6 +50,8 @@
AwareDatetime,
Base64Bytes,
Base64Str,
Base64UrlBytes,
Base64UrlStr,
BaseModel,
ByteSize,
ConfigDict,
Expand Down Expand Up @@ -4882,6 +4884,13 @@ class Model(BaseModel):
pytest.param(
Base64Str, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==\n', id='Base64Str-bytearray-input'
),
pytest.param(
Base64Bytes,
b'BCq+6+1/Paun/Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq+6+1/Paun/Q==\n',
id='Base64Bytes-bytes-alphabet-vanilla',
),
],
)
def test_base64(field_type, input_data, expected_value, serialized_data):
Expand Down Expand Up @@ -4946,6 +4955,99 @@ class Model(BaseModel):
]


@pytest.mark.parametrize(
('field_type', 'input_data', 'expected_value', 'serialized_data'),
[
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==\n', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-reversible'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg==\n', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-reversible'),
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytes-input'),
pytest.param(Base64UrlBytes, 'Zm9vIGJhcg==', b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-str-input'),
pytest.param(
Base64UrlBytes, bytearray(b'Zm9vIGJhcg=='), b'foo bar', b'Zm9vIGJhcg==', id='Base64UrlBytes-bytearray-input'
),
pytest.param(Base64UrlStr, b'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytes-input'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg==', 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-str-input'),
pytest.param(
Base64UrlStr, bytearray(b'Zm9vIGJhcg=='), 'foo bar', 'Zm9vIGJhcg==', id='Base64UrlStr-bytearray-input'
),
pytest.param(
Base64UrlBytes,
b'BCq-6-1_Paun_Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq-6-1_Paun_Q==',
id='Base64UrlBytes-bytes-alphabet-url',
),
pytest.param(
Base64UrlBytes,
b'BCq+6+1/Paun/Q==',
b'\x04*\xbe\xeb\xed\x7f=\xab\xa7\xfd',
b'BCq-6-1_Paun_Q==',
id='Base64UrlBytes-bytes-alphabet-vanilla',
),
],
)
def test_base64url(field_type, input_data, expected_value, serialized_data):
class Model(BaseModel):
base64url_value: field_type
base64url_value_or_none: Optional[field_type] = None

m = Model(base64url_value=input_data)
assert m.base64url_value == expected_value

m = Model.model_construct(base64url_value=expected_value)
assert m.base64url_value == expected_value

assert m.model_dump() == {
'base64url_value': serialized_data,
'base64url_value_or_none': None,
}

assert Model.model_json_schema() == {
'properties': {
'base64url_value': {
'format': 'base64url',
'title': 'Base64Url Value',
'type': 'string',
},
'base64url_value_or_none': {
'anyOf': [{'type': 'string', 'format': 'base64url'}, {'type': 'null'}],
'default': None,
'title': 'Base64Url Value Or None',
},
},
'required': ['base64url_value'],
'title': 'Model',
'type': 'object',
}


@pytest.mark.parametrize(
('field_type', 'input_data'),
[
pytest.param(Base64UrlBytes, b'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-bytes'),
pytest.param(Base64UrlBytes, 'Zm9vIGJhcg', id='Base64UrlBytes-invalid-base64-str'),
pytest.param(Base64UrlStr, b'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-bytes'),
pytest.param(Base64UrlStr, 'Zm9vIGJhcg', id='Base64UrlStr-invalid-base64-str'),
],
)
def test_base64url_invalid(field_type, input_data):
class Model(BaseModel):
base64url_value: field_type

with pytest.raises(ValidationError) as e:
Model(base64url_value=input_data)

assert e.value.errors(include_url=False) == [
{
'ctx': {'error': 'Incorrect padding'},
'input': input_data,
'loc': ('base64url_value',),
'msg': "Base64 decoding error: 'Incorrect padding'",
'type': 'base64_decode',
},
]


def test_sequence_subclass_without_core_schema() -> None:
class MyList(List[int]):
# The point of this is that subclasses can do arbitrary things
Expand Down

0 comments on commit 84282ef

Please sign in to comment.