Skip to content

Commit

Permalink
Prune recursion fixes for nested logic (google#677)
Browse files Browse the repository at this point in the history
In expressions where the logic is nested and the residual
state would result in the production of a constant value
in the expression, ensure that the intermediate state for
the expression is updated to reflect the constant value.

Also, ensure that special cases of pruning for logical
operators happen after argument pruning has happened to
ensure that the prune steps are properly recursive.
  • Loading branch information
TristonianJones committed Apr 12, 2023
1 parent 3d2e878 commit 68d7302
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 22 deletions.
75 changes: 63 additions & 12 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ type astPruner struct {
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
state: state,
state: pruneState,
nextExprID: 1}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expand All @@ -91,24 +96,31 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
}
Expand All @@ -128,6 +140,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
}
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
Expand Down Expand Up @@ -167,6 +180,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
Expand All @@ -182,6 +196,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}

func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
val, valueExists := p.value(call.GetArgs()[1].GetId())
if !valueExists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return nil, false
}

func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
arg := call.GetArgs()[0]
v, exists := p.value(arg.GetId())
if !exists {
return nil, false
}
if b, ok := v.(types.Bool); ok {
return p.maybeCreateLiteral(node.GetId(), !b)
}
return nil, false
}

func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
Expand Down Expand Up @@ -224,7 +269,12 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if call.Function == operators.Conditional {
return p.maybePruneConditional(node)
}

if call.Function == operators.In {
return p.maybePruneIn(node)
}
if call.Function == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}

Expand Down Expand Up @@ -266,10 +316,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}, true
}
case *exprpb.Expr_CallExpr:
if newExpr, pruned := p.maybePruneFunction(node); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
}
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
Expand All @@ -290,13 +336,18 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
prunedCall = true
newCall.Target = newTarget
}
newNode := &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}
if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
}
if prunedCall {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}, true
return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
Expand Down
69 changes: 59 additions & 10 deletions interpreter/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,32 @@ var testCases = []testInfo{
expr: `a && [1, 1u, 1.0].exists(x, type(x) == uint)`,
out: `a`,
},
{
in: unknownActivation("this"),
expr: `this in []`,
out: `false`,
},
{
in: unknownActivation("this"),
expr: `this in {}`,
out: `false`,
},
{
in: partialActivation(map[string]any{"rules": []string{}}, "this"),
expr: `this in rules`,
out: `false`,
},
{
in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"),
expr: `this.size() > 0 ? this in rules.not_in : !(this in rules.not_in)`,
out: `(this.size() > 0) ? false : true`,
},
{
in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"),
expr: `this.size() > 0 ? this in rules.not_in :
!(this in rules.not_in) ? true : false`,
out: `(this.size() > 0) ? false : true`,
},
{
expr: `{'hello': 'world'.size()}`,
out: `{"hello": 5}`,
Expand Down Expand Up @@ -96,6 +122,11 @@ var testCases = []testInfo{
expr: `true ? b < 1.2 : c == ['hello']`,
out: `b < 1.2`,
},
{
in: unknownActivation("b", "c"),
expr: `false ? b < 1.2 : c == ['hello']`,
out: `c == ["hello"]`,
},
{
in: unknownActivation(),
expr: `[1+3, 2+2, 3+1, four]`,
Expand All @@ -121,18 +152,27 @@ var testCases = []testInfo{
expr: `test in {'a': 1, 'field': [test, 3]}.field`,
out: `test in {"a": 1, "field": [test, 3]}.field`,
},
// TODO(issues/) the output test relies on tracking macro expansions back to their original
// call patterns.
/* {
in: unknownActivation(),
expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`,
out: `[4, 4, 4, four].exists(x, x == four)`,
}, */
// TODO: the output of an expression like this relies on either
// a) doing replacements on the original macro call, or
// b) mutating the macro call tracking data rather than the core
// expression in order to render the partial correctly.
// {
// in: unknownActivation(),
// expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`,
// out: `[4, 4, 4, four].exists(x, x == four)`,
// },
}

func TestPrune(t *testing.T) {
p, err := parser.NewParser(
parser.PopulateMacroCalls(true),
parser.Macros(parser.AllMacros...),
)
if err != nil {
t.Fatalf("parser.NewParser() failed: %v", err)
}
for i, tst := range testCases {
ast, iss := parser.Parse(common.NewStringSource(tst.expr, "<input>"))
ast, iss := p.Parse(common.NewStringSource(tst.expr, "<input>"))
if len(iss.GetErrors()) > 0 {
t.Fatalf(iss.ToDisplayString())
}
Expand All @@ -142,10 +182,10 @@ func TestPrune(t *testing.T) {
interp := NewStandardInterpreter(containers.DefaultContainer, reg, reg, attrs)

interpretable, _ := interp.NewUncheckedInterpretable(
ast.Expr,
ast.GetExpr(),
ExhaustiveEval(), Observe(EvalStateObserver(state)))
interpretable.Eval(testActivation(t, tst.in))
newExpr := PruneAst(ast.Expr, ast.SourceInfo.GetMacroCalls(), state)
newExpr := PruneAst(ast.GetExpr(), ast.GetSourceInfo().GetMacroCalls(), state)
actual, err := parser.Unparse(newExpr.GetExpr(), newExpr.GetSourceInfo())
if err != nil {
t.Error(err)
Expand All @@ -165,6 +205,15 @@ func unknownActivation(vars ...string) PartialActivation {
return a
}

func partialActivation(in map[string]any, vars ...string) PartialActivation {
pats := make([]*AttributePattern, len(vars), len(vars))
for i, v := range vars {
pats[i] = NewAttributePattern(v)
}
a, _ := NewPartialActivation(in, pats...)
return a
}

func testActivation(t *testing.T, in any) Activation {
t.Helper()
if in == nil {
Expand Down

0 comments on commit 68d7302

Please sign in to comment.