Skip to content

Commit

Permalink
Support json schema generation for Enum types with no cases. (#7927)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
sydney-runkle and dmontagu committed Oct 25, 2023
1 parent cf213a7 commit 026be6a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
37 changes: 19 additions & 18 deletions pydantic/_internal/_std_types_schema.py
Expand Up @@ -59,13 +59,27 @@ def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchem
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())

enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}

def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema

if not cases:
# Use an isinstance check for enums with no cases.
# This won't work with serialization or JSON schema, but that's okay -- the most important
# use case for this is creating typevar bounds for generics that should be restricted to enums.
# This is more consistent than it might seem at first, since you can only subclass enum.Enum
# (or subclasses of enum.Enum) if all parent classes have no cases.
return core_schema.is_instance_schema(enum_type)
# The most important use case for this is creating TypeVar bounds for generics that should
# be restricted to enums. This is more consistent than it might seem at first, since you can only
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
# We use the get_json_schema function when an Enum subclass has been declared with no cases
# so that we can still generate a valid json schema.
return core_schema.is_instance_schema(enum_type, metadata={'pydantic_js_functions': [get_json_schema]})

use_enum_values = config.get('use_enum_values', False)

Expand All @@ -84,19 +98,6 @@ def to_enum(__input_value: Any) -> Enum:
# The type: ignore on the next line is to ignore the requirement of LiteralString
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore

enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}

def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema

strict_python_schema = core_schema.is_instance_schema(enum_type)
if use_enum_values:
strict_python_schema = core_schema.chain_schema(
Expand Down
23 changes: 11 additions & 12 deletions tests/test_types.py
Expand Up @@ -1621,12 +1621,6 @@ class MyEnum(Enum):
}
]

with pytest.raises(
PydanticInvalidForJsonSchema,
match=re.escape("Cannot generate a JsonSchema for core_schema.IsInstanceSchema (<enum 'Enum'>)"),
):
Model.model_json_schema()


def test_int_enum_type():
class Model(BaseModel):
Expand Down Expand Up @@ -1655,12 +1649,6 @@ class MyIntEnum(IntEnum):
}
]

with pytest.raises(
PydanticInvalidForJsonSchema,
match=re.escape("Cannot generate a JsonSchema for core_schema.IsInstanceSchema (<enum 'IntEnum'>)"),
):
Model.model_json_schema()


@pytest.mark.parametrize('enum_base,strict', [(Enum, False), (IntEnum, False), (IntEnum, True)])
def test_enum_from_json(enum_base, strict):
Expand Down Expand Up @@ -1723,6 +1711,17 @@ class User(BaseModel):
User(demo_strict=0, demo_not_strict=1)


def test_enum_with_no_cases() -> None:
class MyEnum(Enum):
pass

class MyModel(BaseModel):
e: MyEnum

json_schema = MyModel.model_json_schema()
assert json_schema['properties']['e']['enum'] == []


@pytest.mark.parametrize(
'kwargs,type_',
[
Expand Down

0 comments on commit 026be6a

Please sign in to comment.