Skip to content

Commit

Permalink
Generate mixins instead of type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Apr 30, 2024
1 parent e7361dc commit a1ca831
Show file tree
Hide file tree
Showing 59 changed files with 1,834 additions and 172 deletions.
5 changes: 4 additions & 1 deletion edb/schema/annos.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
from . import referencing
from . import objects as so
from . import utils
from .generated import annos as sg_annos

if TYPE_CHECKING:
from . import schema as s_schema


class AnnotationValue(
sg_annos.AnnotationValueMixin,
referencing.ReferencedInheritingObject,
qlkind=qltypes.SchemaObjectClass.ANNOTATION,
reflection=so.ReflectionMethod.AS_LINK,
Expand Down Expand Up @@ -101,7 +103,7 @@ def get_verbosename(
T = TypeVar("T")


class AnnotationSubject(so.Object):
class AnnotationSubject(so.Object, sg_annos.AnnotationSubjectMixin):

annotations_refs = so.RefDict(
attr='annotations',
Expand Down Expand Up @@ -180,6 +182,7 @@ class Annotation(
so.QualifiedObject,
so.InheritingObject,
AnnotationSubject,
sg_annos.AnnotationMixin,
qlkind=qltypes.SchemaObjectClass.ANNOTATION,
data_safe=True,
):
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import types as s_types
from . import schema as s_schema
from . import utils
from .generated import casts as sg_casts


_NOT_REACHABLE = 10000000
Expand Down Expand Up @@ -210,6 +211,7 @@ class Cast(
s_anno.AnnotationSubject,
s_func.VolatilitySubject,
s_abc.Cast,
sg_casts.CastMixin,
qlkind=qltypes.SchemaObjectClass.CAST,
data_safe=True,
):
Expand Down
6 changes: 5 additions & 1 deletion edb/schema/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from . import pseudo as s_pseudo
from . import referencing
from . import utils
from .generated import constraints as sg_constraints


if TYPE_CHECKING:
Expand Down Expand Up @@ -122,8 +123,10 @@ def get_key_for_name(


class Constraint(
sg_constraints.ConstraintMixin,
referencing.ReferencedInheritingObject,
s_func.CallableObject, s_abc.Constraint,
s_func.CallableObject,
s_abc.Constraint,
qlkind=ft.SchemaObjectClass.CONSTRAINT,
data_safe=True,
):
Expand Down Expand Up @@ -367,6 +370,7 @@ class ConsistencySubject(
so.QualifiedObject,
so.InheritingObject,
s_anno.AnnotationSubject,
sg_constraints.ConsistencySubjectMixin,
):
constraints_refs = so.RefDict(
attr='constraints',
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from . import delta as sd
from . import objects as so
from . import schema as s_schema
from .generated import database as sg_database

from typing import cast

Expand All @@ -40,6 +41,7 @@ class Database(
so.ExternalObject,
s_anno.AnnotationSubject,
s_abc.Database,
sg_database.DatabaseMixin,
qlkind=qltypes.SchemaObjectClass.DATABASE,
data_safe=False,
):
Expand Down
2 changes: 2 additions & 0 deletions edb/schema/expraliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import name as sn
from . import objects as so
from . import types as s_types
from .generated import expraliases as sg_expraliases


if TYPE_CHECKING:
Expand All @@ -42,6 +43,7 @@


class Alias(
sg_expraliases.AliasMixin,
so.QualifiedObject,
s_anno.AnnotationSubject,
qlkind=qltypes.SchemaObjectClass.ALIAS,
Expand Down
4 changes: 3 additions & 1 deletion edb/schema/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@
from . import name as sn
from . import objects as so
from . import schema as s_schema

from .generated import extensions as sg_extensions

class ExtensionPackage(
so.GlobalObject,
s_anno.AnnotationSubject,
sg_extensions.ExtensionPackageMixin,
qlkind=qltypes.SchemaObjectClass.EXTENSION_PACKAGE,
data_safe=False,
):
Expand Down Expand Up @@ -103,6 +104,7 @@ def get_displayname_static(cls, name: sn.Name) -> str:

class Extension(
so.Object,
sg_extensions.ExtensionMixin,
qlkind=qltypes.SchemaObjectClass.EXTENSION,
data_safe=False,
):
Expand Down
37 changes: 10 additions & 27 deletions edb/schema/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from . import referencing
from . import types as s_types
from . import utils
from .generated import functions as sg_functions
from .generated import objects as sg_objects


if TYPE_CHECKING:
Expand Down Expand Up @@ -151,32 +153,8 @@ def param_is_inherited(
return qualname != param_name.name


class ParameterLike(s_abc.Parameter):

def get_parameter_name(self, schema: s_schema.Schema) -> str:
raise NotImplementedError

def get_name(self, schema: s_schema.Schema) -> sn.Name:
raise NotImplementedError

def get_kind(self, _: s_schema.Schema) -> ft.ParameterKind:
raise NotImplementedError

def get_default(self, _: s_schema.Schema) -> Optional[s_expr.Expression]:
raise NotImplementedError

def get_type(self, _: s_schema.Schema) -> s_types.Type:
raise NotImplementedError

def get_typemod(self, _: s_schema.Schema) -> ft.TypeModifier:
raise NotImplementedError

def as_str(self, schema: s_schema.Schema) -> str:
raise NotImplementedError


# Non-schema description of a parameter.
class ParameterDesc(ParameterLike):
class ParameterDesc(s_abc.Parameter):

num: int
name: sn.Name
Expand Down Expand Up @@ -366,7 +344,8 @@ def make_func_param(
class Parameter(
so.ObjectFragment,
so.Object, # Help reflection figure out the right db MRO
ParameterLike,
sg_functions.ParameterMixin,
s_abc.Parameter,
qlkind=ft.SchemaObjectClass.PARAMETER,
data_safe=True,
):
Expand Down Expand Up @@ -489,6 +468,8 @@ def get_ast(self, schema: s_schema.Schema) -> qlast.FuncParam:
)


ParameterLike = ParameterDesc | Parameter

class CallableCommandContext(sd.ObjectCommandContext['CallableObject'],
s_anno.AnnotationSubjectCommandContext):
pass
Expand Down Expand Up @@ -756,7 +737,7 @@ def compare_values(
return 1.0


class VolatilitySubject(so.Object):
class VolatilitySubject(so.Object, sg_functions.VolatilitySubjectMixin):

volatility = so.SchemaField(
ft.Volatility, default=ft.Volatility.Volatile,
Expand Down Expand Up @@ -794,6 +775,7 @@ def get_abstract(self, schema: s_schema.Schema) -> bool:
class CallableObject(
so.QualifiedObject,
s_anno.AnnotationSubject,
sg_functions.CallableObjectMixin,
CallableLike,
):

Expand Down Expand Up @@ -1229,6 +1211,7 @@ class Function(
CallableObject,
VolatilitySubject,
s_abc.Function,
sg_functions.FunctionMixin,
qlkind=ft.SchemaObjectClass.FUNCTION,
data_safe=True,
):
Expand Down
3 changes: 2 additions & 1 deletion edb/schema/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from . import objects as so
from . import name as sn
from . import schema as s_schema

from .generated import futures as sg_futures

class FutureBehavior(
so.Object,
sg_futures.FutureBehaviorMixin,
qlkind=qltypes.SchemaObjectClass.FUTURE,
data_safe=False,
):
Expand Down
51 changes: 51 additions & 0 deletions edb/schema/generated/annos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# DO NOT EDIT. This file was generated with:
#
# $ gen-schema-mixins

"""Type definitions for generated methods on schema classes"""

from typing import cast, TYPE_CHECKING
if TYPE_CHECKING:
from edb.schema import schema as s_schema
from edb.schema import orm as s_orm
from edb.schema import objects
from edb.schema import annos


class AnnotationValueMixin:

def get_subject(
self, schema: 's_schema.Schema'
) -> 'objects.Object':
val = s_orm.get_field_value(self, schema, 'subject')
return cast(objects.Object, val)

def get_annotation(
self, schema: 's_schema.Schema'
) -> 'annos.Annotation':
val = s_orm.get_field_value(self, schema, 'annotation')
return cast(annos.Annotation, val)

def get_value(
self, schema: 's_schema.Schema'
) -> 'str':
val = s_orm.get_field_value(self, schema, 'value')
return cast(str, val)


class AnnotationSubjectMixin:

def get_annotations(
self, schema: 's_schema.Schema'
) -> 'objects.ObjectIndexByShortname[annos.AnnotationValue]':
val = s_orm.get_field_value(self, schema, 'annotations')
return cast(objects.ObjectIndexByShortname[annos.AnnotationValue], val)


class AnnotationMixin:

def get_inheritable(
self, schema: 's_schema.Schema'
) -> 'bool':
val = s_orm.get_field_value(self, schema, 'inheritable')
return cast(bool, val)
69 changes: 69 additions & 0 deletions edb/schema/generated/casts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# DO NOT EDIT. This file was generated with:
#
# $ gen-schema-mixins

"""Type definitions for generated methods on schema classes"""

from typing import cast, TYPE_CHECKING
if TYPE_CHECKING:
from edb.schema import schema as s_schema
from edb.schema import orm as s_orm
from edb.edgeql import ast
from edb.schema import types


class CastMixin:

def get_from_type(
self, schema: 's_schema.Schema'
) -> 'types.Type':
val = s_orm.get_field_value(self, schema, 'from_type')
return cast(types.Type, val)

def get_to_type(
self, schema: 's_schema.Schema'
) -> 'types.Type':
val = s_orm.get_field_value(self, schema, 'to_type')
return cast(types.Type, val)

def get_allow_implicit(
self, schema: 's_schema.Schema'
) -> 'bool':
val = s_orm.get_field_value(self, schema, 'allow_implicit')
return cast(bool, val)

def get_allow_assignment(
self, schema: 's_schema.Schema'
) -> 'bool':
val = s_orm.get_field_value(self, schema, 'allow_assignment')
return cast(bool, val)

def get_language(
self, schema: 's_schema.Schema'
) -> 'ast.Language':
val = s_orm.get_field_value(self, schema, 'language')
return cast(ast.Language, val)

def get_from_function(
self, schema: 's_schema.Schema'
) -> 'str':
val = s_orm.get_field_value(self, schema, 'from_function')
return cast(str, val)

def get_from_expr(
self, schema: 's_schema.Schema'
) -> 'bool':
val = s_orm.get_field_value(self, schema, 'from_expr')
return cast(bool, val)

def get_from_cast(
self, schema: 's_schema.Schema'
) -> 'bool':
val = s_orm.get_field_value(self, schema, 'from_cast')
return cast(bool, val)

def get_code(
self, schema: 's_schema.Schema'
) -> 'str':
val = s_orm.get_field_value(self, schema, 'code')
return cast(str, val)

0 comments on commit a1ca831

Please sign in to comment.