Skip to content

Commit

Permalink
Add more intersection handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed May 7, 2024
1 parent 18ac49a commit 9db9b8d
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 11 deletions.
6 changes: 6 additions & 0 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,12 @@ def _validate_object_search_call(
idx = _validate_has_object_index(
variant, schema, span, context, index_name)
indexes[typegen.type_to_typeref(variant, ctx.env)] = idx
elif intersection_variants := stype.get_intersection_of(schema):
for variant in intersection_variants.objects(schema):
schema, variant = variant.material_type(schema)
idx = _validate_has_object_index(
variant, schema, span, context, index_name)
indexes[typegen.type_to_typeref(variant, ctx.env)] = idx
else:
idx = _validate_has_object_index(
stype, schema, span, context, index_name)
Expand Down
6 changes: 6 additions & 0 deletions edb/edgeql/compiler/schemactx.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,12 @@ def get_all_concrete(
for t in union.objects(ctx.env.schema)
for x in get_all_concrete(t, ctx=ctx)
}
elif intersection := stype.get_intersection_of(ctx.env.schema):
return {
x
for t in intersection.objects(ctx.env.schema)
for x in get_all_concrete(t, ctx=ctx)
}
return {stype} | {
x for x in stype.descendants(ctx.env.schema)
if x.is_material_object_type(ctx.env.schema)
Expand Down
6 changes: 6 additions & 0 deletions edb/edgeql/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,12 @@ def _resolve_type_expr(
_resolve_type_expr(texpr.right, ctx=ctx),
])

elif texpr.op == '&':
return qltracer.IntersectionType([
_resolve_type_expr(texpr.left, ctx=ctx),
_resolve_type_expr(texpr.right, ctx=ctx),
])

else:
raise NotImplementedError(
f'unsupported type operation: {texpr.op}')
Expand Down
22 changes: 22 additions & 0 deletions edb/edgeql/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,22 @@ def is_object_type(self) -> bool:
return True


class IntersectionType(Type):

def __init__(
self, types: List[Union[Type, IntersectionType, so.Object]]
) -> None:
self.types = types

def get_name(self, schema: s_schema.Schema) -> sn.QualName:
component_ids = sorted(str(t.get_name(schema)) for t in self.types)
nqname = f"({' & '.join(component_ids)})"
return sn.QualName(name=nqname, module='__derived__')

def is_object_type(self) -> bool:
return True


class Pointer(Source):

def __init__(
Expand Down Expand Up @@ -941,6 +957,12 @@ def _resolve_type_expr(
_resolve_type_expr(texpr.right, ctx=ctx),
])

elif texpr.op == '&':
return IntersectionType([
_resolve_type_expr(texpr.left, ctx=ctx),
_resolve_type_expr(texpr.right, ctx=ctx),
])

else:
raise NotImplementedError(
f'unsupported type operation: {texpr.op}')
Expand Down
11 changes: 9 additions & 2 deletions edb/pgsql/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6206,6 +6206,8 @@ def get_target_objs(self, link, schema):
tgt = link.get_target(schema)
if union := tgt.get_union_of(schema).objects(schema):
objs = set(union)
elif intersection := tgt.get_intersection_of(schema).objects(schema):
objs = set(intersection)
else:
objs = {tgt}
objs |= {
Expand Down Expand Up @@ -6803,9 +6805,10 @@ def apply(
# triggers updated, so track them down.
all_affected_targets = set()
for target in affected_targets:
union_of = target.get_union_of(schema)
if union_of:
if union_of := target.get_union_of(schema):
objtypes = tuple(union_of.objects(schema))
elif intersection_of := target.get_intersection_of(schema):
objtypes = tuple(intersection_of.objects(schema))
else:
objtypes = (target,)

Expand Down Expand Up @@ -6836,6 +6839,10 @@ def apply(
target, scls_type=s_objtypes.ObjectType,
field_name='union_of'
),
schema.get_referrers(
target, scls_type=s_objtypes.ObjectType,
field_name='intersection_of'
),
):
inbound_links |= schema.get_referrers(
ancestor, scls_type=s_links.Link, field_name='target')
Expand Down
2 changes: 1 addition & 1 deletion edb/schema/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _only_generic(

obj = cast(s_objtypes.ObjectType,
relevant_schema.get(cmd.classname))
if obj.is_union_type(relevant_schema):
if obj.is_compound_type(relevant_schema):
continue

result.add(cmd)
Expand Down
50 changes: 50 additions & 0 deletions edb/schema/objtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ def getrptrs(
for union in unions:
ptrs.update(union.getrptrs(schema, name, sources=sources))

intersections = schema.get_referrers(
self, scls_type=ObjectType, field_name='intersection_of'
)

for intersection in intersections:
ptrs.update(intersection.getrptrs(schema, name, sources=sources))

return ptrs

def get_relevant_triggers(
Expand Down Expand Up @@ -683,6 +690,49 @@ def _alter_finalize(
schema = diff.apply(schema, context)
self.add(diff)

# Do the same for intersections
intersections = schema.get_referrers(
self.scls, scls_type=ObjectType, field_name='intersection_of')

orig_disable = context.disable_dep_verification

for intersection in intersections:
delete = (
intersection.init_delta_command(schema, sd.DeleteObject)
)

context.disable_dep_verification = True
delete.apply(schema, context)
context.disable_dep_verification = orig_disable
# We run the delete to populate the tree, but then instead
# of actually deleting the object, we just remove the names.
# This is because the pointers in the types we are looking
# at might themselves reference the intersection, so we need
# them in the schema to produce the correct as_alter_delta.
nschema = _delete_to_delist(delete, schema)

nschema, nintersection, _ = utils.ensure_intersection_type(
nschema,
types=(
intersection
.get_intersection_of(schema)
.objects(schema)
),
module=intersection.get_name(schema).module,
)
assert isinstance(nintersection, ObjectType)

diff = intersection.as_alter_delta(
other=nintersection,
self_schema=schema,
other_schema=nschema,
confidence=1.0,
context=so.ComparisonContext(),
)

schema = diff.apply(schema, context)
self.add(diff)

return super()._alter_finalize(schema, context)


Expand Down
17 changes: 12 additions & 5 deletions edb/schema/pointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import collections.abc
import enum
import itertools
import json
import operator

Expand Down Expand Up @@ -2289,12 +2290,18 @@ def _canonicalize(

# Any union type that references this field needs to have it
# deleted.
unions = schema.get_referrers(
self.scls, scls_type=Pointer, field_name='union_of')
for union in unions:
group, op, _ = union.init_delta_branch(
referrers = itertools.chain(
schema.get_referrers(
self.scls, scls_type=Pointer, field_name='union_of'
),
schema.get_referrers(
self.scls, scls_type=Pointer, field_name='intersection_of'
),
)
for referrer in referrers:
group, op, _ = referrer.init_delta_branch(
schema, context, sd.DeleteObject)
op.update(op._canonicalize(schema, context, union))
op.update(op._canonicalize(schema, context, referrer))
commands.append(group)

return commands
Expand Down
21 changes: 19 additions & 2 deletions edb/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,25 @@ def typeref_to_ast(
for st in t.get_subtypes(schema)
]
)
elif isinstance(t, s_types.Type) and t.is_union_type(schema):
elif isinstance(t, s_types.Type) and (t.is_compound_type(schema)):
object_set = t.get_union_of(schema)
assert object_set is not None

component_objects = tuple(object_set.objects(schema))
result = typeref_to_ast(schema, component_objects[0],
disambiguate_std=disambiguate_std)

if t.is_union_type(schema):
op = '|'
elif t.is_intersection_type(schema):
op = '&'
else:
raise NotImplementedError

for component_object in component_objects[1:]:
result = qlast.TypeOp(
left=result,
op='|',
op=op,
right=typeref_to_ast(schema, component_object,
disambiguate_std=disambiguate_std),
)
Expand Down Expand Up @@ -689,6 +697,15 @@ def shell_to_ast(
op='|',
right=typeref_to_ast(schema, component),
)
elif isinstance(t, s_types.IntersectionTypeShell):
components = t.get_components(schema)
result = typeref_to_ast(schema, components[0])
for component in components[1:]:
result = qlast.TypeOp(
left=result,
op='&',
right=typeref_to_ast(schema, component),
)
elif isinstance(t, s_scalars.AnonymousEnumTypeShell):
result = qlast.TypeName(
name=_name,
Expand Down
2 changes: 1 addition & 1 deletion edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ def describe_database_dump(

cfg_object = schema.get('cfg::ConfigObject', type=s_objtypes.ObjectType)
for objtype in objtypes:
if objtype.is_union_type(schema) or objtype.is_view(schema):
if objtype.is_compound_type(schema) or objtype.is_view(schema):
continue
if objtype.issubclass(schema, cfg_object):
continue
Expand Down
6 changes: 6 additions & 0 deletions edb/tools/experimental_interpreter/elaboration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FunAppExpr,
IndirectionIndexOp,
InsertExpr,
IntersectTp,
IntVal,
Label,
LinkPropLabel,
Expand Down Expand Up @@ -611,6 +612,11 @@ def elab_single_type_expr(typedef: qlast.TypeExpr) -> Tp:
left=elab_single_type_expr(left_type),
right=elab_single_type_expr(right_type),
)
elif op_name == "&":
return IntersectTp(
left=elab_single_type_expr(left_type),
right=elab_single_type_expr(right_type),
)
else:
raise ValueError("Unknown Type Op")
raise ValueError("MATCH")
Expand Down

0 comments on commit 9db9b8d

Please sign in to comment.