Skip to content

Commit

Permalink
Add sum([]) optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Apr 13, 2024
1 parent edb1b5a commit c2b609e
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 141 deletions.
141 changes: 0 additions & 141 deletions expr_test.go
Expand Up @@ -901,147 +901,6 @@ func TestExpr(t *testing.T) {
`all(1..3, {# > 0})`,
true,
},
{
`all(1..3, {# > 0}) && all(1..3, {# < 4})`,
true,
},
{
`all(1..3, {# > 2}) && all(1..3, {# < 4})`,
false,
},
{
`all(1..3, {# > 0}) && all(1..3, {# < 2})`,
false,
},
{
`all(1..3, {# > 2}) && all(1..3, {# < 2})`,
false,
},
{
`all(1..3, {# > 0}) || all(1..3, {# < 4})`,
true,
},
{
`all(1..3, {# > 0}) || all(1..3, {# != 2})`,
true,
},
{
`all(1..3, {# != 3}) || all(1..3, {# < 4})`,
true,
},
{
`all(1..3, {# != 3}) || all(1..3, {# != 2})`,
false,
},
{
`none(1..3, {# == 0})`,
true,
},
{
`none(1..3, {# == 0}) && none(1..3, {# == 4})`,
true,
},
{
`none(1..3, {# == 0}) && none(1..3, {# == 3})`,
false,
},
{
`none(1..3, {# == 1}) && none(1..3, {# == 4})`,
false,
},
{
`none(1..3, {# == 1}) && none(1..3, {# == 3})`,
false,
},
{
`none(1..3, {# == 0}) || none(1..3, {# == 4})`,
true,
},
{
`none(1..3, {# == 0}) || none(1..3, {# == 3})`,
true,
},
{
`none(1..3, {# == 1}) || none(1..3, {# == 4})`,
true,
},
{
`none(1..3, {# == 1}) || none(1..3, {# == 3})`,
false,
},
{
`any([1,1,0,1], {# == 0})`,
true,
},
{
`any(1..3, {# == 1}) && any(1..3, {# == 2})`,
true,
},
{
`any(1..3, {# == 0}) && any(1..3, {# == 2})`,
false,
},
{
`any(1..3, {# == 1}) && any(1..3, {# == 4})`,
false,
},
{
`any(1..3, {# == 0}) && any(1..3, {# == 4})`,
false,
},
{
`any(1..3, {# == 1}) || any(1..3, {# == 2})`,
true,
},
{
`any(1..3, {# == 0}) || any(1..3, {# == 2})`,
true,
},
{
`any(1..3, {# == 1}) || any(1..3, {# == 4})`,
true,
},
{
`any(1..3, {# == 0}) || any(1..3, {# == 4})`,
false,
},
{
`one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`,
true,
},
{
`one(1..3, {# == 1}) and one(1..3, {# == 2})`,
true,
},
{
`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`,
false,
},
{
`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`,
false,
},
{
`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`,
false,
},
{
`one(1..3, {# == 1}) or one(1..3, {# == 2})`,
true,
},
{
`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`,
true,
},
{
`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`,
true,
},
{
`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`,
false,
},

{
`count(1..30, {# % 3 == 0})`,
10,
Expand Down
1 change: 1 addition & 0 deletions optimizer/optimizer.go
Expand Up @@ -37,6 +37,7 @@ func Optimize(node *Node, config *conf.Config) error {
Walk(node, &filterLast{})
Walk(node, &filterFirst{})
Walk(node, &predicateCombination{})
Walk(node, &sumArray{})
Walk(node, &sumMap{})
return nil
}
73 changes: 73 additions & 0 deletions optimizer/optimizer_test.go
Expand Up @@ -17,6 +17,79 @@ import (
"github.com/expr-lang/expr/parser"
)

func TestOptimize(t *testing.T) {
env := map[string]any{
"a": 1,
"b": 2,
"c": 3,
}

tests := []struct {
expr string
want any
}{
{`1 + 2`, 3},
{`sum([])`, 0},
{`sum([a])`, 1},
{`sum([a, b])`, 3},
{`sum([a, b, c])`, 6},
{`sum([a, b, c, 4])`, 10},
{`all(1..3, {# > 0}) && all(1..3, {# < 4})`, true},
{`all(1..3, {# > 2}) && all(1..3, {# < 4})`, false},
{`all(1..3, {# > 0}) && all(1..3, {# < 2})`, false},
{`all(1..3, {# > 2}) && all(1..3, {# < 2})`, false},
{`all(1..3, {# > 0}) || all(1..3, {# < 4})`, true},
{`all(1..3, {# > 0}) || all(1..3, {# != 2})`, true},
{`all(1..3, {# != 3}) || all(1..3, {# < 4})`, true},
{`all(1..3, {# != 3}) || all(1..3, {# != 2})`, false},
{`none(1..3, {# == 0})`, true},
{`none(1..3, {# == 0}) && none(1..3, {# == 4})`, true},
{`none(1..3, {# == 0}) && none(1..3, {# == 3})`, false},
{`none(1..3, {# == 1}) && none(1..3, {# == 4})`, false},
{`none(1..3, {# == 1}) && none(1..3, {# == 3})`, false},
{`none(1..3, {# == 0}) || none(1..3, {# == 4})`, true},
{`none(1..3, {# == 0}) || none(1..3, {# == 3})`, true},
{`none(1..3, {# == 1}) || none(1..3, {# == 4})`, true},
{`none(1..3, {# == 1}) || none(1..3, {# == 3})`, false},
{`any([1, 1, 0, 1], {# == 0})`, true},
{`any(1..3, {# == 1}) && any(1..3, {# == 2})`, true},
{`any(1..3, {# == 0}) && any(1..3, {# == 2})`, false},
{`any(1..3, {# == 1}) && any(1..3, {# == 4})`, false},
{`any(1..3, {# == 0}) && any(1..3, {# == 4})`, false},
{`any(1..3, {# == 1}) || any(1..3, {# == 2})`, true},
{`any(1..3, {# == 0}) || any(1..3, {# == 2})`, true},
{`any(1..3, {# == 1}) || any(1..3, {# == 4})`, true},
{`any(1..3, {# == 0}) || any(1..3, {# == 4})`, false},
{`one([1, 1, 0, 1], {# == 0}) and not one([1, 0, 0, 1], {# == 0})`, true},
{`one(1..3, {# == 1}) and one(1..3, {# == 2})`, true},
{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, false},
{`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, false},
{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, false},
{`one(1..3, {# == 1}) or one(1..3, {# == 2})`, true},
{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, true},
{`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, true},
{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, false},
}

for _, tt := range tests {
t.Run(tt.expr, func(t *testing.T) {
program, err := expr.Compile(tt.expr, expr.Env(env))
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, tt.want, output)

unoptimizedProgram, err := expr.Compile(tt.expr, expr.Env(env), expr.Optimize(false))
require.NoError(t, err)

unoptimizedOutput, err := expr.Run(unoptimizedProgram, env)
require.NoError(t, err)
assert.Equal(t, tt.want, unoptimizedOutput)
})
}
}

func TestOptimize_constant_folding(t *testing.T) {
tree, err := parser.Parse(`[1,2,3][5*5-25]`)
require.NoError(t, err)
Expand Down
37 changes: 37 additions & 0 deletions optimizer/sum_array.go
@@ -0,0 +1,37 @@
package optimizer

import (
"fmt"

. "github.com/expr-lang/expr/ast"
)

type sumArray struct{}

func (*sumArray) Visit(node *Node) {
if sumBuiltin, ok := (*node).(*BuiltinNode); ok &&
sumBuiltin.Name == "sum" &&
len(sumBuiltin.Arguments) == 1 {
if array, ok := sumBuiltin.Arguments[0].(*ArrayNode); ok &&
len(array.Nodes) >= 2 {
Patch(node, sumArrayFold(array))
}
}
}

func sumArrayFold(array *ArrayNode) *BinaryNode {
if len(array.Nodes) > 2 {
return &BinaryNode{
Operator: "+",
Left: array.Nodes[0],
Right: sumArrayFold(&ArrayNode{Nodes: array.Nodes[1:]}),
}
} else if len(array.Nodes) == 2 {
return &BinaryNode{
Operator: "+",
Left: array.Nodes[0],
Right: array.Nodes[1],
}
}
panic(fmt.Errorf("sumArrayFold: invalid array length %d", len(array.Nodes)))
}
73 changes: 73 additions & 0 deletions optimizer/sum_array_test.go
@@ -0,0 +1,73 @@
package optimizer_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/optimizer"
"github.com/expr-lang/expr/parser"
"github.com/expr-lang/expr/vm"
)

func BenchmarkSumArray(b *testing.B) {
env := map[string]any{
"a": 1,
"b": 2,
"c": 3,
"d": 4,
}

program, err := expr.Compile(`sum([a, b, c, d])`, expr.Env(env))
require.NoError(b, err)

var out any
b.ResetTimer()
for n := 0; n < b.N; n++ {
out, err = vm.Run(program, env)
}
b.StopTimer()

require.NoError(b, err)
require.Equal(b, 10, out)

}

func TestOptimize_sum_array(t *testing.T) {
tree, err := parser.Parse(`sum([a, b])`)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

expected := &ast.BinaryNode{
Operator: "+",
Left: &ast.IdentifierNode{Value: "a"},
Right: &ast.IdentifierNode{Value: "b"},
}

assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
}

func TestOptimize_sum_array_3(t *testing.T) {
tree, err := parser.Parse(`sum([a, b, c])`)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

expected := &ast.BinaryNode{
Operator: "+",
Left: &ast.IdentifierNode{Value: "a"},
Right: &ast.BinaryNode{
Operator: "+",
Left: &ast.IdentifierNode{Value: "b"},
Right: &ast.IdentifierNode{Value: "c"},
},
}

assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
}

0 comments on commit c2b609e

Please sign in to comment.