Skip to content

Commit

Permalink
Migrate Cost calculations to Go-native Expr.
Browse files Browse the repository at this point in the history
Note, this is a breaking changes as the type of the checker.AstNode
has been modified to return a Go-native Expr value rather than its
protobuf equivalent.
  • Loading branch information
TristonianJones committed Aug 11, 2023
1 parent 7296860 commit 0645b92
Show file tree
Hide file tree
Showing 3 changed files with 7,763 additions and 102 deletions.
184 changes: 83 additions & 101 deletions checker/cost.go
Expand Up @@ -22,8 +22,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go
Expand Down Expand Up @@ -58,7 +56,7 @@ type AstNode interface {
// Type returns the deduced type of the AstNode.
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
Expr() ast.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no
Expand All @@ -69,7 +67,7 @@ type AstNode interface {
type astNode struct {
path []string
t *types.Type
expr *exprpb.Expr
expr ast.Expr
derivedSize *SizeEstimate
}

Expand All @@ -81,7 +79,7 @@ func (e astNode) Type() *types.Type {
return e.t
}

func (e astNode) Expr() *exprpb.Expr {
func (e astNode) Expr() ast.Expr {
return e.expr
}

Expand All @@ -90,29 +88,27 @@ func (e astNode) ComputedSize() *SizeEstimate {
return e.derivedSize
}
var v uint64
switch ek := e.expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
switch e.expr.Kind() {
case ast.LiteralKind:
switch ck := e.expr.AsLiteral().(type) {
case types.String:
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
*exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value,
*exprpb.Constant_NullValue:
v = uint64(len([]rune(ck)))
case types.Bytes:
v = uint64(len(ck))
case types.Bool, types.Double, types.Duration,
types.Int, types.Timestamp, types.Uint,
types.Null:
v = uint64(1)
default:
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.GetElements()))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.GetMessageName() == "" {
v = uint64(len(ek.StructExpr.GetEntries()))
}
case ast.ListKind:
v = uint64(e.expr.AsList().Size())
case ast.MapKind:
v = uint64(e.expr.AsMap().Size())
default:
return nil
}
Expand Down Expand Up @@ -270,8 +266,8 @@ type coster struct {
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
type iterRangeScopes map[string][]int64

func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) {
vs[varName] = append(vs[varName], expr.GetId())
func (vs iterRangeScopes) push(varName string, expr ast.Expr) {
vs[varName] = append(vs[varName], expr.ID())
}

func (vs iterRangeScopes) pop(varName string) {
Expand Down Expand Up @@ -319,86 +315,82 @@ func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs
return CostEstimate{}, err
}
}
epb, err := ast.ExprToProto(checked.Expr())
if err != nil {
return CostEstimate{}, err
}
return c.cost(epb), nil
return c.cost(checked.Expr()), nil
}

func (c *coster) cost(e *exprpb.Expr) CostEstimate {
func (c *coster) cost(e ast.Expr) CostEstimate {
if e == nil {
return CostEstimate{}
}
var cost CostEstimate
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
switch e.Kind() {
case ast.LiteralKind:
cost = constCost
case *exprpb.Expr_IdentExpr:
case ast.IdentKind:
cost = c.costIdent(e)
case *exprpb.Expr_SelectExpr:
case ast.SelectKind:
cost = c.costSelect(e)
case *exprpb.Expr_CallExpr:
case ast.CallKind:
cost = c.costCall(e)
case *exprpb.Expr_ListExpr:
case ast.ListKind:
cost = c.costCreateList(e)
case *exprpb.Expr_StructExpr:
case ast.MapKind:
cost = c.costCreateMap(e)
case ast.StructKind:
cost = c.costCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
case ast.ComprehensionKind:
cost = c.costComprehension(e)
default:
return CostEstimate{}
}
return cost
}

func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
identExpr := e.GetIdentExpr()

func (c *coster) costIdent(e ast.Expr) CostEstimate {
identName := e.AsIdent()
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
if iterRange, ok := c.iterRanges.peek(identName); ok {
switch c.checkedAST.GetType(iterRange).Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
c.addPath(e, []string{identExpr.GetName()})
c.addPath(e, []string{identName})
}

return selectAndIdentCost
}

func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
func (c *coster) costSelect(e ast.Expr) CostEstimate {
sel := e.AsSelect()
var sum CostEstimate
if sel.GetTestOnly() {
if sel.IsTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
sum = sum.Add(c.cost(sel.Operand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
sum = sum.Add(c.cost(sel.Operand()))
targetType := c.getType(sel.Operand())
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}

// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField()))
c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName()))

return sum
}

func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
func (c *coster) costCall(e ast.Expr) CostEstimate {
call := e.AsCall()
args := call.Args()

var sum CostEstimate

Expand All @@ -409,22 +401,20 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}

overloadIDs := c.checkedAST.GetOverloadIDs(e.GetId())
overloadIDs := c.checkedAST.GetOverloadIDs(e.ID())
if len(overloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
if target != nil {
if call.Target != nil {
sum = sum.Add(c.cost(call.GetTarget()))
targetType = c.newAstNode(call.GetTarget())
}
if call.IsMemberFunction() {
sum = sum.Add(c.cost(call.Target()))
targetType = c.newAstNode(call.Target())
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range overloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
Expand All @@ -447,62 +437,54 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
}
}
if resultSize != nil {
c.computedSizes[e.GetId()] = *resultSize
c.computedSizes[e.ID()] = *resultSize
}
return sum.Add(fnCost)
}

func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate {
create := e.GetListExpr()
func (c *coster) costCreateList(e ast.Expr) CostEstimate {
create := e.AsList()
var sum CostEstimate
for _, e := range create.GetElements() {
for _, e := range create.Elements() {
sum = sum.Add(c.cost(e))
}
return sum.Add(createListBaseCost)
}

func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate {
str := e.GetStructExpr()
if str.MessageName != "" {
return c.costCreateMessage(e)
}
return c.costCreateMap(e)
}

func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate {
mapVal := e.GetStructExpr()
func (c *coster) costCreateMap(e ast.Expr) CostEstimate {
mapVal := e.AsMap()
var sum CostEstimate
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
sum = sum.Add(c.cost(key))

sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range mapVal.Entries() {
entry := ent.AsMapEntry()
sum = sum.Add(c.cost(entry.Key()))
sum = sum.Add(c.cost(entry.Value()))
}
return sum.Add(createMapBaseCost)
}

func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate {
msgVal := e.GetStructExpr()
func (c *coster) costCreateStruct(e ast.Expr) CostEstimate {
msgVal := e.AsStruct()
var sum CostEstimate
for _, ent := range msgVal.GetEntries() {
sum = sum.Add(c.cost(ent.GetValue()))
for _, ent := range msgVal.Fields() {
field := ent.AsStructField()
sum = sum.Add(c.cost(field.Value()))
}
return sum.Add(createMessageBaseCost)
}

func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
comp := e.GetComprehensionExpr()
func (c *coster) costComprehension(e ast.Expr) CostEstimate {
comp := e.AsComprehension()
var sum CostEstimate
sum = sum.Add(c.cost(comp.GetIterRange()))
sum = sum.Add(c.cost(comp.GetAccuInit()))
sum = sum.Add(c.cost(comp.IterRange()))
sum = sum.Add(c.cost(comp.AccuInit()))

// Track the iterRange of each IterVar for field path construction
c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange())
loopCost := c.cost(comp.GetLoopCondition())
stepCost := c.cost(comp.GetLoopStep())
c.iterRanges.pop(comp.GetIterVar())
sum = sum.Add(c.cost(comp.Result))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange()))
c.iterRanges.push(comp.IterVar(), comp.IterRange())
loopCost := c.cost(comp.LoopCondition())
stepCost := c.cost(comp.LoopStep())
c.iterRanges.pop(comp.IterVar())
sum = sum.Add(c.cost(comp.Result()))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange()))
rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost))
sum = sum.Add(rangeCost)

Expand Down Expand Up @@ -647,26 +629,26 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}

func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.GetType(e.GetId())
func (c *coster) getType(e ast.Expr) *types.Type {
return c.checkedAST.GetType(e.ID())
}

func (c *coster) getPath(e *exprpb.Expr) []string {
return c.exprPath[e.GetId()]
func (c *coster) getPath(e ast.Expr) []string {
return c.exprPath[e.ID()]
}

func (c *coster) addPath(e *exprpb.Expr, path []string) {
c.exprPath[e.GetId()] = path
func (c *coster) addPath(e ast.Expr, path []string) {
c.exprPath[e.ID()] = path
}

func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
func (c *coster) newAstNode(e ast.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
// only provide paths to root vars; omit accumulator vars
path = nil
}
var derivedSize *SizeEstimate
if size, ok := c.computedSizes[e.GetId()]; ok {
if size, ok := c.computedSizes[e.ID()]; ok {
derivedSize = &size
}
return &astNode{
Expand Down
2 changes: 1 addition & 1 deletion common/ast/conversion.go
Expand Up @@ -298,7 +298,7 @@ func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) {

func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) {
var err error
var target *exprpb.Expr = nil
var target *exprpb.Expr
if call.IsMemberFunction() {
target, err = ExprToProto(call.Target())
if err != nil {
Expand Down

0 comments on commit 0645b92

Please sign in to comment.