From c2b609e99b5c2229c0954beaf0eba1ed6fa50e94 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Sat, 13 Apr 2024 23:16:08 +0200 Subject: [PATCH] Add sum([]) optimization --- expr_test.go | 141 ------------------------------------ optimizer/optimizer.go | 1 + optimizer/optimizer_test.go | 73 +++++++++++++++++++ optimizer/sum_array.go | 37 ++++++++++ optimizer/sum_array_test.go | 73 +++++++++++++++++++ 5 files changed, 184 insertions(+), 141 deletions(-) create mode 100644 optimizer/sum_array.go create mode 100644 optimizer/sum_array_test.go diff --git a/expr_test.go b/expr_test.go index ac8eecf4..38b97eaa 100644 --- a/expr_test.go +++ b/expr_test.go @@ -901,147 +901,6 @@ func TestExpr(t *testing.T) { `all(1..3, {# > 0})`, true, }, - { - `all(1..3, {# > 0}) && all(1..3, {# < 4})`, - true, - }, - { - `all(1..3, {# > 2}) && all(1..3, {# < 4})`, - false, - }, - { - `all(1..3, {# > 0}) && all(1..3, {# < 2})`, - false, - }, - { - `all(1..3, {# > 2}) && all(1..3, {# < 2})`, - false, - }, - { - `all(1..3, {# > 0}) || all(1..3, {# < 4})`, - true, - }, - { - `all(1..3, {# > 0}) || all(1..3, {# != 2})`, - true, - }, - { - `all(1..3, {# != 3}) || all(1..3, {# < 4})`, - true, - }, - { - `all(1..3, {# != 3}) || all(1..3, {# != 2})`, - false, - }, - { - `none(1..3, {# == 0})`, - true, - }, - { - `none(1..3, {# == 0}) && none(1..3, {# == 4})`, - true, - }, - { - `none(1..3, {# == 0}) && none(1..3, {# == 3})`, - false, - }, - { - `none(1..3, {# == 1}) && none(1..3, {# == 4})`, - false, - }, - { - `none(1..3, {# == 1}) && none(1..3, {# == 3})`, - false, - }, - { - `none(1..3, {# == 0}) || none(1..3, {# == 4})`, - true, - }, - { - `none(1..3, {# == 0}) || none(1..3, {# == 3})`, - true, - }, - { - `none(1..3, {# == 1}) || none(1..3, {# == 4})`, - true, - }, - { - `none(1..3, {# == 1}) || none(1..3, {# == 3})`, - false, - }, - { - `any([1,1,0,1], {# == 0})`, - true, - }, - { - `any(1..3, {# == 1}) && any(1..3, {# == 2})`, - true, - }, - { - `any(1..3, {# == 0}) && any(1..3, {# == 2})`, - false, - }, - { - `any(1..3, {# == 1}) && any(1..3, {# == 4})`, - false, - }, - { - `any(1..3, {# == 0}) && any(1..3, {# == 4})`, - false, - }, - { - `any(1..3, {# == 1}) || any(1..3, {# == 2})`, - true, - }, - { - `any(1..3, {# == 0}) || any(1..3, {# == 2})`, - true, - }, - { - `any(1..3, {# == 1}) || any(1..3, {# == 4})`, - true, - }, - { - `any(1..3, {# == 0}) || any(1..3, {# == 4})`, - false, - }, - { - `one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`, - true, - }, - { - `one(1..3, {# == 1}) and one(1..3, {# == 2})`, - true, - }, - { - `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, - false, - }, - { - `one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, - false, - }, - { - `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, - false, - }, - { - `one(1..3, {# == 1}) or one(1..3, {# == 2})`, - true, - }, - { - `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, - true, - }, - { - `one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, - true, - }, - { - `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, - false, - }, - { `count(1..30, {# % 3 == 0})`, 10, diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index edb9c14f..4ceb3fa4 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -37,6 +37,7 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLast{}) Walk(node, &filterFirst{}) Walk(node, &predicateCombination{}) + Walk(node, &sumArray{}) Walk(node, &sumMap{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 316b1718..e59d2c65 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -17,6 +17,79 @@ import ( "github.com/expr-lang/expr/parser" ) +func TestOptimize(t *testing.T) { + env := map[string]any{ + "a": 1, + "b": 2, + "c": 3, + } + + tests := []struct { + expr string + want any + }{ + {`1 + 2`, 3}, + {`sum([])`, 0}, + {`sum([a])`, 1}, + {`sum([a, b])`, 3}, + {`sum([a, b, c])`, 6}, + {`sum([a, b, c, 4])`, 10}, + {`all(1..3, {# > 0}) && all(1..3, {# < 4})`, true}, + {`all(1..3, {# > 2}) && all(1..3, {# < 4})`, false}, + {`all(1..3, {# > 0}) && all(1..3, {# < 2})`, false}, + {`all(1..3, {# > 2}) && all(1..3, {# < 2})`, false}, + {`all(1..3, {# > 0}) || all(1..3, {# < 4})`, true}, + {`all(1..3, {# > 0}) || all(1..3, {# != 2})`, true}, + {`all(1..3, {# != 3}) || all(1..3, {# < 4})`, true}, + {`all(1..3, {# != 3}) || all(1..3, {# != 2})`, false}, + {`none(1..3, {# == 0})`, true}, + {`none(1..3, {# == 0}) && none(1..3, {# == 4})`, true}, + {`none(1..3, {# == 0}) && none(1..3, {# == 3})`, false}, + {`none(1..3, {# == 1}) && none(1..3, {# == 4})`, false}, + {`none(1..3, {# == 1}) && none(1..3, {# == 3})`, false}, + {`none(1..3, {# == 0}) || none(1..3, {# == 4})`, true}, + {`none(1..3, {# == 0}) || none(1..3, {# == 3})`, true}, + {`none(1..3, {# == 1}) || none(1..3, {# == 4})`, true}, + {`none(1..3, {# == 1}) || none(1..3, {# == 3})`, false}, + {`any([1, 1, 0, 1], {# == 0})`, true}, + {`any(1..3, {# == 1}) && any(1..3, {# == 2})`, true}, + {`any(1..3, {# == 0}) && any(1..3, {# == 2})`, false}, + {`any(1..3, {# == 1}) && any(1..3, {# == 4})`, false}, + {`any(1..3, {# == 0}) && any(1..3, {# == 4})`, false}, + {`any(1..3, {# == 1}) || any(1..3, {# == 2})`, true}, + {`any(1..3, {# == 0}) || any(1..3, {# == 2})`, true}, + {`any(1..3, {# == 1}) || any(1..3, {# == 4})`, true}, + {`any(1..3, {# == 0}) || any(1..3, {# == 4})`, false}, + {`one([1, 1, 0, 1], {# == 0}) and not one([1, 0, 0, 1], {# == 0})`, true}, + {`one(1..3, {# == 1}) and one(1..3, {# == 2})`, true}, + {`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, false}, + {`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, false}, + {`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, false}, + {`one(1..3, {# == 1}) or one(1..3, {# == 2})`, true}, + {`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, true}, + {`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, true}, + {`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, false}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + + unoptimizedProgram, err := expr.Compile(tt.expr, expr.Env(env), expr.Optimize(false)) + require.NoError(t, err) + + unoptimizedOutput, err := expr.Run(unoptimizedProgram, env) + require.NoError(t, err) + assert.Equal(t, tt.want, unoptimizedOutput) + }) + } +} + func TestOptimize_constant_folding(t *testing.T) { tree, err := parser.Parse(`[1,2,3][5*5-25]`) require.NoError(t, err) diff --git a/optimizer/sum_array.go b/optimizer/sum_array.go new file mode 100644 index 00000000..0a05d1f2 --- /dev/null +++ b/optimizer/sum_array.go @@ -0,0 +1,37 @@ +package optimizer + +import ( + "fmt" + + . "github.com/expr-lang/expr/ast" +) + +type sumArray struct{} + +func (*sumArray) Visit(node *Node) { + if sumBuiltin, ok := (*node).(*BuiltinNode); ok && + sumBuiltin.Name == "sum" && + len(sumBuiltin.Arguments) == 1 { + if array, ok := sumBuiltin.Arguments[0].(*ArrayNode); ok && + len(array.Nodes) >= 2 { + Patch(node, sumArrayFold(array)) + } + } +} + +func sumArrayFold(array *ArrayNode) *BinaryNode { + if len(array.Nodes) > 2 { + return &BinaryNode{ + Operator: "+", + Left: array.Nodes[0], + Right: sumArrayFold(&ArrayNode{Nodes: array.Nodes[1:]}), + } + } else if len(array.Nodes) == 2 { + return &BinaryNode{ + Operator: "+", + Left: array.Nodes[0], + Right: array.Nodes[1], + } + } + panic(fmt.Errorf("sumArrayFold: invalid array length %d", len(array.Nodes))) +} diff --git a/optimizer/sum_array_test.go b/optimizer/sum_array_test.go new file mode 100644 index 00000000..e3dc8341 --- /dev/null +++ b/optimizer/sum_array_test.go @@ -0,0 +1,73 @@ +package optimizer_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" + "github.com/expr-lang/expr/vm" +) + +func BenchmarkSumArray(b *testing.B) { + env := map[string]any{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + } + + program, err := expr.Compile(`sum([a, b, c, d])`, expr.Env(env)) + require.NoError(b, err) + + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, err = vm.Run(program, env) + } + b.StopTimer() + + require.NoError(b, err) + require.Equal(b, 10, out) + +} + +func TestOptimize_sum_array(t *testing.T) { + tree, err := parser.Parse(`sum([a, b])`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BinaryNode{ + Operator: "+", + Left: &ast.IdentifierNode{Value: "a"}, + Right: &ast.IdentifierNode{Value: "b"}, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_sum_array_3(t *testing.T) { + tree, err := parser.Parse(`sum([a, b, c])`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BinaryNode{ + Operator: "+", + Left: &ast.IdentifierNode{Value: "a"}, + Right: &ast.BinaryNode{ + Operator: "+", + Left: &ast.IdentifierNode{Value: "b"}, + Right: &ast.IdentifierNode{Value: "c"}, + }, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +}