diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index aa978e06..62b903c8 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -10,9 +10,11 @@ go_library( "cel.go", "decls.go", "env.go", + "folding.go", "io.go", "library.go", "macro.go", + "optimizer.go", "options.go", "program.go", "validator.go", @@ -56,6 +58,7 @@ go_test( "cel_test.go", "decls_test.go", "env_test.go", + "folding_test.go", "io_test.go", "validator_test.go", ], diff --git a/cel/env.go b/cel/env.go index 473604bc..7209edc1 100644 --- a/cel/env.go +++ b/cel/env.go @@ -43,6 +43,9 @@ type Ast struct { } // Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil @@ -221,6 +224,11 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { source: ast.Source(), impl: checked} + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } + // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() for _, v := range e.validators { diff --git a/cel/folding.go b/cel/folding.go new file mode 100644 index 00000000..fce15da5 --- /dev/null +++ b/cel/folding.go @@ -0,0 +1,450 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer() ASTOptimizer { + return &constantFoldingOptimizer{} +} + +type constantFoldingOptimizer struct{} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + for len(foldableExprs) != 0 { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + e.SetKindCase(adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Clear any macro metadata associated with the fold. + a.SourceInfo().ClearMacroCall(expr.ID()) + // Update the fold expression to be a literal. + expr.SetKindCase(ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(expr ast.NavigableExpr) bool { + call := expr.AsCall() + switch call.FunctionName() { + case operators.Conditional: + args := call.Args() + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.AsLiteral() == types.True { + expr.SetKindCase(truthy) + } else { + expr.SetKindCase(falsy) + } + return true + } + return false +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, may aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + for i, e := range elems { + if !l.IsOptional(int32(i)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + if !optElemVal.HasValue() { + continue + } + e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + e.SetKindCase(ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + val.SetKindCase(undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + e.SetKindCase(ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.False { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.True { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) diff --git a/cel/folding_test.go b/cel/folding_test.go new file mode 100644 index 00000000..c871a5bd --- /dev/null +++ b/cel/folding_test.go @@ -0,0 +1,195 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel_test + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/test/proto3pb" +) + +func TestConstantFoldingOptimizer(t *testing.T) { + tests := []struct { + expr string + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + folded: `[1, 3, 6]`, + }, + { + expr: `6 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `false`, + }, + { + expr: `x in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `x in [1, 3, 6]`, + }, + { + expr: `1 in [1, x + 2, 1 + (2 + 3)]`, + folded: `1 in [1, x + 2, 6]`, + }, + { + expr: `{'hello': 'world'}.hello == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}.?hello.orValue('default') == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `optional.of("hello")`, + folded: `optional.of("hello")`, + }, + { + expr: `optional.ofNonZeroValue("")`, + folded: `optional.none()`, + }, + { + expr: `{?'hello': optional.of('world')}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `duration(string(7 * 24) + 'h')`, + folded: `duration("604800s")`, + }, + { + expr: `timestamp("1970-01-01T00:00:00Z")`, + folded: `timestamp("1970-01-01T00:00:00Z")`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10)`, + folded: `true`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2)`, + folded: `false`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has(m.a))`, + folded: `[{"a": 1}]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has({'a': true}.a))`, + folded: `[{}, {"a": 1}, {"b": 2}]`, + }, + { + expr: `type(1)`, + folded: `int`, + }, + { + expr: `[google.expr.proto3.test.TestAllTypes{single_int32: 2 + 3}].map(i, i)[0]`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 5}`, + }, + { + expr: `[1, ?optional.ofNonZeroValue(0)]`, + folded: `[1]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y]`, + folded: `[1, x, 3, ?x.?y]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3`, + folded: `[1, x, 3, ?x.?y].size() > 3`, + }, + { + expr: `{?'a': optional.of('hello'), ?x : optional.of(1), ?'b': optional.none()}`, + folded: `{"a": "hello", ?x: optional.of(1)}`, + }, + { + expr: `true ? x + 1 : x + 2`, + folded: `x + 1`, + }, + { + expr: `false ? x + 1 : x + 2`, + folded: `x + 2`, + }, + { + expr: `false ? x + 'world' : 'hello' + 'world'`, + folded: `"helloworld"`, + }, + { + expr: `null`, + folded: `null`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(1)}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(0)}`, + folded: `google.expr.proto3.test.TestAllTypes{}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + }, + { + expr: `x + dyn([1, 2] + [3, 4])`, + folded: `x + [1, 2, 3, 4]`, + }, + { + expr: `dyn([1, 2]) + [3.0, 4.0]`, + folded: `[1, 2, 3.0, 4.0]`, + }, + { + expr: `{'a': dyn([1, 2]), 'b': x}`, + folded: `{"a": [1, 2], "b": x}`, + }, + } + e, err := cel.NewEnv( + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.Types(&proto3pb.TestAllTypes{}), + cel.Variable("x", cel.DynType)) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + opt := cel.NewStaticOptimizer(cel.NewConstantFoldingOptimizer()) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} diff --git a/cel/inlining.go b/cel/inlining.go new file mode 100644 index 00000000..70a6161d --- /dev/null +++ b/cel/inlining.go @@ -0,0 +1,105 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/containers" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// NewInlineVariable declares a variable name to be replaced by a checked expression. +func NewInlineVariable(name string, definition *Ast) *InlineVariable { + return NewInlineVariableWithAlias(name, name, definition) +} + +// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.impl} +} + +// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. +// +// If a variable occurs one time, the variable is replaced by the inline definition. If the +// variable occurs more than once, the variable occurences are replaced by a cel.bind() call. +func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.name)) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // For a single match, do a direct replacement of the expression sub-graph. + if len(matches) == 1 { + matches[0].SetKindCase(ctx.CopyExpr(inlineVar.def.Expr())) + continue + } + + // For multiple matches, find the least common ancestor (lca) and insert the + // variable as a cel.bind() macro. + var lca ast.NavigableExpr = nil + ancestors := map[int64]bool{} + for _, match := range matches { + // Update the identifier matches with the provided alias. + match.SetKindCase(ctx.NewIdent(inlineVar.alias)) + parent, found := match, true + for found { + _, hasAncestor := ancestors[parent.ID()] + if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) { + lca = parent + } + ancestors[parent.ID()] = true + parent, found = parent.Parent() + } + } + + // Update the least common ancestor by inserting a cel.bind() call to the alias. + lca.SetKindCase( + ctx.NewBindMacro(lca.ID(), inlineVar.alias, inlineVar.def.Expr(), lca)) + } + return a +} + +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + qualName, found := containers.ToQualifiedName(e) + return found && qualName == varName + } + return false + } +} diff --git a/cel/inlining_test.go b/cel/inlining_test.go new file mode 100644 index 00000000..d08b0d7c --- /dev/null +++ b/cel/inlining_test.go @@ -0,0 +1,247 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel_test + +import ( + "testing" + + "github.com/google/cel-go/cel" +) + +func TestInliningOptimizer(t *testing.T) { + type varExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + expr string + vars []varExpr + inlined string + folded string + }{ + { + expr: `a || b`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a || "hello".contains("lo")`, + folded: `true`, + }, + { + expr: `a + [a]`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.DynType, + expr: `dyn([1, 2])`, + }, + }, + inlined: `cel.bind(alpha, dyn([1, 2]), alpha + [alpha])`, + folded: `[1, 2, [1, 2]]`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && b && a`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && b && alpha)`, + folded: `cel.bind(alpha, true, alpha && b && alpha)`, + }, + { + expr: `(c || d) || (a && (a || b))`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + { + name: "c", + t: cel.BoolType, + }, + { + name: "d", + t: cel.BoolType, + expr: "!false", + }, + }, + inlined: `c || !false || cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a && (a || "hello".contains("lo"))`, + folded: `a && true`, + }, + { + expr: `a && b`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `!'hello'.contains('lo')`, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + }, + }, + inlined: `!"hello".contains("lo") && b`, + folded: `false`, + }, + { + expr: `operation.system.consumers + operation.destination_consumers`, + vars: []varExpr{ + { + name: "operation.system", + t: cel.DynType, + }, + { + name: "operation.destination_consumers", + t: cel.ListType(cel.IntType), + expr: `productsToConsumers(operation.destination_products)`, + }, + { + name: "operation.destination_products", + t: cel.ListType(cel.IntType), + expr: `operation.system.products`, + }, + }, + inlined: `operation.system.consumers + productsToConsumers(operation.system.products)`, + folded: `operation.system.consumers + productsToConsumers(operation.system.products)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + opts := []cel.EnvOption{cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.Function("productsToConsumers", + cel.Overload("productsToConsumers_list", + []*cel.Type{cel.ListType(cel.IntType)}, + cel.ListType(cel.IntType)))} + + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.vars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + + opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("got %q, wanted %q", inlined, tc.inlined) + } + opt = cel.NewStaticOptimizer(cel.NewConstantFoldingOptimizer()) + optimized, iss = opt.Optimize(e, optimized) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} diff --git a/cel/optimizer.go b/cel/optimizer.go new file mode 100644 index 00000000..5a7abeb5 --- /dev/null +++ b/cel/optimizer.go @@ -0,0 +1,278 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + ids := newMonotonicIDGen(ast.MaxID(a.impl)) + fac := &optimizerExprFactory{ + nextID: ids.nextID, + renumberID: ids.renumberID, + fac: ast.NewExprFactory(), + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: fac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + normalizeIDs(env, optimized) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +func normalizeIDs(e *Env, optimized *ast.AST) { + ids := newStableIDGen() + optimized.Expr().RenumberIDs(ids.renumberID) + allExprMap := make(map[int64]ast.Expr) + ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) { + allExprMap[e.ID()] = e + })) + info := optimized.SourceInfo() + + // First, update the macro call ids themselves. + for id, call := range info.MacroCalls() { + info.ClearMacroCall(id) + callID := ids.renumberID(id) + if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind { + continue + } + info.SetMacroCall(callID, call) + } + + // Second, update the macro call id references to ensure that macro pointers are' + // updated consistently across macros. + for id, call := range info.MacroCalls() { + call.RenumberIDs(ids.renumberID) + resetMacroCall(optimized, call, allExprMap) + info.SetMacroCall(id, call) + } +} + +func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) { + modified := []ast.Expr{} + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if _, found := allExprMap[e.ID()]; found { + modified = append(modified, e) + } + })) + for _, m := range modified { + updated := allExprMap[m.ID()] + m.SetKindCase(updated) + } +} + +// newMonotonicIDGen increments numbers from an initial seed value. +func newMonotonicIDGen(seed int64) *monotonicIDGenerator { + return &monotonicIDGenerator{seed: seed} +} + +type monotonicIDGenerator struct { + seed int64 +} + +func (gen *monotonicIDGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *monotonicIDGenerator) renumberID(int64) int64 { + return gen.nextID() +} + +// newStableIDGen ensures that new ids are only created the first time they are encountered. +func newStableIDGen() *stableIDGenerator { + return &stableIDGenerator{ + idMap: make(map[int64]int64), + } +} + +type stableIDGenerator struct { + idMap map[int64]int64 + nextID int64 +} + +func (gen *stableIDGenerator) renumberID(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + gen.nextID++ + gen.idMap[id] = gen.nextID + return gen.nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + nextID func() int64 + renumberID ast.IDGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { + copy := opt.fac.CopyExpr(e) + copy.RenumberIDs(opt.renumberID) + return copy +} + +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { + bindID := opt.nextID() + varID := opt.nextID() + + varInit = opt.CopyExpr(varInit) + varInit.RenumberIDs(opt.renumberID) + + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined + // call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + varInit, + remaining)) + + // Replace the parent node with the intercepted inlining using cel.bind()-like + // generated comprehension AST. + return opt.fac.NewComprehension(bindID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(remaining)) +} + +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +func (opt *optimizerExprFactory) NewPresenceTest(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewPresenceTest(opt.nextID(), operand, field) +} + +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} diff --git a/common/ast/ast.go b/common/ast/ast.go index 7610b467..7845520d 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -249,10 +249,16 @@ func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { // SetMacroCall records a macro call at a specific location. func (s *SourceInfo) SetMacroCall(id int64, e Expr) { - if s == nil { - return + if s != nil { + s.macroCalls[id] = e + } +} + +// ClearMacroCall removes the macro call at the given expression id. +func (s *SourceInfo) ClearMacroCall(id int64) { + if s != nil { + delete(s.macroCalls, id) } - s.macroCalls[id] = e } // OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: diff --git a/common/ast/expr.go b/common/ast/expr.go index 5811e395..aac3bf3d 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -184,6 +184,9 @@ type ListExpr interface { // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 + // IsOptional indicates whether the given element index is optional. + IsOptional(int32) bool + // Size returns the number of elements in the list. Size() int @@ -606,6 +609,15 @@ func (e *baseListExpr) Elements() []Expr { return e.elements } +func (e *baseListExpr) IsOptional(index int32) bool { + for _, optIndex := range e.OptionalIndices() { + if optIndex == index { + return true + } + } + return false +} + func (e *baseListExpr) OptionalIndices() []int32 { if e == nil { return []int32{} diff --git a/common/ast/navigable.go b/common/ast/navigable.go index 2836b565..f5ddf6aa 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -423,6 +423,10 @@ func (l navigableListImpl) Elements() []Expr { return elems } +func (l navigableListImpl) IsOptional(index int32) bool { + return l.Expr.AsList().IsOptional(index) +} + func (l navigableListImpl) OptionalIndices() []int32 { return l.Expr.AsList().OptionalIndices() }