Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PruneAST overwrites existing nodes and produces an invalid AST #699

Closed
1 of 3 tasks
charithe opened this issue May 5, 2023 · 7 comments · Fixed by #700 or #703
Closed
1 of 3 tasks

PruneAST overwrites existing nodes and produces an invalid AST #699

charithe opened this issue May 5, 2023 · 7 comments · Fixed by #700 or #703

Comments

@charithe
Copy link

charithe commented May 5, 2023

Describe the bug
Calling interpreter.PruneAST function produces an incorrect AST if it decides to replace an expression with a constant. This is because the ID for the new constant node clashes with an existing node ID and ends up overwriting it.

For example, consider the expression foo == "bar" && R.attr.loc in ["GB", "US"]. If I set the variable foo to "bar" and partially evaluate the expression, I'd expect to get back R.attr.loc in ["GB", "US"]. However, with the change introduced in #677 and released in version 0.15.0, the output is "US".loc in ["GB", "US"].

To Reproduce
Check which components this affects:

  • parser
  • checker
  • interpreter

Sample expression and input that reproduces the issue:

foo == "bar" && R.attr.loc in ["GB", "US"]

Test setup:

package main

import (
	"fmt"
	"log"

	"github.com/google/cel-go/cel"
	"github.com/google/cel-go/checker/decls"
	"github.com/google/cel-go/interpreter"
	expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func main() {
	env, err := cel.NewEnv(cel.Declarations(
		decls.NewVar("R", decls.NewMapType(decls.String, decls.Dyn)),
		decls.NewVar("foo", decls.String),
	))
	exitOnErr(err)

	pvars, err := cel.PartialVars(map[string]any{"foo": "bar"}, cel.AttributePattern("R").QualString("attr").Wildcard())
	exitOnErr(err)

	ast, iss := env.Compile(`foo == "bar" && R.attr.loc in ["GB", "US"]`)
	exitOnErr(iss.Err())

	prg, err := env.Program(ast, cel.EvalOptions(cel.OptTrackState, cel.OptPartialEval))
	exitOnErr(err)

	_, details, err := prg.Eval(pvars)
	exitOnErr(err)

	pruned := interpreter.PruneAst(ast.Expr(), ast.SourceInfo().GetMacroCalls(), details.State()).Expr
	prunedAST := cel.ParsedExprToAst(&expr.ParsedExpr{Expr: pruned})
	fmt.Println(cel.AstToString(prunedAST))

}

func exitOnErr(err error) {
	if err != nil {
		log.Fatalf("Error: %v", err)
	}
}

Expected behavior

The output should be:

R.attr.loc in ["GB", "US"]

Additional context

I believe the problem stems from the fact that nextExprId is set to 1 in the astPruner.

pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
state: pruneState,
nextExprID: 1}

If I apply the following patch, everything works as expected.

diff --git a/interpreter/prune.go b/interpreter/prune.go
index 24c7e79..342de6e 100644
--- a/interpreter/prune.go
+++ b/interpreter/prune.go
@@ -68,15 +68,19 @@ type astPruner struct {
 // the overloads accordingly.
 func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
 	pruneState := NewEvalState()
+	var maxID int64 = 0
 	for _, id := range state.IDs() {
 		v, _ := state.Value(id)
 		pruneState.SetValue(id, v)
+		if id > maxID {
+			maxID = id
+		}
 	}
 	pruner := &astPruner{
 		expr:       expr,
 		macroCalls: macroCalls,
 		state:      pruneState,
-		nextExprID: 1}
+		nextExprID: maxID}
 	newExpr, _ := pruner.maybePrune(expr)
 	return &exprpb.ParsedExpr{
 		Expr:       newExpr,
@TristonianJones
Copy link
Collaborator

@charithe Thank you for the very quality report and triage. I'll have a fix out soon and will do a point release

@TristonianJones
Copy link
Collaborator

@charithe You should be all set to upgrade to v0.15.1. Thanks again for the report!

@charithe
Copy link
Author

charithe commented May 5, 2023

Thank you for the quick fix!

@charithe
Copy link
Author

charithe commented May 9, 2023

@TristonianJones It turns out that the value of nextExprID should be the maximum ID in the expression tree rather than the maximum ID in the state.

I was testing a much more complicated expression and realised that the nextExprID had generated a state value that had the same ID as a node in the expression. When PruneAST got to that node, it replaced that node with the state value that had the same ID but that was the wrong thing to do.

My original suggested solution was wrong. Sorry about that. I think the correct solution is to traverse the expression and find its largest node ID.

@TristonianJones
Copy link
Collaborator

@charithe thanks for following up. Honestly, the fix looked good to me too, but I can see how the ids still need adjustment

@charithe
Copy link
Author

@TristonianJones the following is a contrived expression that triggers the problem:

sets.intersects(users.filter(u, u.role=="MANAGER").map(u, u.name), R.attr.authorized["managers"])

It produces the incorrect, pruned expression sets.intersects(["bob"], R.attr.authorized["bob"]). The correct pruned expression should be sets.intersects(["bob"], R.attr.authorized["managers"]).

Click to reveal the test harness

package main

import (
	"fmt"
	"log"

	"github.com/google/cel-go/cel"
	"github.com/google/cel-go/checker/decls"
	"github.com/google/cel-go/ext"
	"github.com/google/cel-go/interpreter"
	expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func main() {
	env, err := cel.NewEnv(cel.Declarations(
		decls.NewVar("R", decls.NewMapType(decls.String, decls.Dyn)),
		decls.NewVar("users", decls.NewListType(decls.NewMapType(decls.String, decls.String))),
	), ext.Sets())
	exitOnErr(err)

	pvars, err := cel.PartialVars(
		map[string]any{
			"users": []map[string]string{
				{"name": "alice", "role": "EMPLOYEE"},
				{"name": "bob", "role": "MANAGER"},
				{"name": "eve", "role": "CUSTOMER"},
			},
		},
		cel.AttributePattern("R").QualString("attr").Wildcard())
	exitOnErr(err)

	ast, iss := env.Compile(`sets.intersects(users.filter(u, u.role=="MANAGER").map(u, u.name), R.attr.authorized["managers"])`)
	exitOnErr(iss.Err())

	prg, err := env.Program(ast, cel.EvalOptions(cel.OptTrackState, cel.OptPartialEval))
	exitOnErr(err)

	_, details, err := prg.Eval(pvars)
	exitOnErr(err)

	pruned := interpreter.PruneAst(ast.Expr(), ast.SourceInfo().GetMacroCalls(), details.State()).Expr
	prunedAST := cel.ParsedExprToAst(&expr.ParsedExpr{Expr: pruned})
	fmt.Println(cel.AstToString(prunedAST))

}

func exitOnErr(err error) {
	if err != nil {
		log.Fatalf("Error: %v", err)
	}
}

I made the following patch which seems to fix the issue.

diff --git a/interpreter/prune.go b/interpreter/prune.go
index 85b3b06..1d46e2f 100644
--- a/interpreter/prune.go
+++ b/interpreter/prune.go
@@ -68,19 +68,16 @@ type astPruner struct {
 // the overloads accordingly.
 func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
 	pruneState := NewEvalState()
-	maxID := int64(1)
 	for _, id := range state.IDs() {
 		v, _ := state.Value(id)
 		pruneState.SetValue(id, v)
-		if id > maxID {
-			maxID = id + 1
-		}
 	}
+	nextExprID := maxExprID(expr) + 1
 	pruner := &astPruner{
 		expr:       expr,
 		macroCalls: macroCalls,
 		state:      pruneState,
-		nextExprID: maxID}
+		nextExprID: nextExprID}
 	newExpr, _ := pruner.maybePrune(expr)
 	return &exprpb.ParsedExpr{
 		Expr:       newExpr,
@@ -88,6 +85,50 @@ func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalSt
 	}
 }
 
+func maxExprID(expr *exprpb.Expr) int64 {
+	max := expr.GetId()
+	maybeSetMax := func(e *exprpb.Expr) {
+		if e == nil {
+			return
+		}
+
+		if m := maxExprID(e); m > max {
+			max = m
+		}
+	}
+
+	switch kind := expr.GetExprKind().(type) {
+	case *exprpb.Expr_ConstExpr, *exprpb.Expr_IdentExpr:
+	case *exprpb.Expr_CallExpr:
+		maybeSetMax(kind.CallExpr.GetTarget())
+		for _, e := range kind.CallExpr.GetArgs() {
+			maybeSetMax(e)
+		}
+	case *exprpb.Expr_ComprehensionExpr:
+		maybeSetMax(kind.ComprehensionExpr.GetIterRange())
+		maybeSetMax(kind.ComprehensionExpr.GetAccuInit())
+		maybeSetMax(kind.ComprehensionExpr.GetLoopCondition())
+		maybeSetMax(kind.ComprehensionExpr.GetLoopStep())
+		maybeSetMax(kind.ComprehensionExpr.GetResult())
+	case *exprpb.Expr_ListExpr:
+		for _, e := range kind.ListExpr.GetElements() {
+			maybeSetMax(e)
+		}
+	case *exprpb.Expr_SelectExpr:
+		maybeSetMax(kind.SelectExpr.GetOperand())
+	case *exprpb.Expr_StructExpr:
+		for _, e := range kind.StructExpr.GetEntries() {
+			maybeSetMax(e.GetMapKey())
+			maybeSetMax(e.GetValue())
+			if m := e.GetId(); m > max {
+				max = m
+			}
+		}
+	}
+
+	return max
+}
+
 func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
 	return &exprpb.Expr{
 		Id: id,

I don't know whether it's an acceptable solution because it's recursive. Another small concern I have is that if more expression kinds are added to the Expr protobuf in the future, someone needs to remember to update this function as well. However, the latter can be caught with a linter so it's probably not a big issue. What are your thoughts?

@TristonianJones
Copy link
Collaborator

@charithe In theory, the type-map keys from the checked expression would contain a max ID as well. I don't believe there are any expression nodes which skip this check, so that might be an alternative.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants