Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool for generating schema type annotations #7186

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 4 additions & 1 deletion edb/schema/annos.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
from . import referencing
from . import objects as so
from . import utils
from .generated import annos as sg_annos

if TYPE_CHECKING:
from . import schema as s_schema


class AnnotationValue(
sg_annos.AnnotationValueMixin,
referencing.ReferencedInheritingObject,
qlkind=qltypes.SchemaObjectClass.ANNOTATION,
reflection=so.ReflectionMethod.AS_LINK,
Expand Down Expand Up @@ -101,7 +103,7 @@ def get_verbosename(
T = TypeVar("T")


class AnnotationSubject(so.Object):
class AnnotationSubject(so.Object, sg_annos.AnnotationSubjectMixin):

annotations_refs = so.RefDict(
attr='annotations',
Expand Down Expand Up @@ -177,6 +179,7 @@ def must_get_json_annotation(


class Annotation(
sg_annos.AnnotationMixin,
so.QualifiedObject,
so.InheritingObject,
AnnotationSubject,
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import types as s_types
from . import schema as s_schema
from . import utils
from .generated import casts as sg_casts


_NOT_REACHABLE = 10000000
Expand Down Expand Up @@ -206,6 +207,7 @@ def get_cast_fullname(


class Cast(
sg_casts.CastMixin,
so.QualifiedObject,
s_anno.AnnotationSubject,
s_func.VolatilitySubject,
Expand Down
6 changes: 5 additions & 1 deletion edb/schema/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from . import pseudo as s_pseudo
from . import referencing
from . import utils
from .generated import constraints as sg_constraints


if TYPE_CHECKING:
Expand Down Expand Up @@ -122,8 +123,10 @@ def get_key_for_name(


class Constraint(
sg_constraints.ConstraintMixin,
referencing.ReferencedInheritingObject,
s_func.CallableObject, s_abc.Constraint,
s_func.CallableObject,
s_abc.Constraint,
qlkind=ft.SchemaObjectClass.CONSTRAINT,
data_safe=True,
):
Expand Down Expand Up @@ -364,6 +367,7 @@ def get_default_base_name(self) -> sn.QualName:


class ConsistencySubject(
sg_constraints.ConsistencySubjectMixin,
so.QualifiedObject,
so.InheritingObject,
s_anno.AnnotationSubject,
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from . import delta as sd
from . import objects as so
from . import schema as s_schema
from .generated import database as sg_database

from typing import cast


class Database(
sg_database.DatabaseMixin,
so.ExternalObject,
s_anno.AnnotationSubject,
s_abc.Database,
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/expraliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import name as sn
from . import objects as so
from . import types as s_types
from .generated import expraliases as sg_expraliases


if TYPE_CHECKING:
Expand All @@ -42,6 +43,7 @@


class Alias(
sg_expraliases.AliasMixin,
so.QualifiedObject,
s_anno.AnnotationSubject,
qlkind=qltypes.SchemaObjectClass.ALIAS,
Expand Down
3 changes: 3 additions & 0 deletions edb/schema/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
from . import name as sn
from . import objects as so
from . import schema as s_schema
from .generated import extensions as sg_extensions


class ExtensionPackage(
sg_extensions.ExtensionPackageMixin,
so.GlobalObject,
s_anno.AnnotationSubject,
qlkind=qltypes.SchemaObjectClass.EXTENSION_PACKAGE,
Expand Down Expand Up @@ -102,6 +104,7 @@ def get_displayname_static(cls, name: sn.Name) -> str:


class Extension(
sg_extensions.ExtensionMixin,
so.Object,
qlkind=qltypes.SchemaObjectClass.EXTENSION,
data_safe=False,
Expand Down
37 changes: 10 additions & 27 deletions edb/schema/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from . import referencing
from . import types as s_types
from . import utils
from .generated import functions as sg_functions


if TYPE_CHECKING:
Expand Down Expand Up @@ -151,32 +152,8 @@ def param_is_inherited(
return qualname != param_name.name


class ParameterLike(s_abc.Parameter):

def get_parameter_name(self, schema: s_schema.Schema) -> str:
raise NotImplementedError

def get_name(self, schema: s_schema.Schema) -> sn.Name:
raise NotImplementedError

def get_kind(self, _: s_schema.Schema) -> ft.ParameterKind:
raise NotImplementedError

def get_default(self, _: s_schema.Schema) -> Optional[s_expr.Expression]:
raise NotImplementedError

def get_type(self, _: s_schema.Schema) -> s_types.Type:
raise NotImplementedError

def get_typemod(self, _: s_schema.Schema) -> ft.TypeModifier:
raise NotImplementedError

def as_str(self, schema: s_schema.Schema) -> str:
raise NotImplementedError


# Non-schema description of a parameter.
class ParameterDesc(ParameterLike):
class ParameterDesc(s_abc.Parameter):

num: int
name: sn.Name
Expand Down Expand Up @@ -364,9 +341,10 @@ def make_func_param(


class Parameter(
sg_functions.ParameterMixin,
so.ObjectFragment,
so.Object, # Help reflection figure out the right db MRO
ParameterLike,
s_abc.Parameter,
qlkind=ft.SchemaObjectClass.PARAMETER,
data_safe=True,
):
Expand Down Expand Up @@ -489,6 +467,9 @@ def get_ast(self, schema: s_schema.Schema) -> qlast.FuncParam:
)


ParameterLike = ParameterDesc | Parameter


class CallableCommandContext(sd.ObjectCommandContext['CallableObject'],
s_anno.AnnotationSubjectCommandContext):
pass
Expand Down Expand Up @@ -756,7 +737,7 @@ def compare_values(
return 1.0


class VolatilitySubject(so.Object):
class VolatilitySubject(so.Object, sg_functions.VolatilitySubjectMixin):

volatility = so.SchemaField(
ft.Volatility, default=ft.Volatility.Volatile,
Expand Down Expand Up @@ -792,6 +773,7 @@ def get_abstract(self, schema: s_schema.Schema) -> bool:


class CallableObject(
sg_functions.CallableObjectMixin,
so.QualifiedObject,
s_anno.AnnotationSubject,
CallableLike,
Expand Down Expand Up @@ -1226,6 +1208,7 @@ def _delete_begin(


class Function(
sg_functions.FunctionMixin,
CallableObject,
VolatilitySubject,
s_abc.Function,
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from . import objects as so
from . import name as sn
from . import schema as s_schema
from .generated import futures as sg_futures


class FutureBehavior(
sg_futures.FutureBehaviorMixin,
so.Object,
qlkind=qltypes.SchemaObjectClass.FUTURE,
data_safe=False,
Expand Down
57 changes: 57 additions & 0 deletions edb/schema/generated/annos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# DO NOT EDIT. This file was generated with:
#
# $ edb gen-schema-mixins

"""Type definitions for generated methods on schema classes"""

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from edb.schema import schema as s_schema
from edb.schema import orm as s_orm
from edb.schema import objects
from edb.schema import annos


class AnnotationValueMixin:

def get_subject(
self, schema: 's_schema.Schema'
) -> 'objects.Object':
return s_orm.get_field_value( # type: ignore
self, schema, 'subject' # type: ignore
)

def get_annotation(
self, schema: 's_schema.Schema'
) -> 'annos.Annotation':
return s_orm.get_field_value( # type: ignore
self, schema, 'annotation' # type: ignore
)

def get_value(
self, schema: 's_schema.Schema'
) -> 'str':
return s_orm.get_field_value( # type: ignore
self, schema, 'value' # type: ignore
)


class AnnotationSubjectMixin:

def get_annotations(
self, schema: 's_schema.Schema'
) -> 'objects.ObjectIndexByShortname[annos.AnnotationValue]':
return s_orm.get_field_value( # type: ignore
self, schema, 'annotations' # type: ignore
)


class AnnotationMixin:

def get_inheritable(
self, schema: 's_schema.Schema'
) -> 'bool':
return s_orm.get_field_value( # type: ignore
self, schema, 'inheritable' # type: ignore
)
79 changes: 79 additions & 0 deletions edb/schema/generated/casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# DO NOT EDIT. This file was generated with:
#
# $ edb gen-schema-mixins

"""Type definitions for generated methods on schema classes"""

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from edb.schema import schema as s_schema
from edb.schema import orm as s_orm
from edb.edgeql import ast
from edb.schema import types


class CastMixin:

def get_from_type(
self, schema: 's_schema.Schema'
) -> 'types.Type':
return s_orm.get_field_value( # type: ignore
self, schema, 'from_type' # type: ignore
)

def get_to_type(
self, schema: 's_schema.Schema'
) -> 'types.Type':
return s_orm.get_field_value( # type: ignore
self, schema, 'to_type' # type: ignore
)

def get_allow_implicit(
self, schema: 's_schema.Schema'
) -> 'bool':
return s_orm.get_field_value( # type: ignore
self, schema, 'allow_implicit' # type: ignore
)

def get_allow_assignment(
self, schema: 's_schema.Schema'
) -> 'bool':
return s_orm.get_field_value( # type: ignore
self, schema, 'allow_assignment' # type: ignore
)

def get_language(
self, schema: 's_schema.Schema'
) -> 'ast.Language':
return s_orm.get_field_value( # type: ignore
self, schema, 'language' # type: ignore
)

def get_from_function(
self, schema: 's_schema.Schema'
) -> 'str':
return s_orm.get_field_value( # type: ignore
self, schema, 'from_function' # type: ignore
)

def get_from_expr(
self, schema: 's_schema.Schema'
) -> 'bool':
return s_orm.get_field_value( # type: ignore
self, schema, 'from_expr' # type: ignore
)

def get_from_cast(
self, schema: 's_schema.Schema'
) -> 'bool':
return s_orm.get_field_value( # type: ignore
self, schema, 'from_cast' # type: ignore
)

def get_code(
self, schema: 's_schema.Schema'
) -> 'str':
return s_orm.get_field_value( # type: ignore
self, schema, 'code' # type: ignore
)