Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: google/cel-go
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.17.7
Choose a base ref
...
head repository: google/cel-go
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.18.0
Choose a head ref

Commits on Jul 20, 2023

  1. Bump word-wrap from 1.2.3 to 1.2.4 in /repl/appengine/web (#783)

    Bumps [word-wrap](https://github.com/jonschlinkert/word-wrap) from 1.2.3 to 1.2.4.
    - [Release notes](https://github.com/jonschlinkert/word-wrap/releases)
    - [Commits](jonschlinkert/word-wrap@1.2.3...1.2.4)
    
    ---
    updated-dependencies:
    - dependency-name: word-wrap
      dependency-type: indirect
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    dependabot[bot] authored Jul 20, 2023
    Copy the full SHA
    766076f View commit details

Commits on Jul 31, 2023

  1. String format validator (#775)

    * String format validator
    * Remove unused method from prior string validation
    TristonianJones authored Jul 31, 2023
    Copy the full SHA
    965e9c8 View commit details
  2. Copy the full SHA
    5359cfd View commit details

Commits on Aug 4, 2023

  1. Copy the full SHA
    20720f3 View commit details

Commits on Aug 9, 2023

  1. Split Expr from NavigableExpr with interpreter support (#788)

    * Split Expr from NavigableExpr with interpreter support
    * Handle errors during proto transformation in the type-checker
    TristonianJones authored Aug 9, 2023
    Copy the full SHA
    a6388a3 View commit details

Commits on Aug 11, 2023

  1. Migrate the type-checker to a native AST representation (#793)

    * Migrate the type-checker to a native AST representation
    * Minor doc update to the type-checker
    * Patch cost_test.go fixes
    * Removed nil field initializer
    TristonianJones authored Aug 11, 2023
    Copy the full SHA
    337fc07 View commit details

Commits on Aug 14, 2023

  1. Copy the full SHA
    b99d122 View commit details
  2. Copy the full SHA
    bf4f82c View commit details

Commits on Aug 15, 2023

  1. Copy the full SHA
    2ef121b View commit details
  2. Copy the full SHA
    036015e View commit details
  3. Migrate cel.Ast to be a thin layer on ast.AST (#806)

    * Migrate cel.Ast to be a thin layer on ast.AST
    * Adjust IsChecked() criteria
    TristonianJones authored Aug 15, 2023
    Copy the full SHA
    51cf846 View commit details

Commits on Aug 16, 2023

  1. Copy the full SHA
    fc3b794 View commit details
  2. Copy the full SHA
    6643a4a View commit details
  3. Copy the full SHA
    0bd4d39 View commit details

Commits on Aug 18, 2023

  1. Creating a Function that Reverses a String (#796)

    Creating a reverse function for CEL that takes a string and returns a
    new string whose characters are in the reverse order of the original.
    bboogler authored Aug 18, 2023
    Copy the full SHA
    5be9464 View commit details
  2. Migrate the parser to the Go-native Expr (#797)

    * Migrate the type-checker to a native AST representation
    * Parser expr update
    * Unify cel.MacroExprFactory with ast.ExprFactory
    TristonianJones authored Aug 18, 2023
    Copy the full SHA
    0b8fcf3 View commit details
  3. Migrate the checker.Coster to the ast.Expr (#798)

    Migrate Cost calculations to Go-native Expr.
    
    Note, this is a breaking changes as the type of the checker.AstNode
    has been modified to return a Go-native Expr value rather than its
    protobuf equivalent.
    TristonianJones authored Aug 18, 2023
    Copy the full SHA
    0453692 View commit details
  4. Migrate the interpreter.PruneAst to the Go-native Expr (#799)

    Migrate pruner to Go-native Expr
    TristonianJones authored Aug 18, 2023
    Copy the full SHA
    b6d0e04 View commit details
  5. Migrate the parser.Unparse to the Go-native Expr (#800)

    Migrate unparser internals to native Expr type
    TristonianJones authored Aug 18, 2023
    Copy the full SHA
    26aa367 View commit details
  6. Copy the full SHA
    eaebecb View commit details

Commits on Aug 19, 2023

  1. Introduce pre-order / post-order visitor pattern (#813)

    * Introduce pre-order / post-order visitor pattern
    TristonianJones authored Aug 19, 2023
    Copy the full SHA
    1a6373d View commit details
  2. Copy the full SHA
    2de9952 View commit details

Commits on Aug 22, 2023

  1. Copy the full SHA
    78039f1 View commit details

Commits on Aug 23, 2023

  1. Copy the full SHA
    8a45955 View commit details
  2. Copy the full SHA
    dd6d31d View commit details
  3. Copy the full SHA
    509c1d6 View commit details

Commits on Aug 30, 2023

  1. Static optimizer for constant folding (#804)

    * Optimizer API with Constant Folding implementatiton
    * Better logical folds and additional tests
    * Add a configurable limit to constant folding
    TristonianJones authored Aug 30, 2023
    Copy the full SHA
    705546a View commit details

Commits on Aug 31, 2023

  1. Copy the full SHA
    bfccebd View commit details

Commits on Sep 1, 2023

  1. Copy the full SHA
    4eebcf3 View commit details
  2. Upgrade go-genproto to latest (#831)

    * Upgrade go-genproto to latest
    
    See googleapis/go-genproto#1015.
    
    * Update WORKSPACE
    l46kok authored Sep 1, 2023
    Copy the full SHA
    8943046 View commit details
  3. Inlining optimizer (#827)

    Inliner with support for identifiers and simple select expressions
    
    Note: optional field selections are not supported for inlining matches
    at this time.
    TristonianJones authored Sep 1, 2023
    Copy the full SHA
    5db3640 View commit details
Showing with 10,947 additions and 4,346 deletions.
  1. +13 −13 WORKSPACE
  2. +4 −0 cel/BUILD.bazel
  3. +140 −2 cel/cel_test.go
  4. +0 −40 cel/decls.go
  5. +62 −70 cel/env.go
  6. +47 −0 cel/env_test.go
  7. +558 −0 cel/folding.go
  8. +676 −0 cel/folding_test.go
  9. +220 −0 cel/inlining.go
  10. +577 −0 cel/inlining_test.go
  11. +8 −28 cel/io.go
  12. +1 −1 cel/io_test.go
  13. +30 −31 cel/library.go
  14. +444 −12 cel/macro.go
  15. +390 −0 cel/optimizer.go
  16. +2 −0 cel/options.go
  17. +6 −54 cel/program.go
  18. +19 −32 cel/validator.go
  19. +146 −210 checker/checker.go
  20. +22 −12 checker/checker_test.go
  21. +112 −106 checker/cost.go
  22. +16 −0 checker/cost_test.go
  23. +7 −11 checker/errors.go
  24. +17 −17 checker/printer.go
  25. +9 −0 common/ast/BUILD.bazel
  26. +332 −127 common/ast/ast.go
  27. +168 −101 common/ast/ast_test.go
  28. +632 −0 common/ast/conversion.go
  29. +469 −0 common/ast/conversion_test.go
  30. +559 −408 common/ast/expr.go
  31. +392 −332 common/ast/expr_test.go
  32. +303 −0 common/ast/factory.go
  33. +652 −0 common/ast/navigable.go
  34. +601 −0 common/ast/navigable_test.go
  35. +2 −2 common/containers/BUILD.bazel
  36. +11 −11 common/containers/container.go
  37. +8 −28 common/containers/container_test.go
  38. +3 −1 common/debug/BUILD.bazel
  39. +77 −79 common/debug/debug.go
  40. +2 −1 common/decls/decls.go
  41. +1 −1 common/errors.go
  42. +23 −2 common/types/provider.go
  43. +34 −0 common/types/provider_test.go
  44. +2 −2 ext/BUILD.bazel
  45. +14 −0 ext/README.md
  46. +12 −12 ext/bindings.go
  47. +904 −0 ext/formatting.go
  48. +5 −6 ext/guards.go
  49. +33 −34 ext/math.go
  50. +18 −0 ext/native.go
  51. +35 −0 ext/native_test.go
  52. +22 −23 ext/protos.go
  53. +62 −392 ext/strings.go
  54. +67 −21 ext/strings_test.go
  55. +4 −4 go.mod
  56. +8 −8 go.sum
  57. +0 −1 interpreter/BUILD.bazel
  58. +0 −383 interpreter/formatting.go
  59. +3 −23 interpreter/interpreter.go
  60. +49 −88 interpreter/interpreter_test.go
  61. +122 −157 interpreter/planner.go
  62. +210 −286 interpreter/prune.go
  63. +15 −13 interpreter/prune_test.go
  64. +5 −0 parser/BUILD.bazel
  65. +226 −352 parser/helper.go
  66. +17 −10 parser/helper_test.go
  67. +88 −106 parser/macro.go
  68. +66 −70 parser/parser.go
  69. +66 −52 parser/parser_test.go
  70. +107 −115 parser/unparser.go
  71. +16 −13 parser/unparser_test.go
  72. +3 −3 repl/appengine/web/package-lock.json
  73. +10 −4 vendor/google.golang.org/protobuf/encoding/protojson/encode.go
  74. +10 −4 vendor/google.golang.org/protobuf/encoding/prototext/encode.go
  75. +6 −4 vendor/google.golang.org/protobuf/internal/encoding/json/encode.go
  76. +6 −4 vendor/google.golang.org/protobuf/internal/encoding/text/encode.go
  77. +48 −0 vendor/google.golang.org/protobuf/internal/genid/descriptor_gen.go
  78. +6 −0 vendor/google.golang.org/protobuf/internal/genid/type_gen.go
  79. +1 −1 vendor/google.golang.org/protobuf/internal/order/order.go
  80. +1 −1 vendor/google.golang.org/protobuf/internal/version/version.go
  81. +7 −3 vendor/google.golang.org/protobuf/proto/size.go
  82. +27 −0 vendor/google.golang.org/protobuf/reflect/protoreflect/source_gen.go
  83. +633 −378 vendor/google.golang.org/protobuf/types/descriptorpb/descriptor.pb.go
  84. +177 −0 vendor/google.golang.org/protobuf/types/dynamicpb/types.go
  85. +35 −35 vendor/google.golang.org/protobuf/types/known/anypb/any.pb.go
  86. +1 −1 vendor/google.golang.org/protobuf/types/known/structpb/struct.pb.go
  87. +1 −1 vendor/google.golang.org/protobuf/types/known/timestamppb/timestamp.pb.go
  88. +4 −4 vendor/modules.txt
26 changes: 13 additions & 13 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -30,13 +30,13 @@ http_archive(
],
)

# googleapis as of 05/26/2023
# googleapis as of 08/31/2023
http_archive(
name = "com_google_googleapis",
strip_prefix = "googleapis-07c27163ac591955d736f3057b1619ece66f5b99",
sha256 = "bd8e735d881fb829751ecb1a77038dda4a8d274c45490cb9fcf004583ee10571",
sha256 = "5c56500adf7b1b7a3a2ee5ca5b77500617ad80afb808e3d3979f582e64c0523d",
strip_prefix = "googleapis-25f99371444ea7fd0dc1523ca6925e91cc48a664",
urls = [
"https://github.com/googleapis/googleapis/archive/07c27163ac591955d736f3057b1619ece66f5b99.tar.gz",
"https://github.com/googleapis/googleapis/archive/25f99371444ea7fd0dc1523ca6925e91cc48a664.tar.gz",
],
)

@@ -68,22 +68,22 @@ go_repository(
tag = "v1.28.1",
)

# Generated Google APIs protos for Golang 05/25/2023
# Generated Google APIs protos for Golang 08/03/2023
go_repository(
name = "org_golang_google_genproto_googleapis_api",
build_file_proto_mode = "disable_global",
importpath = "google.golang.org/genproto/googleapis/api",
sum = "h1:m8v1xLLLzMe1m5P+gCTF8nJB9epwZQUBERm20Oy1poQ=",
version = "v0.0.0-20230525234035-dd9d682886f9",
sum = "h1:nIgk/EEq3/YlnmVVXVnm14rC2oxgs1o0ong4sD/rd44=",
version = "v0.0.0-20230803162519-f966b187b2e5",
)

# Generated Google APIs protos for Golang 05/25/2023
# Generated Google APIs protos for Golang 08/03/2023
go_repository(
name = "org_golang_google_genproto_googleapis_rpc",
build_file_proto_mode = "disable_global",
importpath = "google.golang.org/genproto/googleapis/rpc",
sum = "h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=",
version = "v0.0.0-20230525234030-28d5490b6b19",
sum = "h1:eSaPbMR4T7WfH9FvABk36NBMacoTUKdWCvV0dx+KfOg=",
version = "v0.0.0-20230803162519-f966b187b2e5",
)

# gRPC deps for v1.49.0 (including x/text and x/net)
@@ -119,8 +119,8 @@ go_repository(
# CEL Spec deps v0.9.0
go_repository(
name = "com_google_cel_spec",
importpath = "github.com/google/cel-spec",
commit = "51af45e2b75a8aa2b3108b00f0e91cd172cfbea1",
importpath = "github.com/google/cel-spec",
)

# strcase deps
@@ -134,16 +134,16 @@ go_repository(
# Readline for repl
go_repository(
name = "com_github_chzyer_readline",
commit = "62c6fe6193755f722b8b8788aa7357be55a50ff1", # v1.4
importpath = "github.com/chzyer/readline",
commit = "62c6fe6193755f722b8b8788aa7357be55a50ff1" # v1.4
)

# golang.org/x/exp deps
go_repository(
name = "org_golang_x_exp",
importpath = "golang.org/x/exp",
sum = "h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=",
version = "v0.0.0-20220722155223-a9213eeb770e"
version = "v0.0.0-20220722155223-a9213eeb770e",
)

# Run the dependencies at the end. These will silently try to import some
4 changes: 4 additions & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -10,9 +10,11 @@ go_library(
"cel.go",
"decls.go",
"env.go",
"folding.go",
"io.go",
"library.go",
"macro.go",
"optimizer.go",
"options.go",
"program.go",
"validator.go",
@@ -56,7 +58,9 @@ go_test(
"cel_test.go",
"decls_test.go",
"env_test.go",
"folding_test.go",
"io_test.go",
"validator_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
142 changes: 140 additions & 2 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ import (
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
"github.com/google/cel-go/test"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -699,6 +700,142 @@ func TestCustomMacro(t *testing.T) {
}
}

func TestMacroInterop(t *testing.T) {
existsOneMacro := NewReceiverMacro("exists_one", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return ExistsOneMacroExpander(meh, iterRange, args)
})
transformMacro := NewReceiverMacro("transform", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return MapMacroExpander(meh, iterRange, args)
})
filterMacro := NewReceiverMacro("filter", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return FilterMacroExpander(meh, iterRange, args)
})
pairMacro := NewGlobalMacro("pair", 2,
func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return meh.NewMap(meh.NewMapEntry(args[0], args[1], false)), nil
})
getMacro := NewReceiverMacro("get", 2,
func(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return meh.GlobalCall(
operators.Conditional,
meh.PresenceTest(meh.Copy(target), args[0].GetIdentExpr().GetName()),
meh.Select(meh.Copy(target), args[0].GetIdentExpr().GetName()),
meh.Copy(args[1]),
), nil
})
env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro))
tests := []struct {
expr string
out ref.Val
}{
{
expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`,
out: types.True,
},
{
expr: `pair('a', 'b')`,
out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}),
},
{
expr: `{}.get(a, 'default')`,
out: types.String("default"),
},
{
expr: `{'a': 'b'}.get(a, 'default')`,
out: types.String("b"),
},
}

for _, tst := range tests {
ast, iss := env.Compile(tst.expr)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatal(err)
}
if out.Equal(tst.out) != types.True {
t.Errorf("got %v, wanted %v", out, tst.out)
}
}
}

func TestMacroModern(t *testing.T) {
existsOneMacro := ReceiverMacro("exists_one", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeExistsOne(mef, iterRange, args)
})
transformMacro := ReceiverMacro("transform", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeMap(mef, iterRange, args)
})
filterMacro := ReceiverMacro("filter", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeFilter(mef, iterRange, args)
})
pairMacro := GlobalMacro("pair", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil
})
getMacro := ReceiverMacro("get", 2,
func(mef MacroExprFactory, target celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewCall(
operators.Conditional,
mef.NewPresenceTest(mef.Copy(target), args[0].AsIdent()),
mef.NewSelect(mef.Copy(target), args[0].AsIdent()),
mef.Copy(args[1]),
), nil
})
env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro))
tests := []struct {
expr string
out ref.Val
}{
{
expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`,
out: types.True,
},
{
expr: `pair('a', 'b')`,
out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}),
},
{
expr: `{}.get(a, 'default')`,
out: types.String("default"),
},
{
expr: `{'a': 'b'}.get(a, 'default')`,
out: types.String("b"),
},
}

for _, tst := range tests {
ast, iss := env.Compile(tst.expr)
if iss.Err() != nil {
t.Fatal(iss.Err())
}
prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval))
if err != nil {
t.Fatalf("program creation error: %s\n", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatal(err)
}
if out.Equal(tst.out) != types.True {
t.Errorf("got %v, wanted %v", out, tst.out)
}
}
}

func TestCustomExistsMacro(t *testing.T) {
env := testEnv(t,
Variable("attr", MapType(StringType, BoolType)),
@@ -1944,7 +2081,8 @@ func TestDynamicDispatch(t *testing.T) {
),
)
out, err := interpret(t, env, `
[1, 2].first() == 1
dyn([]).first() == 0
&& [1, 2].first() == 1
&& [1.0, 2.0].first() == 1.0
&& ["hello", "world"].first() == "hello"
&& [["hello"], ["world", "!"]].first().first() == "hello"
@@ -2063,7 +2201,7 @@ func TestOptionalValuesCompile(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
for id, reference := range ast.refMap {
for id, reference := range ast.impl.ReferenceMap() {
other, found := tc.references[id]
if !found {
t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference)
40 changes: 0 additions & 40 deletions cel/decls.go
Original file line number Diff line number Diff line change
@@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}

func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}
132 changes: 62 additions & 70 deletions cel/env.go
Original file line number Diff line number Diff line change
@@ -38,36 +38,44 @@ type Source = common.Source
// Ast representing the checked or unchecked expression, its source, and related metadata such as
// source position information.
type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
source Source
impl *celast.AST
}

// Expr returns the proto serializable instance of the parsed/checked expression.
//
// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr()
// the result instead.
func (ast *Ast) Expr() *exprpb.Expr {
return ast.expr
if ast == nil {
return nil
}
pbExpr, _ := celast.ExprToProto(ast.impl.Expr())
return pbExpr
}

// IsChecked returns whether the Ast value has been successfully type-checked.
func (ast *Ast) IsChecked() bool {
return ast.typeMap != nil && len(ast.typeMap) > 0
if ast == nil {
return false
}
return ast.impl.IsChecked()
}

// SourceInfo returns character offset and newline position information about expression elements.
func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
return ast.info
if ast == nil {
return nil
}
pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo())
return pbInfo
}

// ResultType returns the output type of the expression if the Ast has been type-checked, else
// returns chkdecls.Dyn as the parse step cannot infer the type.
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return chkdecls.Dyn
}
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
@@ -79,16 +87,18 @@ func (ast *Ast) ResultType() *exprpb.Type {
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
if ast == nil {
return types.ErrorType
}
return t
return ast.impl.GetType(ast.impl.Expr().ID())
}

// Source returns a view of the input used to create the Ast. This source may be complete or
// constructed from the SourceInfo.
func (ast *Ast) Source() Source {
if ast == nil {
return nil
}
return ast.source
}

@@ -196,29 +206,28 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
// Note, errors aren't currently possible on the Ast to ParsedExpr conversion.
pe, _ := AstToParsedExpr(ast)

// Construct the internal checker env, erroring if there is an issue adding the declarations.
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}

res, errs := checker.Check(pe, ast.Source(), chk)
checked, errs := checker.Check(ast.impl, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
source: ast.Source(),
impl: checked}

// Avoid creating a validator config if it's not needed.
if len(e.validators) == 0 {
return ast, nil
}

// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
@@ -228,9 +237,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
v.Validate(e, vConfig, checked, iss)
}
if iss.Err() != nil {
return nil, iss
@@ -388,6 +397,15 @@ func (e *Env) HasLibrary(libName string) bool {
return exists && configured
}

// Libraries returns a list of SingletonLibrary that have been configured in the environment.
func (e *Env) Libraries() []string {
libraries := make([]string, 0, len(e.libraries))
for libName := range e.libraries {
libraries = append(libraries, libName)
}
return libraries
}

// HasValidator returns whether a specific ASTValidator has been configured in the environment.
func (e *Env) HasValidator(name string) bool {
for _, v := range e.validators {
@@ -415,16 +433,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src Source) (*Ast, *Issues) {
res, errs := e.prsr.Parse(src)
parsed, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
// Manually create the Ast to ensure that the text source information is propagated on
// subsequent calls to Check.
return &Ast{
source: src,
expr: res.GetExpr(),
info: res.GetSourceInfo()}, nil
return &Ast{source: src, impl: parsed}, nil
}

// Program generates an evaluable instance of the Ast within the environment (Env).
@@ -520,8 +533,9 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
return nil, err
}
@@ -542,13 +556,7 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
return checker.Cost(checked, estimator, opts...)
return checker.Cost(ast.impl, estimator, opts...)
}

// configure applies a series of EnvOptions to the current environment.
@@ -690,7 +698,7 @@ type Error = common.Error
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
info *celast.SourceInfo
}

// NewIssues returns an Issues struct from a common.Errors object.
@@ -701,7 +709,7 @@ func NewIssues(errs *common.Errors) *Issues {
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
@@ -751,30 +759,7 @@ func (i *Issues) String() string {
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}

// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...)
}

// getStdEnv lazy initializes the CEL standard environment.
@@ -805,6 +790,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b
return nil, false
}

// FindStructFieldNames returns an empty set of field for the interop provider.
//
// To inspect the field names, migrate to a `types.Provider` implementation.
func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}

// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//
47 changes: 47 additions & 0 deletions cel/env_test.go
Original file line number Diff line number Diff line change
@@ -29,6 +29,25 @@ import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func TestAstNil(t *testing.T) {
var ast *Ast
if ast.IsChecked() {
t.Error("ast.IsChecked() returned true for nil ast")
}
if ast.Expr() != nil {
t.Errorf("ast.Expr() got %v, wanted nil", ast.Expr())
}
if ast.SourceInfo() != nil {
t.Errorf("ast.SourceInfo() got %v, wanted nil", ast.SourceInfo())
}
if ast.OutputType() != types.ErrorType {
t.Errorf("ast.OutputType() got %v, wanted error type", ast.OutputType())
}
if ast.Source() != nil {
t.Errorf("ast.Source() got %v, wanted nil", ast.Source())
}
}

func TestIssuesNil(t *testing.T) {
var iss *Issues
iss = iss.Append(iss)
@@ -228,6 +247,30 @@ func TestTypeProviderInterop(t *testing.T) {
}
}

func TestLibraries(t *testing.T) {
e, err := NewEnv(OptionalTypes())
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
for _, expected := range []string{"cel.lib.std", "cel.lib.optional"} {
if !e.HasLibrary(expected) {
t.Errorf("Expected HasLibrary() to return true for '%s'", expected)
}
libMap := map[string]struct{}{}
libraries := e.Libraries()
for _, lib := range libraries {
libMap[lib] = struct{}{}
}
if len(libraries) != 2 {
t.Errorf("Expected HasLibrary() to contain exactly 2 libraries but got: %v", libraries)
}

if _, ok := libMap[expected]; !ok {
t.Errorf("Expected Libraries() to include '%s'", expected)
}
}
}

func BenchmarkNewCustomEnvLazy(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -366,6 +409,10 @@ func (p *customCELProvider) FindStructType(typeName string) (*types.Type, bool)
return p.provider.FindStructType(typeName)
}

func (p *customCELProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return p.provider.FindStructFieldNames(typeName)
}

func (p *customCELProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) {
return p.provider.FindStructFieldType(structType, fieldName)
}
558 changes: 558 additions & 0 deletions cel/folding.go

Large diffs are not rendered by default.

676 changes: 676 additions & 0 deletions cel/folding_test.go

Large diffs are not rendered by default.

220 changes: 220 additions & 0 deletions cel/inlining.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cel

import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
)

// InlineVariable holds a variable name to be matched and an AST representing
// the expression graph which should be used to replace it.
type InlineVariable struct {
name string
alias string
def *ast.AST
}

// Name returns the qualified variable or field selection to replace.
func (v *InlineVariable) Name() string {
return v.name
}

// Alias returns the alias to use when performing cel.bind() calls during inlining.
func (v *InlineVariable) Alias() string {
return v.alias
}

// Expr returns the inlined expression value.
func (v *InlineVariable) Expr() ast.Expr {
return v.def.Expr()
}

// Type indicates the inlined expression type.
func (v *InlineVariable) Type() *Type {
return v.def.GetType(v.def.Expr().ID())
}

// NewInlineVariable declares a variable name to be replaced by a checked expression.
func NewInlineVariable(name string, definition *Ast) *InlineVariable {
return NewInlineVariableWithAlias(name, name, definition)
}

// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression.
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
}

// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
//
// If a variable occurs one time, the variable is replaced by the inline definition. If the
// variable occurs more than once, the variable occurences are replaced by a cel.bind() call.
func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer {
return &inliningOptimizer{variables: inlineVars}
}

type inliningOptimizer struct {
variables []*InlineVariable
}

func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
for _, inlineVar := range opt.variables {
matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name()))
// Skip cases where the variable isn't in the expression graph
if len(matches) == 0 {
continue
}

// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 {
opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
continue
}

if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
}
continue
}
// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = nil
ancestors := map[int64]bool{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
parent, found := match, true
for found {
_, hasAncestor := ancestors[parent.ID()]
if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) {
lca = parent
}
ancestors[parent.ID()] = true
parent, found = parent.Parent()
}
}

// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
}
return a
}

// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
prev.SetKindCase(inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
prev.SetKindCase(inlined)
}
}

// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe
// expression appropriate for the inlined type, if possible.
//
// If the rewrite is not possible an error is reported at the inline expression site.
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
ctx.sourceInfo.ClearMacroCall(prev.ID())
if inlined.Kind() == ast.SelectKind {
inlinedSel := inlined.AsSelect()
prev.SetKindCase(
ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName()))
return
}
if inlinedType.IsAssignableType(NullType) {
prev.SetKindCase(
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
prev.SetKindCase(
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
))
return
}
ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType)
}

// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression
// being replaced occurs within a presence test. Value types with a size() method or field selection
// support can be bound.
//
// In future iterations, support may also be added for indexer types which can be rewritten as an `in`
// expression; however, this would imply a rewrite of the inlined expression that may not be necessary
// in most cases.
func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool {
if inlinedType.IsAssignableType(NullType) ||
inlinedType.HasTrait(traits.SizerType) ||
inlinedType.HasTrait(traits.FieldTesterType) {
return true
}
for _, m := range matches {
if m.Kind() != ast.SelectKind {
continue
}
sel := m.AsSelect()
if sel.IsTestOnly() {
return false
}
}
return true
}

// matchVariable matches simple identifiers, select expressions, and presence test expressions
// which match the (potentially) qualified variable name provided as input.
//
// Note, this function does not support inlining against select expressions which includes optional
// field selection. This may be a future refinement.
func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && e.AsIdent() == varName {
return true
}
if e.Kind() == ast.SelectKind {
sel := e.AsSelect()
// While the `ToQualifiedName` call could take the select directly, this
// would skip presence tests from possible matches, which we would like
// to include.
qualName, found := containers.ToQualifiedName(sel.Operand())
return found && qualName+"."+sel.FieldName() == varName
}
return false
}
}
577 changes: 577 additions & 0 deletions cel/inlining_test.go

Large diffs are not rendered by default.

36 changes: 8 additions & 28 deletions cel/io.go
Original file line number Diff line number Diff line change
@@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
checked, err := ast.ToAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
return &Ast{source: src, impl: checked}, nil
}

// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@@ -67,13 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}
return ast.CheckedASTToCheckedExpr(cAst)
return ast.ToProto(a.impl)
}

// ParsedExprToAst converts a parsed expression proto message to an Ast.
@@ -89,18 +77,12 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast {
//
// Prefer ParsedExprToAst if loading expressions from storage.
func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast {
si := parsedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo())
if src == nil {
src = common.NewInfoSource(si)
}
return &Ast{
expr: parsedExpr.GetExpr(),
info: si,
source: src,
src = common.NewInfoSource(parsedExpr.GetSourceInfo())
}
e, _ := ast.ProtoToExpr(parsedExpr.GetExpr())
return &Ast{source: src, impl: ast.NewAST(e, info)}
}

// AstToParsedExpr converts an Ast to an protobuf ParsedExpr value.
@@ -116,9 +98,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
expr := a.Expr()
info := a.SourceInfo()
return parser.Unparse(expr, info)
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
}

// RefValueToValue converts between ref.Val and api.expr.Value.
2 changes: 1 addition & 1 deletion cel/io_test.go
Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@ func TestAstToProto(t *testing.T) {
}
checked, err := AstToCheckedExpr(ast)
if err != nil {
t.Fatalf("AstToCheckeExpr(ast) failed: %v", err)
t.Fatalf("AstToCheckedExpr(ast) failed: %v", err)
}
ast4 := CheckedExprToAst(checked)
if !proto.Equal(ast4.Expr(), ast.Expr()) {
61 changes: 30 additions & 31 deletions cel/library.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import (
"strings"
"time"

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
@@ -28,8 +29,6 @@ import (
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

const (
@@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Types(types.OptionalType),

// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
Macros(ReceiverMacro(optMapMacro, 2, optMap)),

// Global and member functions for working with optional values.
Function(optionalOfFunc,
@@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
@@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption {
}
}

func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewCall(optionalOfFunc,
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}

func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
return meh.NewCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewMemberCall(hasValueFunc, target),
meh.NewComprehension(
meh.NewList(),
unusedIterVar,
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
meh.NewMemberCall(valueFunc, target),
meh.NewLiteral(types.False),
meh.NewIdent(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
meh.NewCall(optionalNoneFunc),
), nil
}

456 changes: 444 additions & 12 deletions cel/macro.go

Large diffs are not rendered by default.

390 changes: 390 additions & 0 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,390 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cel

import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order.
//
// The static optimizer normalizes expression ids and type-checking run between optimization
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
}

// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
}
}

// Optimize applies a sequence of optimizations to an Ast within a given environment.
//
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)

// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
ids := newMonotonicIDGen(ast.MaxID(a.impl))
fac := &optimizerExprFactory{
nextID: ids.nextID,
renumberID: ids.renumberID,
fac: ast.NewExprFactory(),
sourceInfo: optimized.SourceInfo(),
}
ctx := &OptimizerContext{
optimizerExprFactory: fac,
Env: env,
Issues: issues,
}

// Apply the optimizations sequentially.
for _, o := range opt.optimizers {
optimized = o.Optimize(ctx, optimized)
if issues.Err() != nil {
return nil, issues
}
// Normalize expression id metadata including coordination with macro call metadata.
normalizeIDs(env, optimized)

// Recheck the updated expression for any possible type-agreement or validation errors.
parsed := &Ast{
source: a.Source(),
impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())}
checked, iss := ctx.Check(parsed)
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
}

// Return the optimized result.
return &Ast{
source: a.Source(),
impl: optimized,
}, nil
}

// normalizeIDs ensures that the metadata present with an AST is reset in a manner such
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(e *Env, optimized *ast.AST) {
ids := newStableIDGen()
optimized.Expr().RenumberIDs(ids.renumberID)
allExprMap := make(map[int64]ast.Expr)
ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) {
allExprMap[e.ID()] = e
}))
info := optimized.SourceInfo()

// First, update the macro call ids themselves.
for id, call := range info.MacroCalls() {
info.ClearMacroCall(id)
callID := ids.renumberID(id)
if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind {
continue
}
info.SetMacroCall(callID, call)
}

// Second, update the macro call id references to ensure that macro pointers are'
// updated consistently across macros.
for id, call := range info.MacroCalls() {
call.RenumberIDs(ids.renumberID)
resetMacroCall(optimized, call, allExprMap)
info.SetMacroCall(id, call)
}
}

func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) {
modified := []ast.Expr{}
ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) {
if _, found := allExprMap[e.ID()]; found {
modified = append(modified, e)
}
}))
for _, m := range modified {
updated := allExprMap[m.ID()]
m.SetKindCase(updated)
}
}

// newMonotonicIDGen increments numbers from an initial seed value.
func newMonotonicIDGen(seed int64) *monotonicIDGenerator {
return &monotonicIDGenerator{seed: seed}
}

type monotonicIDGenerator struct {
seed int64
}

func (gen *monotonicIDGenerator) nextID() int64 {
gen.seed++
return gen.seed
}

func (gen *monotonicIDGenerator) renumberID(int64) int64 {
return gen.nextID()
}

// newStableIDGen ensures that new ids are only created the first time they are encountered.
func newStableIDGen() *stableIDGenerator {
return &stableIDGenerator{
idMap: make(map[int64]int64),
}
}

type stableIDGenerator struct {
idMap map[int64]int64
nextID int64
}

func (gen *stableIDGenerator) renumberID(id int64) int64 {
if id == 0 {
return 0
}
if newID, found := gen.idMap[id]; found {
return newID
}
gen.nextID++
gen.idMap[id] = gen.nextID
return gen.nextID
}

// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate
// subexpressions and report any errors encountered along the way. The context also embeds the
// optimizerExprFactory which can be used to generate new sub-expressions with expression ids
// consistent with the expectations of a parsed expression.
type OptimizerContext struct {
*Env
*optimizerExprFactory
*Issues
}

// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Optimize(*OptimizerContext, *ast.AST) *ast.AST
}

type optimizerExprFactory struct {
nextID func() int64
renumberID ast.IDGenerator
fac ast.ExprFactory
sourceInfo *ast.SourceInfo
}

// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner
// consistent with the CEL parser / checker.
func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr {
copy := opt.fac.CopyExpr(e)
copy.RenumberIDs(opt.renumberID)
return copy
}

// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression.
//
// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used
// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse
// optimized expressions which use the bind() call.
//
// Example:
//
// cel.bind(myVar, a && b || c, !myVar || (myVar && d))
// - varName: myVar
// - varInit: a && b || c
// - remaining: !myVar || (myVar && d)
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr {
bindID := opt.nextID()
varID := opt.nextID()

varInit = opt.CopyExpr(varInit)
varInit.RenumberIDs(opt.renumberID)

remaining = opt.fac.CopyExpr(remaining)
remaining.RenumberIDs(opt.renumberID)

// Place the expanded macro form in the macro calls list so that the inlined
// call can be unparsed.
opt.sourceInfo.SetMacroCall(macroID,
opt.fac.NewMemberCall(0, "bind",
opt.fac.NewIdent(opt.nextID(), "cel"),
opt.fac.NewIdent(varID, varName),
varInit,
remaining))

// Replace the parent node with the intercepted inlining using cel.bind()-like
// generated comprehension AST.
return opt.fac.NewComprehension(bindID,
opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}),
"#unused",
varName,
opt.fac.CopyExpr(varInit),
opt.fac.NewLiteral(opt.nextID(), types.False),
opt.fac.NewIdent(varID, varName),
opt.fac.CopyExpr(remaining))
}

// NewCall creates a global function call invocation expression.
//
// Example:
//
// countByField(list, fieldName)
// - function: countByField
// - args: [list, fieldName]
func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr {
return opt.fac.NewCall(opt.nextID(), function, args...)
}

// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call.
//
// Example:
//
// list.countByField(fieldName)
// - function: countByField
// - target: list
// - args: [fieldName]
func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
return opt.fac.NewMemberCall(opt.nextID(), function, target, args...)
}

// NewIdent creates a new identifier expression.
//
// Examples:
//
// - simple_var_name
// - qualified.subpackage.var_name
func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr {
return opt.fac.NewIdent(opt.nextID(), name)
}

// NewLiteral creates a new literal expression value.
//
// The range of valid values for a literal generated during optimization is different than for expressions
// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can
// be converted back to a literal-like form.
func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr {
return opt.fac.NewLiteral(opt.nextID(), value)
}

// NewList creates a list expression with a set of optional indices.
//
// Examples:
//
// [a, b]
// - elems: [a, b]
// - optIndices: []
//
// [a, ?b, ?c]
// - elems: [a, b, c]
// - optIndices: [1, 2]
func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr {
return opt.fac.NewList(opt.nextID(), elems, optIndices)
}

// NewMap creates a map from a set of entry expressions which contain a key and value expression.
func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr {
return opt.fac.NewMap(opt.nextID(), entries)
}

// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the
// entry is optional.
//
// Examples:
//
// {a: b}
// - key: a
// - value: b
// - optional: false
//
// {?a: ?b}
// - key: a
// - value: b
// - optional: true
func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional)
}

// NewPresenceTest creates a new presence test macro call.
//
// Example:
//
// has(msg.field_name)
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr {
// Copy the input operand and renumber it.
operand = opt.CopyExpr(operand)
operand.RenumberIDs(opt.renumberID)

// Place the expanded macro form in the macro calls list so that the inlined call can be unparsed.
opt.sourceInfo.SetMacroCall(macroID,
opt.fac.NewCall(0, "has",
opt.fac.NewSelect(opt.nextID(), operand, field)))

// Generate a new presence test macro.
return opt.fac.NewPresenceTest(opt.nextID(), opt.CopyExpr(operand), field)
}

// NewSelect creates a select expression where a field value is selected from an operand.
//
// Example:
//
// msg.field_name
// - operand: msg
// - field: field_name
func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr {
return opt.fac.NewSelect(opt.nextID(), operand, field)
}

// NewStruct creates a new typed struct value with an set of field initializations.
//
// Example:
//
// pkg.TypeName{field: value}
// - typeName: pkg.TypeName
// - fields: [{field: value}]
func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr {
return opt.fac.NewStruct(opt.nextID(), typeName, fields)
}

// NewStructField creates a struct field initialization.
//
// Examples:
//
// {count: 3u}
// - field: count
// - value: 3u
// - optional: false
//
// {?count: x}
// - field: count
// - value: x
// - optional: true
func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr {
return opt.fac.NewStructField(opt.nextID(), field, value, isOptional)
}
2 changes: 2 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
@@ -447,6 +447,8 @@ const (
OptTrackCost EvalOption = 1 << iota

// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.ValidateFormatString() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

60 changes: 6 additions & 54 deletions cel/program.go
Original file line number Diff line number Diff line change
@@ -19,7 +19,6 @@ import (
"fmt"
"sync"

celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@@ -148,7 +147,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()

@@ -208,34 +207,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}

// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
@@ -263,33 +234,16 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}

return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}

func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}

// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
if err != nil {
return nil, err
}
@@ -559,8 +513,6 @@ func (p *evalActivationPool) Put(value any) {
}

var (
emptyEvalState = interpreter.NewEvalState()

// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

51 changes: 19 additions & 32 deletions cel/validator.go
Original file line number Diff line number Diff line change
@@ -21,8 +21,6 @@ import (

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

const (
@@ -69,7 +67,7 @@ type ASTValidator interface {
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}

// ValidatorConfig provides an accessor method for querying validator configuration state.
@@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}

type argChecker func(env *Env, call, arg ast.NavigableExpr) error
type argChecker func(env *Env, call, arg ast.Expr) error

func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
@@ -203,8 +201,8 @@ func (v formatValidator) Name() string {

// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
@@ -221,8 +219,8 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST,
}
}

func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
func evalCall(env *Env, call, arg ast.Expr) error {
ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
prg, err := env.Program(ast)
if err != nil {
return err
@@ -231,7 +229,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error {
return err
}

func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
func compileRegex(_ *Env, _, arg ast.Expr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
@@ -244,25 +242,14 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}

// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}

// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
root := ast.NavigateAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
@@ -273,7 +260,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
et := a.GetType(e.ID())
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
@@ -296,9 +283,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
mapEntry := e.AsMapEntry()
key, val := mapEntry.Key(), mapEntry.Value()
kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
if mapEntry.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
@@ -316,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig
}

func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
parent, found := e.Parent()
for found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
@@ -325,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
parent, found = parent.Parent()
}
return false
}
@@ -353,8 +340,8 @@ func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}

func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return
356 changes: 146 additions & 210 deletions checker/checker.go

Large diffs are not rendered by default.

34 changes: 22 additions & 12 deletions checker/checker_test.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
@@ -2350,19 +2351,15 @@ func TestCheck(t *testing.T) {
t.Errorf("Expected error not thrown: %s", tc.err)
}

actual := cAst.TypeMap[pAst.Expr.Id]
actual := cAst.GetType(pAst.Expr().ID())
if tc.err == "" {
if actual == nil || !actual.IsEquivalentType(tc.outType) {
t.Error(test.DiffMessage("Type Error", actual, tc.outType))
}
}

if tc.out != "" {
chkExpr, err := ast.CheckedASTToCheckedExpr(cAst)
if err != nil {
t.Fatalf("CheckedAstToCheckedExpr() failed: %v", err)
}
actualStr := Print(pAst.Expr, chkExpr)
actualStr := Print(pAst.Expr(), cAst)
if !test.Compare(actualStr, tc.out) {
t.Error(test.DiffMessage("Structure error", actualStr, tc.out))
}
@@ -2445,19 +2442,15 @@ func BenchmarkCheck(b *testing.B) {
b.Errorf("Expected error not thrown: %s", tc.err)
}

actual := cAst.TypeMap[pAst.Expr.Id]
actual := cAst.GetType(pAst.Expr().ID())
if tc.err == "" {
if actual == nil || !actual.IsEquivalentType(tc.outType) {
b.Error(test.DiffMessage("Type Error", actual, tc.outType))
}
}

if tc.out != "" {
chkExpr, err := ast.CheckedASTToCheckedExpr(cAst)
if err != nil {
b.Fatalf("CheckedAstToCheckedExpr() failed: %v", err)
}
actualStr := Print(pAst.Expr, chkExpr)
actualStr := Print(pAst.Expr(), cAst)
if !test.Compare(actualStr, tc.out) {
b.Error(test.DiffMessage("Structure error", actualStr, tc.out))
}
@@ -2551,6 +2544,23 @@ func TestCheckErrorData(t *testing.T) {
}
}

func TestCheckInvalidLiteral(t *testing.T) {
fac := ast.NewExprFactory()
durLiteral := fac.NewLiteral(1, types.Duration{Duration: time.Second})
// This is not valid syntax, just for illustration purposes.
src := common.NewTextSource(`1s`)
parsed := ast.NewAST(durLiteral, ast.NewSourceInfo(src))
reg := newTestRegistry(t)
env, err := NewEnv(containers.DefaultContainer, reg)
if err != nil {
t.Fatalf("NewEnv(cont, reg) failed: %v", err)
}
_, iss := Check(parsed, src, env)
if !strings.Contains(iss.ToDisplayString(), "unexpected literal type") {
t.Errorf("got %s, wanted 'unexpected literal type'", iss.ToDisplayString())
}
}

func testFunction(t testing.TB, name string, opts ...decls.FunctionOpt) *decls.FunctionDecl {
t.Helper()
fn, err := decls.NewFunction(name, opts...)
218 changes: 112 additions & 106 deletions checker/cost.go

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
@@ -261,13 +261,29 @@ func TestCost(t *testing.T) {
expr: `string(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "bytes to string conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
hints: map[string]int64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `string(input) == string(input)`,
wanted: CostEstimate{Min: 3, Max: 152},
},
{
name: "string to bytes conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
expr: `bytes(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "string to bytes conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `bytes(input) == bytes(input)`,
wanted: CostEstimate{Min: 3, Max: 302},
},
{
name: "int to string conversion",
expr: `string(1)`,
18 changes: 7 additions & 11 deletions checker/errors.go
Original file line number Diff line number Diff line change
@@ -15,13 +15,9 @@
package checker

import (
"reflect"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

// typeErrors is a specialization of Errors.
@@ -34,9 +30,9 @@ func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string,
name, FormatCELType(field), FormatCELType(value))
}

func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex ast.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}

func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
@@ -49,7 +45,7 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type
FormatCELType(t))
}

func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}

@@ -61,9 +57,9 @@ func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName strin
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}

func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex ast.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next)
}

func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
@@ -87,6 +83,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}

func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, kind, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected %s type: %v", kind, typeName)
}
34 changes: 17 additions & 17 deletions checker/printer.go
Original file line number Diff line number Diff line change
@@ -17,40 +17,40 @@ package checker
import (
"sort"

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/debug"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

type semanticAdorner struct {
checks *exprpb.CheckedExpr
checked *ast.AST
}

var _ debug.Adorner = &semanticAdorner{}

func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
e, isExpr := elem.(ast.Expr)
if !isExpr {
return result
}
t := a.checks.TypeMap[e.GetId()]
t := a.checked.TypeMap()[e.ID()]
if t != nil {
result += "~"
result += FormatCheckedType(t)
result += FormatCELType(t)
}

switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr,
*exprpb.Expr_CallExpr,
*exprpb.Expr_StructExpr,
*exprpb.Expr_SelectExpr:
if ref, found := a.checks.ReferenceMap[e.GetId()]; found {
if len(ref.GetOverloadId()) == 0 {
switch e.Kind() {
case ast.IdentKind,
ast.CallKind,
ast.ListKind,
ast.StructKind,
ast.SelectKind:
if ref, found := a.checked.ReferenceMap()[e.ID()]; found {
if len(ref.OverloadIDs) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
sort.Strings(ref.OverloadIDs)
for i, overload := range ref.OverloadIDs {
if i == 0 {
result += "^"
} else {
@@ -68,7 +68,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string {
// Print returns a string representation of the Expr message,
// annotated with types from the CheckedExpr. The Expr must
// be a sub-expression embedded in the CheckedExpr.
func Print(e *exprpb.Expr, checks *exprpb.CheckedExpr) string {
a := &semanticAdorner{checks: checks}
func Print(e ast.Expr, checked *ast.AST) string {
a := &semanticAdorner{checked: checked}
return debug.ToAdornedDebugString(e, a)
}
9 changes: 9 additions & 0 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@ package(
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//ext:__subpackages__",
"//interpreter:__subpackages__",
"//parser:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
@@ -14,10 +16,14 @@ go_library(
name = "go_default_library",
srcs = [
"ast.go",
"conversion.go",
"expr.go",
"factory.go",
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
@@ -29,7 +35,9 @@ go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"conversion_test.go",
"expr_test.go",
"navigable_test.go",
],
embed = [
":go_default_library",
@@ -48,5 +56,6 @@ go_test(
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
],
)
459 changes: 332 additions & 127 deletions common/ast/ast.go

Large diffs are not rendered by default.

269 changes: 168 additions & 101 deletions common/ast/ast_test.go
Original file line number Diff line number Diff line change
@@ -15,71 +15,192 @@
package ast_test

import (
"fmt"
"reflect"
"testing"
"time"

"google.golang.org/protobuf/proto"

chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
)

func TestConvertAST(t *testing.T) {
goAST := &ast.CheckedAST{
Expr: &exprpb.Expr{},
SourceInfo: &exprpb.SourceInfo{},
TypeMap: map[int64]*types.Type{
1: types.BoolType,
2: types.DynType,
},
ReferenceMap: map[int64]*ast.ReferenceInfo{
1: ast.NewFunctionReference(overloads.LogicalNot),
2: ast.NewIdentReference("TRUE", types.True),
},
func TestASTCopy(t *testing.T) {
tests := []string{
`'a' == 'b'`,
`'a'.size()`,
`size('a')`,
`has({'a': 1}.a)`,
`{'a': 1}`,
`{'a': 1}['a']`,
`[1, 2, 3].exists(i, i % 2 == 1)`,
`google.expr.proto3.test.TestAllTypes{}`,
`google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`,
}

exprAST := &exprpb.CheckedExpr{
Expr: &exprpb.Expr{},
SourceInfo: &exprpb.SourceInfo{},
TypeMap: map[int64]*exprpb.Type{
1: chkdecls.Bool,
2: chkdecls.Dyn,
},
ReferenceMap: map[int64]*exprpb.Reference{
1: {OverloadId: []string{overloads.LogicalNot}},
2: {
Name: "TRUE",
Value: &exprpb.Constant{
ConstantKind: &exprpb.Constant_BoolValue{BoolValue: true},
},
},
},
for _, tst := range tests {
checked := mustTypeCheck(t, tst)
copyChecked := ast.Copy(checked)
if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo())
}
copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo()))
if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo())
}
checkedPB, err := ast.ToProto(checked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
copyCheckedPB, err := ast.ToProto(copyChecked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
if !proto.Equal(checkedPB, copyCheckedPB) {
t.Errorf("Copy() produced different proto results, got %v, wanted %v",
prototext.Format(checkedPB), prototext.Format(copyCheckedPB))
}
checkedRoundtrip, err := ast.ToAST(checkedPB)
if err != nil {
t.Errorf("ast.ToAST() failed: %v", err)
}
if !reflect.DeepEqual(checked, checkedRoundtrip) {
t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked)
}
}
}

checkedAST, err := ast.CheckedExprToCheckedAST(exprAST)
func TestASTNilSafety(t *testing.T) {
ex, err := ast.ProtoToExpr(nil)
if err != nil {
t.Fatalf("CheckedExprToCheckedAST() failed: %v", err)
t.Fatalf("ast.ProtoToExpr() failed: %v", err)
}
info, err := ast.ProtoToSourceInfo(nil)
if err != nil {
t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err)
}
tests := []*ast.AST{
nil,
ast.NewAST(nil, nil),
ast.NewCheckedAST(nil, nil, nil),
ast.NewCheckedAST(ast.NewAST(nil, nil), nil, nil),
ast.NewAST(ex, info),
ast.NewCheckedAST(ast.NewAST(ex, info), map[int64]*types.Type{}, map[int64]*ast.ReferenceInfo{}),
}
for _, tst := range tests {
a := tst
asts := []*ast.AST{a, ast.Copy(a)}
for _, testAST := range asts {
if testAST.Expr().ID() != 0 {
t.Errorf("Expr().ID() got %v, wanted 0", testAST.Expr().ID())
}
if testAST.SourceInfo().SyntaxVersion() != "" {
t.Errorf("SourceInfo().SyntaxVersion() got %s, wanted empty string", testAST.SourceInfo().SyntaxVersion())
}
if testAST.IsChecked() {
t.Error("IsChecked() returned true, wanted false")
}
if testAST.GetType(testAST.Expr().ID()) != types.DynType {
t.Errorf("GetType() got %v, wanted dyn", testAST.GetType(testAST.Expr().ID()))
}
if len(testAST.GetOverloadIDs(testAST.Expr().ID())) != 0 {
t.Errorf("GetOverloadIDs() got %v, wanted empty set", testAST.GetOverloadIDs(testAST.Expr().ID()))
}
}
}
}

func TestSourceInfo(t *testing.T) {
src := common.NewStringSource("a\n? b\n: c", "custom description")
info := ast.NewSourceInfo(src)
if info.Description() != "custom description" {
t.Errorf("Description() got %s, wanted 'custom description'", info.Description())
}
if len(info.LineOffsets()) != 3 {
t.Errorf("LineOffsets() got %v, wanted 3 offsets", info.LineOffsets())
}
info.SetOffsetRange(1, ast.OffsetRange{Start: 0, Stop: 1}) // a
info.SetOffsetRange(2, ast.OffsetRange{Start: 4, Stop: 5}) // b
info.SetOffsetRange(3, ast.OffsetRange{Start: 8, Stop: 9}) // c
if !reflect.DeepEqual(info.GetStartLocation(1), common.NewLocation(1, 0)) {
t.Errorf("info.GetStartLocation(1) got %v, wanted line 1, col 0", info.GetStartLocation(1))
}
if !reflect.DeepEqual(info.GetStopLocation(1), common.NewLocation(1, 1)) {
t.Errorf("info.GetStopLocation(1) got %v, wanted line 1, col 1", info.GetStopLocation(1))
}
if !reflect.DeepEqual(info.GetStartLocation(2), common.NewLocation(2, 2)) {
t.Errorf("info.GetStartLocation(2) got %v, wanted line 2, col 2", info.GetStartLocation(2))
}
if !reflect.DeepEqual(checkedAST.ReferenceMap, goAST.ReferenceMap) ||
!reflect.DeepEqual(checkedAST.TypeMap, goAST.TypeMap) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST)
if !reflect.DeepEqual(info.GetStopLocation(2), common.NewLocation(2, 3)) {
t.Errorf("info.GetStopLocation(2) got %v, wanted line 2, col 3", info.GetStopLocation(2))
}
if !checkedAST.ReferenceMap[1].Equals(goAST.ReferenceMap[1]) ||
!checkedAST.ReferenceMap[2].Equals(goAST.ReferenceMap[2]) {
t.Error("converted reference info values not equal")
if !reflect.DeepEqual(info.GetStartLocation(3), common.NewLocation(3, 2)) {
t.Errorf("info.GetStartLocation(3) got %v, wanted line 2, col 2", info.GetStartLocation(3))
}
checkedExpr, err := ast.CheckedASTToCheckedExpr(goAST)
if !reflect.DeepEqual(info.GetStopLocation(3), common.NewLocation(3, 3)) {
t.Errorf("info.GetStopLocation(3) got %v, wanted line 2, col 3", info.GetStopLocation(3))
}
if info.ComputeOffset(3, 2) != 8 {
t.Errorf("info.ComputeOffset(3, 2) got %d, wanted 8", info.ComputeOffset(3, 2))
}
}

func TestSourceInfoNilSafety(t *testing.T) {
info, err := ast.ProtoToSourceInfo(nil)
if err != nil {
t.Fatalf("CheckedASTToCheckedExpr() failed: %v", err)
t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err)
}
tests := []*ast.SourceInfo{
nil,
info,
ast.NewSourceInfo(nil),
}
if !proto.Equal(checkedExpr, exprAST) {
t.Errorf("conversion to protobuf did not produce identical results: got %v, wanted %v", checkedExpr, exprAST)
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
testInfo := tc
if testInfo.SyntaxVersion() != "" {
t.Errorf("SyntaxVersion() got %s, wanted empty string", testInfo.SyntaxVersion())
}
if testInfo.Description() != "" {
t.Errorf("Description() got %s, wanted empty string", testInfo.Description())
}
if len(testInfo.LineOffsets()) != 0 {
t.Errorf("LineOffsets() got %v, wanted empty list", testInfo.LineOffsets())
}
if len(testInfo.MacroCalls()) != 0 {
t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls())
}
if call, found := testInfo.GetMacroCall(0); found {
t.Errorf("GetMacroCall(0) got %v, wanted not found", call)
}
if r, found := testInfo.GetOffsetRange(0); found {
t.Errorf("GetOffsetRange(0) got %v, wanted not found", r)
}
if loc := testInfo.GetStartLocation(0); loc != common.NoLocation {
t.Errorf("GetStartLocation(0) got %v, wanted no location", loc)
}
if loc := testInfo.GetStopLocation(0); loc != common.NoLocation {
t.Errorf("GetStopLocation(0) got %v, wanted no location", loc)
}
if off := testInfo.ComputeOffset(1, 0); off != 0 {
t.Errorf("ComputeOffset(1, 0) got %d, wanted 0", off)
}
if off := testInfo.ComputeOffset(-2, 0); off != -1 {
t.Errorf("ComputeOffset(-2, 0) got %d, wanted -1", off)
}
if off := testInfo.ComputeOffset(2, 0); off != -1 {
t.Errorf("ComputeOffset(2, 0) got %d, wanted -1", off)
}
})
}
}

@@ -173,57 +294,3 @@ func TestReferenceInfoAddOverload(t *testing.T) {
t.Error("repeated AddOverload() did not produce equal references")
}
}

func TestReferenceInfoToReferenceExprError(t *testing.T) {
out, err := ast.ReferenceInfoToReferenceExpr(
ast.NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second}))
if err == nil {
t.Errorf("ReferenceInfoToReferenceExpr() got %v, wanted error", out)
}
}

func TestReferenceExprToReferenceInfoError(t *testing.T) {
out, err := ast.ReferenceExprToReferenceInfo(&exprpb.Reference{Value: &exprpb.Constant{}})
if err == nil {
t.Errorf("ReferenceExprToReferenceInfo() got %v, wanted error", out)
}
}

func TestConvertVal(t *testing.T) {
tests := []ref.Val{
types.True,
types.Bytes("bytes"),
types.Double(3.2),
types.Int(-1),
types.NullValue,
types.String("string"),
types.Uint(27),
}
for _, tst := range tests {
c, err := ast.ValToConstant(tst)
if err != nil {
t.Errorf("ValToConstant(%v) failed: %v", tst, err)
}
v, err := ast.ConstantToVal(c)
if err != nil {
t.Errorf("ValToConstant(%v) failed: %v", c, err)
}
if tst.Equal(v) != types.True {
t.Errorf("roundtrip from %v to %v and back did not produce equal results, got %v, wanted %v", tst, c, v, tst)
}
}
}

func TestValToConstantError(t *testing.T) {
out, err := ast.ValToConstant(types.Duration{Duration: time.Duration(10)})
if err == nil {
t.Errorf("ValToConstant() got %v, wanted error", out)
}
}

func TestConstantToValError(t *testing.T) {
out, err := ast.ConstantToVal(&exprpb.Constant{})
if err == nil {
t.Errorf("ConstantToVal() got %v, wanted error", out)
}
}
632 changes: 632 additions & 0 deletions common/ast/conversion.go

Large diffs are not rendered by default.

469 changes: 469 additions & 0 deletions common/ast/conversion_test.go

Large diffs are not rendered by default.

967 changes: 559 additions & 408 deletions common/ast/expr.go

Large diffs are not rendered by default.

724 changes: 392 additions & 332 deletions common/ast/expr_test.go

Large diffs are not rendered by default.

303 changes: 303 additions & 0 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ast

import "github.com/google/cel-go/common/types/ref"

// ExprFactory interfaces defines a set of methods necessary for building native expression values.
type ExprFactory interface {
// CopyExpr creates a deep copy of the input Expr value.
CopyExpr(Expr) Expr

// CopyEntryExpr creates a deep copy of the input EntryExpr value.
CopyEntryExpr(EntryExpr) EntryExpr

// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr

// NewComprehension creates an Expr value representing a comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr

// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr

// NewIdent creates an Expr value representing an identifier.
NewIdent(id int64, name string) Expr

// NewAccuIdent creates an Expr value representing an accumulator identifier within a
//comprehension.
NewAccuIdent(id int64) Expr

// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
NewLiteral(id int64, value ref.Val) Expr

// NewList creates an Expr value representing a list literal expression with optional indices.
//
// Optional indicies will typically be empty unless the CEL optional types are enabled.
NewList(id int64, elems []Expr, optIndices []int32) Expr

// NewMap creates an Expr value representing a map literal expression
NewMap(id int64, entries []EntryExpr) Expr

// NewMapEntry creates a MapEntry with a given key, value, and a flag indicating whether
// the key is optionally set.
NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr

// NewPresenceTest creates an Expr representing a field presence test on an operand expression.
NewPresenceTest(id int64, operand Expr, field string) Expr

// NewSelect creates an Expr representing a field selection on an operand expression.
NewSelect(id int64, operand Expr, field string) Expr

// NewStruct creates an Expr value representing a struct literal with a given type name and a
// set of field initializers.
NewStruct(id int64, typeName string, fields []EntryExpr) Expr

// NewStructField creates a StructField with a given field name, value, and a flag indicating
// whether the field is optionally set.
NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr

// NewUnspecifiedExpr creates an empty expression node.
NewUnspecifiedExpr(id int64) Expr

isExprFactory()
}

type baseExprFactory struct{}

// NewExprFactory creates an ExprFactory instance.
func NewExprFactory() ExprFactory {
return &baseExprFactory{}
}

func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: nilExpr,
args: args,
isMember: false,
})
}

func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr, args ...Expr) Expr {
if len(args) == 0 {
args = []Expr{}
}
return fac.newExpr(
id,
&baseCallExpr{
function: function,
target: target,
args: args,
isMember: true,
})
}

func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
loopStep: loopStep,
result: result,
})
}

func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
return fac.newExpr(id, baseIdentExpr(name))
}

func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
return fac.NewIdent(id, "__result__")
}

func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {
return fac.newExpr(id, &baseLiteral{Val: value})
}

func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr {
optIndexMap := make(map[int32]struct{}, len(optIndices))
for _, idx := range optIndices {
optIndexMap[idx] = struct{}{}
}
return fac.newExpr(id,
&baseListExpr{
elements: elems,
optIndices: optIndices,
optIndexMap: optIndexMap,
})
}

func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr {
return fac.newExpr(id, &baseMapExpr{entries: entries})
}

func (fac *baseExprFactory) NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseMapEntry{
key: key,
value: value,
isOptional: isOptional,
})
}

func (fac *baseExprFactory) NewPresenceTest(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
testOnly: true,
})
}

func (fac *baseExprFactory) NewSelect(id int64, operand Expr, field string) Expr {
return fac.newExpr(
id,
&baseSelectExpr{
operand: operand,
field: field,
})
}

func (fac *baseExprFactory) NewStruct(id int64, typeName string, fields []EntryExpr) Expr {
return fac.newExpr(
id,
&baseStructExpr{
typeName: typeName,
fields: fields,
})
}

func (fac *baseExprFactory) NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr {
return fac.newEntryExpr(
id,
&baseStructField{
field: field,
value: value,
isOptional: isOptional,
})
}

func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr {
return fac.newExpr(id, nil)
}

func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
// unwrap navigable expressions to avoid unnecessary allocations during copying.
if nav, ok := e.(*navigableExprImpl); ok {
e = nav.Expr
}
switch e.Kind() {
case CallKind:
c := e.AsCall()
argsCopy := make([]Expr, len(c.Args()))
for i, arg := range c.Args() {
argsCopy[i] = fac.CopyExpr(arg)
}
if !c.IsMemberFunction() {
return fac.NewCall(e.ID(), c.FunctionName(), argsCopy...)
}
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),
fac.CopyExpr(compre.LoopStep()),
fac.CopyExpr(compre.Result()))
case IdentKind:
return fac.NewIdent(e.ID(), e.AsIdent())
case ListKind:
l := e.AsList()
elemsCopy := make([]Expr, l.Size())
for i, elem := range l.Elements() {
elemsCopy[i] = fac.CopyExpr(elem)
}
return fac.NewList(e.ID(), elemsCopy, l.OptionalIndices())
case LiteralKind:
return fac.NewLiteral(e.ID(), e.AsLiteral())
case MapKind:
m := e.AsMap()
entriesCopy := make([]EntryExpr, m.Size())
for i, entry := range m.Entries() {
entriesCopy[i] = fac.CopyEntryExpr(entry)
}
return fac.NewMap(e.ID(), entriesCopy)
case SelectKind:
s := e.AsSelect()
if s.IsTestOnly() {
return fac.NewPresenceTest(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
}
return fac.NewSelect(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName())
case StructKind:
s := e.AsStruct()
fieldsCopy := make([]EntryExpr, len(s.Fields()))
for i, field := range s.Fields() {
fieldsCopy[i] = fac.CopyEntryExpr(field)
}
return fac.NewStruct(e.ID(), s.TypeName(), fieldsCopy)
default:
return fac.NewUnspecifiedExpr(e.ID())
}
}

func (fac *baseExprFactory) CopyEntryExpr(e EntryExpr) EntryExpr {
switch e.Kind() {
case MapEntryKind:
entry := e.AsMapEntry()
return fac.NewMapEntry(e.ID(),
fac.CopyExpr(entry.Key()), fac.CopyExpr(entry.Value()), entry.IsOptional())
case StructFieldKind:
field := e.AsStructField()
return fac.NewStructField(e.ID(),
field.Name(), fac.CopyExpr(field.Value()), field.IsOptional())
default:
return fac.newEntryExpr(e.ID(), nil)
}
}

func (*baseExprFactory) isExprFactory() {}

func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr {
return &expr{
id: id,
exprKindCase: e,
}
}

func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExpr {
return &entryExpr{
id: id,
entryExprKindCase: e,
}
}

var (
defaultFactory = &baseExprFactory{}
)
652 changes: 652 additions & 0 deletions common/ast/navigable.go

Large diffs are not rendered by default.

601 changes: 601 additions & 0 deletions common/ast/navigable_test.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions common/containers/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)

@@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
],
)
22 changes: 11 additions & 11 deletions common/containers/container.go
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ import (
"fmt"
"strings"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)

var (
@@ -297,19 +297,19 @@ func Name(name string) ContainerOption {

// ToQualifiedName converts an expression AST into a qualified name if possible, with a boolean
// 'found' value that indicates if the conversion is successful.
func ToQualifiedName(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
id := e.GetIdentExpr()
return id.GetName(), true
case *exprpb.Expr_SelectExpr:
sel := e.GetSelectExpr()
func ToQualifiedName(e ast.Expr) (string, bool) {
switch e.Kind() {
case ast.IdentKind:
id := e.AsIdent()
return id, true
case ast.SelectKind:
sel := e.AsSelect()
// Test only expressions are not valid as qualified names.
if sel.GetTestOnly() {
if sel.IsTestOnly() {
return "", false
}
if qual, found := ToQualifiedName(sel.GetOperand()); found {
return qual + "." + sel.GetField(), true
if qual, found := ToQualifiedName(sel.Operand()); found {
return qual + "." + sel.FieldName(), true
}
}
return "", false
36 changes: 8 additions & 28 deletions common/containers/container_test.go
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ import (
"reflect"
"testing"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
)

func TestContainers_ResolveCandidateNames(t *testing.T) {
@@ -204,28 +204,16 @@ func TestContainers_Extend_Name(t *testing.T) {
}

func TestContainers_ToQualifiedName(t *testing.T) {
ident := &exprpb.Expr{
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "var",
},
},
}
fac := ast.NewExprFactory()
ident := fac.NewIdent(1, "var")
idName, found := ToQualifiedName(ident)
if !found {
t.Errorf("got not found from %v expr, wanted found", ident)
}
if idName != "var" {
t.Errorf("got %v, wanted 'var'", idName)
}
sel := &exprpb.Expr{
ExprKind: &exprpb.Expr_SelectExpr{
SelectExpr: &exprpb.Expr_Select{
Operand: ident,
Field: "qualifier",
},
},
}
sel := fac.NewSelect(2, ident, "qualifier")
qualName, found := ToQualifiedName(sel)
if !found {
t.Errorf("got not found from %v expr, wanted found", sel)
@@ -234,22 +222,14 @@ func TestContainers_ToQualifiedName(t *testing.T) {
t.Errorf("got %v, wanted 'var.qualifier'", qualName)
}

sel.GetSelectExpr().TestOnly = true
_, found = ToQualifiedName(sel)
pres := fac.NewPresenceTest(2, ident, "qualifier")
_, found = ToQualifiedName(pres)
if found {
t.Error("got found, wanted not found for test-only expression")
}

unary := &exprpb.Expr{
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: "!_",
Args: []*exprpb.Expr{ident},
},
},
}
sel.GetSelectExpr().TestOnly = false
sel.GetSelectExpr().Operand = unary
unary := fac.NewCall(2, "!_", ident)
sel = fac.NewSelect(3, unary, "qualifier")
_, found = ToQualifiedName(sel)
if found {
t.Errorf("got found, wanted not found for %v", sel)
4 changes: 3 additions & 1 deletion common/debug/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"//common/ast:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
],
)
156 changes: 77 additions & 79 deletions common/debug/debug.go
Original file line number Diff line number Diff line change
@@ -22,7 +22,9 @@ import (
"strconv"
"strings"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

// Adorner returns debug metadata that will be tacked on to the string
@@ -38,7 +40,7 @@ type Writer interface {

// Buffer pushes an expression into an internal queue of expressions to
// write to a string.
Buffer(e *exprpb.Expr)
Buffer(e ast.Expr)
}

type emptyDebugAdorner struct {
@@ -51,12 +53,12 @@ func (a *emptyDebugAdorner) GetMetadata(e any) string {
}

// ToDebugString gives the unadorned string representation of the Expr.
func ToDebugString(e *exprpb.Expr) string {
func ToDebugString(e ast.Expr) string {
return ToAdornedDebugString(e, emptyAdorner)
}

// ToAdornedDebugString gives the adorned string representation of the Expr.
func ToAdornedDebugString(e *exprpb.Expr, adorner Adorner) string {
func ToAdornedDebugString(e ast.Expr, adorner Adorner) string {
w := newDebugWriter(adorner)
w.Buffer(e)
return w.String()
@@ -78,49 +80,51 @@ func newDebugWriter(a Adorner) *debugWriter {
}
}

func (w *debugWriter) Buffer(e *exprpb.Expr) {
func (w *debugWriter) Buffer(e ast.Expr) {
if e == nil {
return
}
switch e.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
w.append(formatLiteral(e.GetConstExpr()))
case *exprpb.Expr_IdentExpr:
w.append(e.GetIdentExpr().Name)
case *exprpb.Expr_SelectExpr:
w.appendSelect(e.GetSelectExpr())
case *exprpb.Expr_CallExpr:
w.appendCall(e.GetCallExpr())
case *exprpb.Expr_ListExpr:
w.appendList(e.GetListExpr())
case *exprpb.Expr_StructExpr:
w.appendStruct(e.GetStructExpr())
case *exprpb.Expr_ComprehensionExpr:
w.appendComprehension(e.GetComprehensionExpr())
switch e.Kind() {
case ast.LiteralKind:
w.append(formatLiteral(e.AsLiteral()))
case ast.IdentKind:
w.append(e.AsIdent())
case ast.SelectKind:
w.appendSelect(e.AsSelect())
case ast.CallKind:
w.appendCall(e.AsCall())
case ast.ListKind:
w.appendList(e.AsList())
case ast.MapKind:
w.appendMap(e.AsMap())
case ast.StructKind:
w.appendStruct(e.AsStruct())
case ast.ComprehensionKind:
w.appendComprehension(e.AsComprehension())
}
w.adorn(e)
}

func (w *debugWriter) appendSelect(sel *exprpb.Expr_Select) {
w.Buffer(sel.GetOperand())
func (w *debugWriter) appendSelect(sel ast.SelectExpr) {
w.Buffer(sel.Operand())
w.append(".")
w.append(sel.GetField())
if sel.TestOnly {
w.append(sel.FieldName())
if sel.IsTestOnly() {
w.append("~test-only~")
}
}

func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
if call.Target != nil {
w.Buffer(call.GetTarget())
func (w *debugWriter) appendCall(call ast.CallExpr) {
if call.IsMemberFunction() {
w.Buffer(call.Target())
w.append(".")
}
w.append(call.GetFunction())
w.append(call.FunctionName())
w.append("(")
if len(call.GetArgs()) > 0 {
if len(call.Args()) > 0 {
w.addIndent()
w.appendLine()
for i, arg := range call.GetArgs() {
for i, arg := range call.Args() {
if i > 0 {
w.append(",")
w.appendLine()
@@ -133,12 +137,12 @@ func (w *debugWriter) appendCall(call *exprpb.Expr_Call) {
w.append(")")
}

func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
func (w *debugWriter) appendList(list ast.ListExpr) {
w.append("[")
if len(list.GetElements()) > 0 {
if len(list.Elements()) > 0 {
w.appendLine()
w.addIndent()
for i, elem := range list.GetElements() {
for i, elem := range list.Elements() {
if i > 0 {
w.append(",")
w.appendLine()
@@ -151,119 +155,113 @@ func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) {
w.append("]")
}

func (w *debugWriter) appendStruct(obj *exprpb.Expr_CreateStruct) {
if obj.MessageName != "" {
w.appendObject(obj)
} else {
w.appendMap(obj)
}
}

func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(obj.GetMessageName())
func (w *debugWriter) appendStruct(obj ast.StructExpr) {
w.append(obj.TypeName())
w.append("{")
if len(obj.GetEntries()) > 0 {
if len(obj.Fields()) > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, f := range obj.Fields() {
field := f.AsStructField()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if field.IsOptional() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(field.Name())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(field.Value())
w.adorn(f)
}
w.removeIndent()
w.appendLine()
}
w.append("}")
}

func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
func (w *debugWriter) appendMap(m ast.MapExpr) {
w.append("{")
if len(obj.GetEntries()) > 0 {
if m.Size() > 0 {
w.appendLine()
w.addIndent()
for i, entry := range obj.GetEntries() {
for i, e := range m.Entries() {
entry := e.AsMapEntry()
if i > 0 {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
if entry.IsOptional() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.Buffer(entry.Key())
w.append(":")
w.Buffer(entry.GetValue())
w.adorn(entry)
w.Buffer(entry.Value())
w.adorn(e)
}
w.removeIndent()
w.appendLine()
}
w.append("}")
}

func (w *debugWriter) appendComprehension(comprehension *exprpb.Expr_Comprehension) {
func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) {
w.append("__comprehension__(")
w.addIndent()
w.appendLine()
w.append("// Variable")
w.appendLine()
w.append(comprehension.GetIterVar())
w.append(comprehension.IterVar())
w.append(",")
w.appendLine()
w.append("// Target")
w.appendLine()
w.Buffer(comprehension.GetIterRange())
w.Buffer(comprehension.IterRange())
w.append(",")
w.appendLine()
w.append("// Accumulator")
w.appendLine()
w.append(comprehension.GetAccuVar())
w.append(comprehension.AccuVar())
w.append(",")
w.appendLine()
w.append("// Init")
w.appendLine()
w.Buffer(comprehension.GetAccuInit())
w.Buffer(comprehension.AccuInit())
w.append(",")
w.appendLine()
w.append("// LoopCondition")
w.appendLine()
w.Buffer(comprehension.GetLoopCondition())
w.Buffer(comprehension.LoopCondition())
w.append(",")
w.appendLine()
w.append("// LoopStep")
w.appendLine()
w.Buffer(comprehension.GetLoopStep())
w.Buffer(comprehension.LoopStep())
w.append(",")
w.appendLine()
w.append("// Result")
w.appendLine()
w.Buffer(comprehension.GetResult())
w.Buffer(comprehension.Result())
w.append(")")
w.removeIndent()
}

func formatLiteral(c *exprpb.Constant) string {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return fmt.Sprintf("%t", c.GetBoolValue())
case *exprpb.Constant_BytesValue:
return fmt.Sprintf("b\"%s\"", string(c.GetBytesValue()))
case *exprpb.Constant_DoubleValue:
return fmt.Sprintf("%v", c.GetDoubleValue())
case *exprpb.Constant_Int64Value:
return fmt.Sprintf("%d", c.GetInt64Value())
case *exprpb.Constant_StringValue:
return strconv.Quote(c.GetStringValue())
case *exprpb.Constant_Uint64Value:
return fmt.Sprintf("%du", c.GetUint64Value())
case *exprpb.Constant_NullValue:
func formatLiteral(c ref.Val) string {
switch v := c.(type) {
case types.Bool:
return fmt.Sprintf("%t", v)
case types.Bytes:
return fmt.Sprintf("b\"%s\"", string(v))
case types.Double:
return fmt.Sprintf("%v", float64(v))
case types.Int:
return fmt.Sprintf("%d", int64(v))
case types.String:
return strconv.Quote(string(v))
case types.Uint:
return fmt.Sprintf("%du", uint64(v))
case types.Null:
return "null"
default:
panic("Unknown constant type")
3 changes: 2 additions & 1 deletion common/decls/decls.go
Original file line number Diff line number Diff line change
@@ -243,7 +243,8 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
for _, o := range f.overloads {
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
2 changes: 1 addition & 1 deletion common/errors.go
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ func (e *Errors) GetErrors() []*Error {
// Append creates a new Errors object with the current and input errors.
func (e *Errors) Append(errs []*Error) *Errors {
return &Errors{
errors: append(e.errors, errs...),
errors: append(e.errors[:], errs...),
source: e.source,
numErrors: e.numErrors + len(errs),
maxErrorsToReport: e.maxErrorsToReport,
25 changes: 23 additions & 2 deletions common/types/provider.go
Original file line number Diff line number Diff line change
@@ -54,6 +54,10 @@ type Provider interface {
// Returns false if not found.
FindStructType(structType string) (*Type, bool)

// FindStructFieldNames returns thet field names associated with the type, if the type
// is found.
FindStructFieldNames(structType string) ([]string, bool)

// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
@@ -154,7 +158,7 @@ func (p *Registry) EnumValue(enumName string) ref.Val {
return Int(enumVal.Value())
}

// FieldFieldType returns the field type for a checked type value. Returns false if
// FindFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
//
// Deprecated: use FindStructFieldType
@@ -173,7 +177,24 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType,
GetFrom: field.GetFrom}, true
}

// FieldStructFieldType returns the field type for a checked type value. Returns
// FindStructFieldNames returns the set of field names for the given struct type,
// if the type exists in the registry.
func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return []string{}, false
}
fieldMap := msgType.FieldMap()
fields := make([]string, len(fieldMap))
idx := 0
for f := range fieldMap {
fields[idx] = f
idx++
}
return fields, true
}

// FindStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
34 changes: 34 additions & 0 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"time"
@@ -132,6 +133,39 @@ func TestRegistryFindStructType(t *testing.T) {
}
}

func TestRegistryFindStructFieldNames(t *testing.T) {
reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{})
tests := []struct {
typeName string
fields []string
}{
{
typeName: "google.api.expr.v1alpha1.Reference",
fields: []string{"name", "overload_id", "value"},
},
{
typeName: "google.api.expr.v1alpha1.Decl",
fields: []string{"name", "ident", "function"},
},
{
typeName: "invalid.TypeName",
fields: []string{},
},
}

for _, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) {
fields, _ := reg.FindStructFieldNames(tc.typeName)
sort.Strings(fields)
sort.Strings(tc.fields)
if !reflect.DeepEqual(fields, tc.fields) {
t.Errorf("got %v, wanted %v", fields, tc.fields)
}
})
}
}

func TestRegistryFindStructFieldType(t *testing.T) {
reg := newTestRegistry(t)
err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile())
4 changes: 2 additions & 2 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ go_library(
name = "go_default_library",
srcs = [
"encoders.go",
"formatting.go",
"guards.go",
"lists.go",
"math.go",
@@ -21,13 +22,13 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker/decls:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
@@ -60,7 +61,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
14 changes: 14 additions & 0 deletions ext/README.md
Original file line number Diff line number Diff line change
@@ -414,3 +414,17 @@ Examples:

'TacoCat'.upperAscii() // returns 'TACOCAT'
'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'

### Reverse

Returns a new string whose characters are the same as the target string, only formatted in
reverse order.
This function relies on converting strings to rune arrays in order to reverse.
It can be located in Version 3 of strings.

<string>.reverse() -> <string>

Examples:

'gums'.reverse() // returns 'smug'
'John Smith'.reverse() // returns 'htimS nhoJ'
24 changes: 12 additions & 12 deletions ext/bindings.go
Original file line number Diff line number Diff line change
@@ -16,8 +16,8 @@ package ext

import (
"github.com/google/cel-go/cel"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
)

// Bindings returns a cel.EnvOption to configure support for local variable
@@ -61,7 +61,7 @@ func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.NewReceiverMacro(bindMacro, 3, celBind),
cel.ReceiverMacro(bindMacro, 3, celBind),
),
}
}
@@ -70,27 +70,27 @@ func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
switch varIdent.Kind() {
case ast.IdentKind:
varName = varIdent.AsIdent()
default:
return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers")
return nil, mef.NewError(varIdent.ID(), "cel.bind() variable names must be simple identifiers")
}
varInit := args[1]
resultExpr := args[2]
return meh.Fold(
return mef.NewComprehension(
mef.NewList(),
unusedIterVar,
meh.NewList(),
varName,
varInit,
meh.LiteralBool(false),
meh.Ident(varName),
mef.NewLiteral(types.False),
mef.NewIdent(varName),
resultExpr,
), nil
}
Loading