Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: google/cel-go
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.17.6
Choose a base ref
...
head repository: google/cel-go
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.17.7
Choose a head ref
  • 1 commit
  • 11 files changed
  • 1 contributor

Commits on Oct 30, 2023

  1. 3
    Copy the full SHA
    cfefae0 View commit details
Showing with 603 additions and 124 deletions.
  1. +35 −19 cel/cel_test.go
  2. +9 −1 cel/env.go
  3. +19 −0 cel/options.go
  4. +28 −7 cel/program.go
  5. +44 −12 checker/cost.go
  6. +55 −23 checker/cost_test.go
  7. +1 −0 ext/BUILD.bazel
  8. +60 −1 ext/sets.go
  9. +288 −50 ext/sets_test.go
  10. +36 −11 interpreter/runtimecost.go
  11. +28 −0 interpreter/runtimecost_test.go
54 changes: 35 additions & 19 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
@@ -91,7 +91,12 @@ func Test_ExampleWithBuiltins(t *testing.T) {
}

func TestEval(t *testing.T) {
env, err := NewEnv(Variable("input", ListType(IntType)))
env, err := NewEnv(
Variable("input", ListType(IntType)),
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
@@ -114,6 +119,9 @@ func TestEval(t *testing.T) {
ctx := context.Background()
prgOpts := []ProgramOption{
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
EvalOptions(OptOptimize, OptTrackCost),
InterruptCheckFrequency(100),
}
@@ -1338,7 +1346,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
name string
expr string
decls []EnvOption
hints map[string]int64
hints map[string]uint64
want checker.CostEstimate
in any
}{
@@ -1362,7 +1370,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
Variable("str1", StringType),
Variable("str2", StringType),
},
hints: map[string]int64{"str1": 10, "str2": 10},
hints: map[string]uint64{"str1": 10, "str2": 10},
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]any{"str1": "val1111111", "str2": "val2222222"},
},
@@ -1373,9 +1381,15 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if tc.hints == nil {
tc.hints = map[string]int64{}
tc.hints = map[string]uint64{}
}
env := testEnv(t, tc.decls...)
envOpts := []EnvOption{
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
}
envOpts = append(envOpts, tc.decls...)
env := testEnv(t, envOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
@@ -1394,7 +1408,12 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := env.Program(checkedAst, CostTracking(testRuntimeCostEstimator{}))
program, err := env.Program(checkedAst,
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
)
if err != nil {
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
@@ -2631,27 +2650,26 @@ func BenchmarkDynamicDispatch(b *testing.B) {

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]int64
hints map[string]uint64
}

func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: uint64(l)}
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}

func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}
return nil
}

type testRuntimeCostEstimator struct {
func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}

type testRuntimeCostEstimator struct{}

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
@@ -2667,13 +2685,11 @@ func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []r
argsSize[i] = 1
}
}
return nil
}

switch overloadID {
case overloads.TimestampToYear:
return &timeToYearCost
default:
return nil
}
func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 {
return &timeToYearCost
}

func testEnv(t testing.TB, opts ...EnvOption) *Env {
10 changes: 9 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
@@ -119,6 +119,7 @@ type Env struct {
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption

// Internal parser representation
prsr *parser.Parser
@@ -181,6 +182,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}

@@ -356,6 +358,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

ext := &Env{
Container: e.Container,
@@ -371,6 +375,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
@@ -557,7 +562,10 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
return checker.Cost(checked, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
19 changes: 19 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
@@ -469,6 +470,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}

// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}

// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
35 changes: 28 additions & 7 deletions cel/program.go
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
@@ -130,10 +130,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}

func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)

return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
@@ -155,9 +159,10 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}

// Configure the program via the ProgramOption values.
@@ -242,6 +247,12 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
@@ -371,7 +382,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
@@ -384,7 +399,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
@@ -412,7 +430,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
56 changes: 44 additions & 12 deletions checker/cost.go
Original file line number Diff line number Diff line change
@@ -230,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
@@ -242,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}

var (
@@ -260,9 +264,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
@@ -291,6 +296,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error

// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
@@ -303,15 +309,30 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}

// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate

// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}

// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checker,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
@@ -532,7 +553,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}

if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
@@ -682,3 +710,7 @@ func isScalar(t *types.Type) bool {
}
return false
}

var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)
78 changes: 55 additions & 23 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
package checker

import (
"math"
"strings"
"testing"

@@ -43,7 +44,7 @@ func TestCost(t *testing.T) {
name string
expr string
vars []*decls.VariableDecl
hints map[string]int64
hints map[string]uint64
options []CostOption
wanted CostEstimate
}{
@@ -128,14 +129,14 @@ func TestCost(t *testing.T) {
{
name: "all comprehension",
vars: []*decls.VariableDecl{decls.NewVariable("input", allList)},
hints: map[string]int64{"input": 100},
hints: map[string]uint64{"input": 100},
expr: `input.all(x, true)`,
wanted: CostEstimate{Min: 2, Max: 302},
},
{
name: "nested all comprehension",
vars: []*decls.VariableDecl{decls.NewVariable("input", nestedList)},
hints: map[string]int64{"input": 50, "input.@items": 10},
hints: map[string]uint64{"input": 50, "input.@items": 10},
expr: `input.all(x, x.all(y, true))`,
wanted: CostEstimate{Min: 2, Max: 1752},
},
@@ -147,7 +148,7 @@ func TestCost(t *testing.T) {
{
name: "variable cost function",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
expr: `input.matches('[0-9]')`,
wanted: CostEstimate{Min: 3, Max: 103},
},
@@ -256,29 +257,29 @@ func TestCost(t *testing.T) {
{
name: "bytes to string conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
expr: `string(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "bytes to string conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `string(input) == string(input)`,
wanted: CostEstimate{Min: 3, Max: 152},
},
{
name: "string to bytes conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
expr: `bytes(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "string to bytes conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `bytes(input) == bytes(input)`,
wanted: CostEstimate{Min: 3, Max: 302},
@@ -295,7 +296,7 @@ func TestCost(t *testing.T) {
decls.NewVariable("input", types.StringType),
decls.NewVariable("arg1", types.StringType),
},
hints: map[string]int64{"input": 500, "arg1": 500},
hints: map[string]uint64{"input": 500, "arg1": 500},
wanted: CostEstimate{Min: 2, Max: 2502},
},
{
@@ -304,7 +305,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", types.StringType),
},
hints: map[string]int64{"input": 500},
hints: map[string]uint64{"input": 500},
wanted: CostEstimate{Min: 3, Max: 103},
},
{
@@ -314,7 +315,7 @@ func TestCost(t *testing.T) {
decls.NewVariable("input", types.StringType),
decls.NewVariable("arg1", types.StringType),
},
hints: map[string]int64{"arg1": 500},
hints: map[string]uint64{"arg1": 500},
wanted: CostEstimate{Min: 2, Max: 52},
},
{
@@ -324,7 +325,7 @@ func TestCost(t *testing.T) {
decls.NewVariable("input", types.StringType),
decls.NewVariable("arg1", types.StringType),
},
hints: map[string]int64{"arg1": 500},
hints: map[string]uint64{"arg1": 500},
wanted: CostEstimate{Min: 2, Max: 52},
},
{
@@ -351,7 +352,7 @@ func TestCost(t *testing.T) {
decls.NewVariable("input1", allList),
decls.NewVariable("input2", allList),
},
hints: map[string]int64{"input1": 1, "input2": 1},
hints: map[string]uint64{"input1": 1, "input2": 1},
wanted: CostEstimate{Min: 4, Max: 7},
},
{
@@ -360,7 +361,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", allMap),
},
hints: map[string]int64{"input": 10},
hints: map[string]uint64{"input": 10},
wanted: CostEstimate{Min: 2, Max: 82},
},
{
@@ -369,7 +370,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", nestedMap),
},
hints: map[string]int64{"input": 5, "input.@values": 10},
hints: map[string]uint64{"input": 5, "input.@values": 10},
wanted: CostEstimate{Min: 2, Max: 187},
},
{
@@ -378,7 +379,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", nestedMap),
},
hints: map[string]int64{"input": 5, "input.@keys": 10},
hints: map[string]uint64{"input": 5, "input.@keys": 10},
wanted: CostEstimate{Min: 2, Max: 32},
},
{
@@ -387,7 +388,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", nestedMap),
},
hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5},
hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5},
wanted: CostEstimate{Min: 2, Max: 34},
},
{
@@ -396,7 +397,7 @@ func TestCost(t *testing.T) {
vars: []*decls.VariableDecl{
decls.NewVariable("input", nestedMap),
},
hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5},
hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5},
wanted: CostEstimate{Min: 2, Max: 34},
},
{
@@ -406,7 +407,7 @@ func TestCost(t *testing.T) {
decls.NewVariable("list1", types.NewListType(types.IntType)),
decls.NewVariable("list2", types.NewListType(types.IntType)),
},
hints: map[string]int64{"list1": 10, "list2": 10},
hints: map[string]uint64{"list1": 10, "list2": 10},
wanted: CostEstimate{Min: 4, Max: 64},
},
{
@@ -416,9 +417,30 @@ func TestCost(t *testing.T) {
decls.NewVariable("str1", types.StringType),
decls.NewVariable("str2", types.StringType),
},
hints: map[string]int64{"str1": 10, "str2": 10},
hints: map[string]uint64{"str1": 10, "str2": 10},
wanted: CostEstimate{Min: 2, Max: 6},
},
{
name: "str concat custom cost estimate",
expr: `"abcdefg".contains(str1 + str2)`,
vars: []*decls.VariableDecl{
decls.NewVariable("str1", types.StringType),
decls.NewVariable("str2", types.StringType),
},
hints: map[string]uint64{"str1": 10, "str2": 10},
options: []CostOption{
OverloadCostEstimate(overloads.ContainsString,
func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate {
if target != nil && len(args) == 1 {
strSize := estimateSize(estimator, *target).MultiplyByCostFactor(0.2)
subSize := estimateSize(estimator, args[0]).MultiplyByCostFactor(0.2)
return &CallEstimate{CostEstimate: strSize.Multiply(subSize)}
}
return nil
}),
},
wanted: CostEstimate{Min: 2, Max: 12},
},
{
name: "list size comparison",
expr: `list1.size() == list2.size()`,
@@ -485,7 +507,7 @@ func TestCost(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if tc.hints == nil {
tc.hints = map[string]int64{}
tc.hints = map[string]uint64{}
}
p, err := parser.NewParser(parser.Macros(parser.AllMacros...))
if err != nil {
@@ -530,12 +552,12 @@ func TestCost(t *testing.T) {
}

type testCostEstimator struct {
hints map[string]int64
hints map[string]uint64
}

func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &SizeEstimate{Min: 0, Max: uint64(l)}
return &SizeEstimate{Min: 0, Max: l}
}
return nil
}
@@ -547,3 +569,13 @@ func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target
}
return nil
}

func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate {
if l := node.ComputedSize(); l != nil {
return *l
}
if l := estimator.EstimateSize(node); l != nil {
return *l
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}
1 change: 1 addition & 0 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
61 changes: 60 additions & 1 deletion ext/sets.go
Original file line number Diff line number Diff line change
@@ -15,10 +15,14 @@
package ext

import (
"math"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)

// Sets returns a cel.EnvOption to configure namespaced set relationship
@@ -95,12 +99,24 @@ func (setsLib) CompileOptions() []cel.EnvOption {
cel.Function("sets.intersects",
cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsIntersects))),
cel.CostEstimatorOptions(
checker.OverloadCostEstimate("list_sets_contains_list", estimateSetsCost(1)),
checker.OverloadCostEstimate("list_sets_intersects_list", estimateSetsCost(1)),
// equivalence requires potentially two m*n comparisons to ensure each list is contained by the other
checker.OverloadCostEstimate("list_sets_equivalent_list", estimateSetsCost(2)),
),
}
}

// ProgramOptions implements the Library interface method.
func (setsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.OverloadCostTracker("list_sets_contains_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_intersects_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_equivalent_list", trackSetsCost(2)),
),
}
}

func setsIntersects(listA, listB ref.Val) ref.Val {
@@ -136,3 +152,46 @@ func setsEquivalent(listA, listB ref.Val) ref.Val {
}
return setsContains(listB, listA)
}

func estimateSetsCost(costFactor float64) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
}
return nil
}
}

func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate {
if l := node.ComputedSize(); l != nil {
return *l
}
if l := estimator.EstimateSize(node); l != nil {
return *l
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}

func trackSetsCost(costFactor float64) interpreter.FunctionTracker {
return func(args []ref.Val, _ ref.Val) *uint64 {
lhsSize := actualSize(args[0])
rhsSize := actualSize(args[1])
cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor)
return &cost
}
}

func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}

var (
callCostEstimate = checker.CostEstimate{Min: 1, Max: 1}
callCost = uint64(1)
)
338 changes: 288 additions & 50 deletions ext/sets_test.go
Original file line number Diff line number Diff line change
@@ -15,67 +15,267 @@
package ext

import (
"fmt"
"math"
"reflect"
"strings"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
)

func TestSets(t *testing.T) {
setsTests := []struct {
expr string
expr string
vars []cel.EnvOption
in map[string]any
hints map[string]uint64
estimatedCost checker.CostEstimate
actualCost uint64
}{
// set containment
{expr: `sets.contains([], [])`},
{expr: `sets.contains([1], [])`},
{expr: `sets.contains([1], [1])`},
{expr: `sets.contains([1], [1, 1])`},
{expr: `sets.contains([1, 1], [1])`},
{expr: `sets.contains([2, 1], [1])`},
{expr: `sets.contains([1, 2, 3, 4], [2, 3])`},
{expr: `sets.contains([1], [1.0, 1])`},
{expr: `sets.contains([1, 2], [2u, 2.0])`},
{expr: `sets.contains([1, 2u], [2, 2.0])`},
{expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`},
{expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`},
{expr: `!sets.contains([1], [2])`},
{expr: `!sets.contains([1], [1, 2])`},
{expr: `!sets.contains([1], ["1", 1])`},
{expr: `!sets.contains([1], [1.1, 1u])`},
// set equivalence
{expr: `sets.equivalent([], [])`},
{expr: `sets.equivalent([1], [1])`},
{expr: `sets.equivalent([1], [1, 1])`},
{expr: `sets.equivalent([1, 1], [1])`},
{expr: `sets.equivalent([1], [1u, 1.0])`},
{expr: `sets.equivalent([1], [1u, 1.0])`},
{expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`},
{expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`},
{expr: `!sets.equivalent([2, 1], [1])`},
{expr: `!sets.equivalent([1], [1, 2])`},
{expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`},
{expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`},
{
expr: `sets.contains(x, [1, 2, 3])`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))},
in: map[string]any{"x": []int64{5, 4, 3, 2, 1}},
hints: map[string]uint64{"x": 10},
// min cost is input 'x' length 0, 10 for list creation, 2 for arg costs
// max cost is input 'x' lenght 10, 10 for list creation, 2 for arg costs
estimatedCost: checker.CostEstimate{Min: 12, Max: 42},
// actual cost is 'x' length 5 * list literal length 3, 10 for list creation, 2 for arg cost
actualCost: 27,
},
{
expr: `sets.contains(x, [1, 1, 1, 1, 1])`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))},
in: map[string]any{"x": []int64{5, 4, 3, 2, 1}},
// min cost is input 'x' length 0, 10 for list creation, 2 for arg costs
// max cost is effectively infinite due to missing size hint for 'x'
estimatedCost: checker.CostEstimate{Min: 12, Max: math.MaxUint64},
// actual cost is 'x' length 5 * list literal length 5, 10 for list creation, 2 for arg cost
actualCost: 37,
},
{
expr: `sets.contains([], [])`,
estimatedCost: checker.CostEstimate{Min: 21, Max: 21},
actualCost: 21,
},
{
expr: `sets.contains([1], [])`,
estimatedCost: checker.CostEstimate{Min: 21, Max: 21},
actualCost: 21,
},
{
expr: `sets.contains([1], [1])`,
estimatedCost: checker.CostEstimate{Min: 22, Max: 22},
actualCost: 22,
},
{
expr: `sets.contains([1], [1, 1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.contains([1, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.contains([2, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.contains([1, 2, 3, 4], [2, 3])`,
estimatedCost: checker.CostEstimate{Min: 29, Max: 29},
actualCost: 29,
},
{
expr: `sets.contains([1], [1.0, 1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.contains([1, 2], [2u, 2.0])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.contains([1, 2u], [2, 2.0])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`,
estimatedCost: checker.CostEstimate{Min: 30, Max: 30},
actualCost: 30,
},
{
expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`,
// 10 for each list creation, top-level list sizes are 2, 1
estimatedCost: checker.CostEstimate{Min: 53, Max: 53},
actualCost: 53,
},
{
expr: `!sets.contains([1], [2])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `!sets.contains([1], [1, 2])`,
estimatedCost: checker.CostEstimate{Min: 24, Max: 24},
actualCost: 24,
},
{
expr: `!sets.contains([1], ["1", 1])`,
estimatedCost: checker.CostEstimate{Min: 24, Max: 24},
actualCost: 24,
},
{
expr: `!sets.contains([1], [1.1, 1u])`,
estimatedCost: checker.CostEstimate{Min: 24, Max: 24},
actualCost: 24,
},

// set equivalence (note the cost factor is higher as it's basically two contains checks)
{
expr: `sets.equivalent([], [])`,
estimatedCost: checker.CostEstimate{Min: 21, Max: 21},
actualCost: 21,
},
{
expr: `sets.equivalent([1], [1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.equivalent([1], [1, 1])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.equivalent([1, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.equivalent([1], [1u, 1.0])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.equivalent([1], [1u, 1.0])`,
estimatedCost: checker.CostEstimate{Min: 25, Max: 25},
actualCost: 25,
},
{
expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`,
estimatedCost: checker.CostEstimate{Min: 39, Max: 39},
actualCost: 39,
},
{
expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`,
estimatedCost: checker.CostEstimate{Min: 69, Max: 69},
actualCost: 69,
},
{
expr: `!sets.equivalent([2, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 26, Max: 26},
actualCost: 26,
},
{
expr: `!sets.equivalent([1], [1, 2])`,
estimatedCost: checker.CostEstimate{Min: 26, Max: 26},
actualCost: 26,
},
{
expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`,
estimatedCost: checker.CostEstimate{Min: 34, Max: 34},
actualCost: 34,
},
{
expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`,
estimatedCost: checker.CostEstimate{Min: 34, Max: 34},
actualCost: 34,
},

// set intersection
{expr: `sets.intersects([1], [1])`},
{expr: `sets.intersects([1], [1, 1])`},
{expr: `sets.intersects([1, 1], [1])`},
{expr: `sets.intersects([2, 1], [1])`},
{expr: `sets.intersects([1], [1, 2])`},
{expr: `sets.intersects([1], [1.0, 2])`},
{expr: `sets.intersects([1, 2], [2u, 2, 2.0])`},
{expr: `sets.intersects([1, 2], [1u, 2, 2.3])`},
{expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`},
{expr: `!sets.intersects([], [])`},
{expr: `!sets.intersects([1], [])`},
{expr: `!sets.intersects([1], [2])`},
{expr: `!sets.intersects([1], ["1", 2])`},
{expr: `!sets.intersects([1], [1.1, 2u])`},
{
expr: `sets.intersects([1], [1])`,
estimatedCost: checker.CostEstimate{Min: 22, Max: 22},
actualCost: 22,
},
{
expr: `sets.intersects([1], [1, 1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.intersects([1, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.intersects([2, 1], [1])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.intersects([1], [1, 2])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.intersects([1], [1.0, 2])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `sets.intersects([1, 2], [2u, 2, 2.0])`,
estimatedCost: checker.CostEstimate{Min: 27, Max: 27},
actualCost: 27,
},
{
expr: `sets.intersects([1, 2], [1u, 2, 2.3])`,
estimatedCost: checker.CostEstimate{Min: 27, Max: 27},
actualCost: 27,
},
{
expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`,
estimatedCost: checker.CostEstimate{Min: 65, Max: 65},
actualCost: 65,
},
{
expr: `!sets.intersects([], [])`,
estimatedCost: checker.CostEstimate{Min: 22, Max: 22},
actualCost: 22,
},
{
expr: `!sets.intersects([1], [])`,
estimatedCost: checker.CostEstimate{Min: 22, Max: 22},
actualCost: 22,
},
{
expr: `!sets.intersects([1], [2])`,
estimatedCost: checker.CostEstimate{Min: 23, Max: 23},
actualCost: 23,
},
{
expr: `!sets.intersects([1], ["1", 2])`,
estimatedCost: checker.CostEstimate{Min: 24, Max: 24},
actualCost: 24,
},
{
expr: `!sets.intersects([1], [1.1, 2u])`,
estimatedCost: checker.CostEstimate{Min: 24, Max: 24},
actualCost: 24,
},
}

env := testSetsEnv(t)
for i, tst := range setsTests {
for _, tst := range setsTests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Run(tc.expr, func(t *testing.T) {
env := testSetsEnv(t, tc.vars...)
var asts []*cel.Ast
pAst, iss := env.Parse(tc.expr)
if iss.Err() != nil {
@@ -86,20 +286,43 @@ func TestSets(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}

hints := map[string]uint64{}
if len(tc.hints) != 0 {
hints = tc.hints
}
est, err := env.EstimateCost(cAst, testSetsCostEstimator{hints: hints})
if err != nil {
t.Fatalf("env.EstimateCost() failed: %v", err)
}
if !reflect.DeepEqual(est, tc.estimatedCost) {
t.Errorf("env.EstimateCost() got %v, wanted %v", est, tc.estimatedCost)
}
asts = append(asts, cAst)

for _, ast := range asts {
prg, err := env.Program(ast)
prgOpts := []cel.ProgramOption{}
if ast.IsChecked() {
prgOpts = append(prgOpts, cel.CostTracking(nil))
}
prg, err := env.Program(ast, prgOpts...)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(cel.NoVars())
in := tc.in
if in == nil {
in = map[string]any{}
}
out, det, err := prg.Eval(in)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Value() != true {
t.Errorf("prg.Eval() got %v, wanted true for expr: %s", out.Value(), tc.expr)
}
if det.ActualCost() != nil && *det.ActualCost() != tc.actualCost {
t.Errorf("prg.Eval() had cost %v, wanted %v", *det.ActualCost(), tc.actualCost)
}
}
})
}
@@ -114,3 +337,18 @@ func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
}
return env
}

type testSetsCostEstimator struct {
hints map[string]uint64
}

func (tc testSetsCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}

func (testSetsCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}
47 changes: 36 additions & 11 deletions interpreter/runtimecost.go
Original file line number Diff line number Diff line change
@@ -133,6 +133,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption {
func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) {
tracker := &CostTracker{
Estimator: estimator,
overloadTrackers: map[string]FunctionTracker{},
presenceTestHasCost: true,
}
for _, opt := range opts {
@@ -144,9 +145,24 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*
return tracker, nil
}

// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation.
//
// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or
// optional cost tracking changes.
func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.overloadTrackers[overloadID] = fnTracker
return nil
}
}

// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result.
type FunctionTracker func(args []ref.Val, result ref.Val) *uint64

// CostTracker represents the information needed for tracking runtime cost.
type CostTracker struct {
Estimator ActualCostEstimator
overloadTrackers map[string]FunctionTracker
Limit *uint64
presenceTestHasCost bool

@@ -159,10 +175,19 @@ func (c *CostTracker) ActualCost() uint64 {
return c.cost
}

func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 {
var cost uint64
if len(c.overloadTrackers) != 0 {
if tracker, found := c.overloadTrackers[call.OverloadID()]; found {
callCost := tracker(args, result)
if callCost != nil {
cost += *callCost
return cost
}
}
}
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result)
if callCost != nil {
cost += *callCost
return cost
@@ -173,20 +198,20 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
cost += c.actualSize(args[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
overloads.Equals, overloads.NotEquals:
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
lhsSize := c.actualSize(args[0])
rhsSize := c.actualSize(args[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
@@ -195,23 +220,23 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost

default:
28 changes: 28 additions & 0 deletions interpreter/runtimecost_test.go
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ package interpreter

import (
"fmt"
"math"
"math/rand"
"reflect"
"strings"
@@ -29,6 +30,7 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"

proto3pb "github.com/google/cel-go/test/proto3pb"
@@ -727,6 +729,25 @@ func TestRuntimeCost(t *testing.T) {
want: 6,
in: map[string]any{"str1": "val1", "str2": "val2222222"},
},
{
name: "str concat custom cost tracker",
expr: `"abcdefg".contains(str1 + str2)`,
vars: []*decls.VariableDecl{
decls.NewVariable("str1", types.StringType),
decls.NewVariable("str2", types.StringType),
},
options: []CostTrackerOption{
OverloadCostTracker(overloads.ContainsString,
func(args []ref.Val, result ref.Val) *uint64 {
strCost := uint64(math.Ceil(float64(actualSize(args[0])) * 0.2))
substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * 0.2))
cost := strCost * substrCost
return &cost
}),
},
want: 10,
in: map[string]any{"str1": "val1", "str2": "val2222222"},
},
{
name: "at limit",
expr: `"abcdefg".contains(str1 + str2)`,
@@ -803,3 +824,10 @@ func TestRuntimeCost(t *testing.T) {
})
}
}

func actualSize(val ref.Val) uint64 {
if sz, ok := val.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}