Skip to content

Commit

Permalink
Add Secret base type (#8519)
Browse files Browse the repository at this point in the history
Co-authored-by: sydney-runkle <sydneymarierunkle@gmail.com>
Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 9, 2024
1 parent 5463055 commit 1c91c86
Show file tree
Hide file tree
Showing 5 changed files with 433 additions and 42 deletions.
3 changes: 3 additions & 0 deletions docs/api/types.md
@@ -1 +1,4 @@
::: pydantic.types
options:
show_root_heading: true
merge_init_into_class: false
100 changes: 100 additions & 0 deletions docs/examples/secrets.md
Expand Up @@ -37,3 +37,103 @@ print(model.model_dump())
print(model.model_dump_json())
#> {"password":"IAmSensitive","password_bytes":"IAmSensitiveBytes"}
```

## Create your own Secret field

Pydantic provides the generic `Secret` class as a mechanism for creating custom secret types.

??? api "API Documentation"
[`pydantic.types.Secret`][pydantic.types.Secret]<br>

Pydantic provides the generic `Secret` class as a mechanism for creating custom secret types.
You can either directly parametrize `Secret`, or subclass from a parametrized `Secret` to customize the `str()` and `repr()` of a secret type.

```py
from datetime import date

from pydantic import BaseModel, Secret

# Using the default representation
SecretDate = Secret[date]


# Overwriting the representation
class SecretSalary(Secret[float]):
def _display(self) -> str:
return '$****.**'


class Employee(BaseModel):
date_of_birth: SecretDate
salary: SecretSalary


employee = Employee(date_of_birth='1990-01-01', salary=42)

print(employee)
#> date_of_birth=Secret('**********') salary=SecretSalary('$****.**')

print(employee.salary)
#> $****.**

print(employee.salary.get_secret_value())
#> 42.0

print(employee.date_of_birth)
#> **********

print(employee.date_of_birth.get_secret_value())
#> 1990-01-01
```

You can enforce constraints on the underlying type through annotations:
For example:

```py
from typing_extensions import Annotated

from pydantic import BaseModel, Field, Secret, ValidationError

SecretPosInt = Secret[Annotated[int, Field(gt=0, strict=True)]]


class Model(BaseModel):
sensitive_int: SecretPosInt


m = Model(sensitive_int=42)
print(m.model_dump())
#> {'sensitive_int': Secret('**********')}

try:
m = Model(sensitive_int=-42) # (1)!
except ValidationError as exc_info:
print(exc_info.errors(include_url=False, include_input=False))
"""
[
{
'type': 'greater_than',
'loc': ('sensitive_int',),
'msg': 'Input should be greater than 0',
'ctx': {'gt': 0},
}
]
"""

try:
m = Model(sensitive_int='42') # (2)!
except ValidationError as exc_info:
print(exc_info.errors(include_url=False, include_input=False))
"""
[
{
'type': 'int_type',
'loc': ('sensitive_int',),
'msg': 'Input should be a valid integer',
}
]
"""
```

1. The input value is not greater than 0, so it raises a validation error.
2. The input value is not an integer, so it raises a validation error because the `SecretPosInt` type has strict mode enabled.
2 changes: 2 additions & 0 deletions pydantic/__init__.py
Expand Up @@ -162,6 +162,7 @@
'DirectoryPath',
'NewPath',
'Json',
'Secret',
'SecretStr',
'SecretBytes',
'StrictBool',
Expand Down Expand Up @@ -310,6 +311,7 @@
'DirectoryPath': (__package__, '.types'),
'NewPath': (__package__, '.types'),
'Json': (__package__, '.types'),
'Secret': (__package__, '.types'),
'SecretStr': (__package__, '.types'),
'SecretBytes': (__package__, '.types'),
'StrictBool': (__package__, '.types'),
Expand Down
191 changes: 152 additions & 39 deletions pydantic/types.py
Expand Up @@ -24,6 +24,8 @@
TypeVar,
Union,
cast,
get_args,
get_origin,
)
from uuid import UUID

Expand Down Expand Up @@ -75,6 +77,7 @@
'DirectoryPath',
'NewPath',
'Json',
'Secret',
'SecretStr',
'SecretBytes',
'StrictBool',
Expand Down Expand Up @@ -1332,7 +1335,8 @@ class Model(BaseModel):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

if TYPE_CHECKING:
Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str]
# Json[list[str]] will be recognized by type checkers as list[str]
Json = Annotated[AnyType, ...]

else:

Expand Down Expand Up @@ -1439,10 +1443,10 @@ def __eq__(self, other: Any) -> bool:

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

SecretType = TypeVar('SecretType', str, bytes)
SecretType = TypeVar('SecretType')


class _SecretField(Generic[SecretType]):
class _SecretBase(Generic[SecretType]):
def __init__(self, secret_value: SecretType) -> None:
self._secret_value: SecretType = secret_value

Expand All @@ -1460,29 +1464,124 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.get_secret_value())

def __len__(self) -> int:
return len(self._secret_value)

def __str__(self) -> str:
return str(self._display())

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self._display()!r})'

def _display(self) -> SecretType:
def _display(self) -> str | bytes:
raise NotImplementedError


class Secret(_SecretBase[SecretType]):
"""A generic base class used for defining a field with sensitive information that you do not want to be visible in logging or tracebacks.
You may either directly parametrize `Secret` with a type, or subclass from `Secret` with a parametrized type. The benefit of subclassing
is that you can define a custom `_display` method, which will be used for `repr()` and `str()` methods. The examples below demonstrate both
ways of using `Secret` to create a new secret type.
1. Directly parametrizing `Secret` with a type:
```py
from pydantic import BaseModel, Secret
SecretBool = Secret[bool]
class Model(BaseModel):
secret_bool: SecretBool
m = Model(secret_bool=True)
print(m.model_dump())
#> {'secret_bool': Secret('**********')}
print(m.model_dump_json())
#> {"secret_bool":"**********"}
print(m.secret_bool.get_secret_value())
#> True
```
2. Subclassing from parametrized `Secret`:
```py
from datetime import date
from pydantic import BaseModel, Secret
class SecretDate(Secret[date]):
def _display(self) -> str:
return '****/**/**'
class Model(BaseModel):
secret_date: SecretDate
m = Model(secret_date=date(2022, 1, 1))
print(m.model_dump())
#> {'secret_date': SecretDate('****/**/**')}
print(m.model_dump_json())
#> {"secret_date":"****/**/**"}
print(m.secret_date.get_secret_value())
#> 2022-01-01
```
The value returned by the `_display` method will be used for `repr()` and `str()`.
"""

def _display(self) -> str | bytes:
return '**********' if self.get_secret_value() else ''

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
if issubclass(source, SecretStr):
field_type = str
inner_schema = core_schema.str_schema()
inner_type = None
# if origin_type is Secret, then cls is a GenericAlias, and we can extract the inner type directly
origin_type = get_origin(source)
if origin_type is not None:
inner_type = get_args(source)[0]
# otherwise, we need to get the inner type from the base class
else:
assert issubclass(source, SecretBytes)
field_type = bytes
inner_schema = core_schema.bytes_schema()
error_kind = 'string_type' if field_type is str else 'bytes_type'
bases = getattr(cls, '__orig_bases__', getattr(cls, '__bases__', []))
for base in bases:
if get_origin(base) is Secret:
inner_type = get_args(base)[0]
if bases == [] or inner_type is None:
raise TypeError(
f"Can't get secret type from {cls.__name__}. "
'Please use Secret[<type>], or subclass from Secret[<type>] instead.'
)

inner_schema = handler.generate_schema(inner_type) # type: ignore

def validate_secret_value(value, handler) -> Secret[SecretType]:
if isinstance(value, Secret):
value = value.get_secret_value()
validated_inner = handler(value)
return cls(validated_inner)

return core_schema.json_or_python_schema(
python_schema=core_schema.no_info_wrap_validator_function(
validate_secret_value,
inner_schema,
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x),
),
json_schema=core_schema.no_info_after_validator_function(
lambda x: cls(x), inner_schema, serialization=core_schema.to_string_ser_schema(when_used='json')
),
)


def _secret_display(value: SecretType) -> str: # type: ignore
return '**********' if value else ''


class _SecretField(_SecretBase[SecretType]):
_inner_schema: ClassVar[CoreSchema]
_error_kind: ClassVar[str]

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def serialize(
value: _SecretField[SecretType], info: core_schema.SerializationInfo
) -> str | _SecretField[SecretType]:
Expand All @@ -1494,7 +1593,7 @@ def serialize(
return value

def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(inner_schema)
json_schema = handler(cls._inner_schema)
_utils.update_not_none(
json_schema,
type='string',
Expand All @@ -1505,31 +1604,33 @@ def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchema

json_schema = core_schema.no_info_after_validator_function(
source, # construct the type
inner_schema,
cls._inner_schema,
)
s = core_schema.json_or_python_schema(
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(source),
json_schema,
],
strict=True,
custom_error_type=error_kind,
),
json_schema=json_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
serialize,
info_arg=True,
return_schema=core_schema.str_schema(),
when_used='json',
),
)
s.setdefault('metadata', {}).setdefault('pydantic_js_functions', []).append(get_json_schema)
return s

def get_secret_schema(strict: bool) -> CoreSchema:
return core_schema.json_or_python_schema(
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(source),
json_schema,
],
custom_error_type=cls._error_kind,
strict=strict,
),
json_schema=json_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
serialize,
info_arg=True,
return_schema=core_schema.str_schema(),
when_used='json',
),
)

def _secret_display(value: str | bytes) -> str:
return '**********' if value else ''
return core_schema.lax_or_strict_schema(
lax_schema=get_secret_schema(strict=False),
strict_schema=get_secret_schema(strict=True),
metadata={'pydantic_js_functions': [get_json_schema]},
)


class SecretStr(_SecretField[str]):
Expand All @@ -1556,8 +1657,14 @@ class User(BaseModel):
```
"""

_inner_schema: ClassVar[CoreSchema] = core_schema.str_schema()
_error_kind: ClassVar[str] = 'string_type'

def __len__(self) -> int:
return len(self._secret_value)

def _display(self) -> str:
return _secret_display(self.get_secret_value())
return _secret_display(self._secret_value)


class SecretBytes(_SecretField[bytes]):
Expand All @@ -1583,8 +1690,14 @@ class User(BaseModel):
```
"""

_inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema()
_error_kind: ClassVar[str] = 'bytes_type'

def __len__(self) -> int:
return len(self._secret_value)

def _display(self) -> bytes:
return _secret_display(self.get_secret_value()).encode()
return _secret_display(self._secret_value).encode()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down

0 comments on commit 1c91c86

Please sign in to comment.