Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sets cost estimation and tracking options #850

Merged
merged 8 commits into from
Oct 27, 2023
54 changes: 35 additions & 19 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ func Test_ExampleWithBuiltins(t *testing.T) {
}

func TestEval(t *testing.T) {
env, err := NewEnv(Variable("input", ListType(IntType)))
env, err := NewEnv(
Variable("input", ListType(IntType)),
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
Expand All @@ -115,6 +120,9 @@ func TestEval(t *testing.T) {
ctx := context.Background()
prgOpts := []ProgramOption{
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
EvalOptions(OptOptimize, OptTrackCost),
InterruptCheckFrequency(100),
}
Expand Down Expand Up @@ -1475,7 +1483,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
name string
expr string
decls []EnvOption
hints map[string]int64
hints map[string]uint64
want checker.CostEstimate
in any
}{
Expand All @@ -1499,7 +1507,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
Variable("str1", StringType),
Variable("str2", StringType),
},
hints: map[string]int64{"str1": 10, "str2": 10},
hints: map[string]uint64{"str1": 10, "str2": 10},
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]any{"str1": "val1111111", "str2": "val2222222"},
},
Expand All @@ -1510,9 +1518,15 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if tc.hints == nil {
tc.hints = map[string]int64{}
tc.hints = map[string]uint64{}
}
env := testEnv(t, tc.decls...)
envOpts := []EnvOption{
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
}
envOpts = append(envOpts, tc.decls...)
env := testEnv(t, envOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
Expand All @@ -1531,7 +1545,12 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := env.Program(checkedAst, CostTracking(testRuntimeCostEstimator{}))
program, err := env.Program(checkedAst,
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
)
if err != nil {
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
Expand Down Expand Up @@ -2768,27 +2787,26 @@ func BenchmarkDynamicDispatch(b *testing.B) {

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]int64
hints map[string]uint64
}

func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: uint64(l)}
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}

func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}
return nil
}

type testRuntimeCostEstimator struct {
func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}

type testRuntimeCostEstimator struct{}

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
Expand All @@ -2804,13 +2822,11 @@ func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []r
argsSize[i] = 1
}
}
return nil
}

switch overloadID {
case overloads.TimestampToYear:
return &timeToYearCost
default:
return nil
}
func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 {
return &timeToYearCost
}

func testEnv(t testing.TB, opts ...EnvOption) *Env {
Expand Down
10 changes: 9 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type Env struct {
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption

// Internal parser representation
prsr *parser.Parser
Expand Down Expand Up @@ -191,6 +192,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}

Expand Down Expand Up @@ -365,6 +367,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

ext := &Env{
Container: e.Container,
Expand All @@ -380,6 +384,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
Expand Down Expand Up @@ -556,7 +561,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
return checker.Cost(ast.impl, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
19 changes: 19 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -471,6 +472,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}

// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}

// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
Expand Down
35 changes: 28 additions & 7 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
Expand All @@ -129,10 +129,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}

func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)

return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
Expand All @@ -154,9 +158,10 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}

// Configure the program via the ProgramOption values.
Expand Down Expand Up @@ -213,6 +218,12 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
Expand Down Expand Up @@ -325,7 +336,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
Expand All @@ -338,7 +353,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down Expand Up @@ -366,7 +384,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down
56 changes: 44 additions & 12 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
Expand All @@ -238,7 +238,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}

var (
Expand All @@ -256,9 +260,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
Expand Down Expand Up @@ -287,6 +292,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error

// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
Expand All @@ -299,15 +305,30 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}

// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate

// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}

// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checked,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checked,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
Expand Down Expand Up @@ -518,7 +539,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}

if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
Expand Down Expand Up @@ -668,3 +696,7 @@ func isScalar(t *types.Type) bool {
}
return false
}

var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)