diff --git a/compiler/compiler.go b/compiler/compiler.go index 05a92f2f..1aa5ce18 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -92,6 +92,13 @@ type scope struct { index int } +func (c *compiler) nodeParent() ast.Node { + if len(c.nodes) > 1 { + return c.nodes[len(c.nodes)-2] + } + return nil +} + func (c *compiler) emitLocation(loc file.Location, op Opcode, arg int) int { c.bytecode = append(c.bytecode, op) current := len(c.bytecode) @@ -594,9 +601,21 @@ func isSimpleType(node ast.Node) bool { func (c *compiler) ChainNode(node *ast.ChainNode) { c.chains = append(c.chains, []int{}) c.compile(node.Node) - // Chain activate (got nit somewhere) for _, ph := range c.chains[len(c.chains)-1] { - c.patchJump(ph) + c.patchJump(ph) // If chain activated jump here (got nit somewhere). + } + parent := c.nodeParent() + if binary, ok := parent.(*ast.BinaryNode); ok && binary.Operator == "??" { + // If chain is used in nil coalescing operator, we can omit + // nil push at the end of the chain. The ?? operator will + // handle it. + } else { + // We need to put the nil on the stack, otherwise "typed" + // nil will be used as a result of the chain. + j := c.emit(OpJumpIfNotNil, placeholder) + c.emit(OpPop) + c.emit(OpNil) + c.patchJump(j) } c.chains = c.chains[:len(c.chains)-1] } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index fbd83ec8..6fe58e93 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -198,8 +198,11 @@ func TestCompile(t *testing.T) { vm.OpLoadField, vm.OpJumpIfNil, vm.OpFetchField, + vm.OpJumpIfNotNil, + vm.OpPop, + vm.OpNil, }, - Arguments: []int{0, 1, 1}, + Arguments: []int{0, 1, 1, 2, 0, 0}, }, }, { @@ -219,8 +222,60 @@ func TestCompile(t *testing.T) { vm.OpLoadField, vm.OpJumpIfNil, vm.OpFetchField, + vm.OpJumpIfNotNil, + vm.OpPop, + vm.OpNil, + }, + Arguments: []int{0, 1, 1, 2, 0, 0}, + }, + }, + { + `A?.B`, + vm.Program{ + Constants: []any{ + &runtime.Field{ + Index: []int{0}, + Path: []string{"A"}, + }, + &runtime.Field{ + Index: []int{1}, + Path: []string{"B"}, + }, + }, + Bytecode: []vm.Opcode{ + vm.OpLoadField, + vm.OpJumpIfNil, + vm.OpFetchField, + vm.OpJumpIfNotNil, + vm.OpPop, + vm.OpNil, }, - Arguments: []int{0, 1, 1}, + Arguments: []int{0, 1, 1, 2, 0, 0}, + }, + }, + { + `A?.B ?? 42`, + vm.Program{ + Constants: []any{ + &runtime.Field{ + Index: []int{0}, + Path: []string{"A"}, + }, + &runtime.Field{ + Index: []int{1}, + Path: []string{"B"}, + }, + 42, + }, + Bytecode: []vm.Opcode{ + vm.OpLoadField, + vm.OpJumpIfNil, + vm.OpFetchField, + vm.OpJumpIfNotNil, + vm.OpPop, + vm.OpPush, + }, + Arguments: []int{0, 1, 1, 2, 0, 2}, }, }, { @@ -392,187 +447,176 @@ func TestCompile_optimizes_jumps(t *testing.T) { "c": true, "d": true, } - type op struct { - Bytecode vm.Opcode - Arg int - } tests := []struct { code string - want []op + want string }{ { `let foo = true; let bar = false; let baz = true; foo || bar || baz`, - []op{ - {vm.OpTrue, 0}, - {vm.OpStore, 0}, - {vm.OpFalse, 0}, - {vm.OpStore, 1}, - {vm.OpTrue, 0}, - {vm.OpStore, 2}, - {vm.OpLoadVar, 0}, - {vm.OpJumpIfTrue, 5}, - {vm.OpPop, 0}, - {vm.OpLoadVar, 1}, - {vm.OpJumpIfTrue, 2}, - {vm.OpPop, 0}, - {vm.OpLoadVar, 2}, - }, + `0 OpTrue +1 OpStore <0> foo +2 OpFalse +3 OpStore <1> bar +4 OpTrue +5 OpStore <2> baz +6 OpLoadVar <0> foo +7 OpJumpIfTrue <5> (13) +8 OpPop +9 OpLoadVar <1> bar +10 OpJumpIfTrue <2> (13) +11 OpPop +12 OpLoadVar <2> baz +`, }, { `a && b && c`, - []op{ - {vm.OpLoadFast, 0}, - {vm.OpJumpIfFalse, 5}, - {vm.OpPop, 0}, - {vm.OpLoadFast, 1}, - {vm.OpJumpIfFalse, 2}, - {vm.OpPop, 0}, - {vm.OpLoadFast, 2}, - }, + `0 OpLoadFast <0> a +1 OpJumpIfFalse <5> (7) +2 OpPop +3 OpLoadFast <1> b +4 OpJumpIfFalse <2> (7) +5 OpPop +6 OpLoadFast <2> c +`, }, { `a && b || c && d`, - []op{ - {vm.OpLoadFast, 0}, - {vm.OpJumpIfFalse, 2}, - {vm.OpPop, 0}, - {vm.OpLoadFast, 1}, - {vm.OpJumpIfTrue, 5}, - {vm.OpPop, 0}, - {vm.OpLoadFast, 2}, - {vm.OpJumpIfFalse, 2}, - {vm.OpPop, 0}, - {vm.OpLoadFast, 3}, - }, + `0 OpLoadFast <0> a +1 OpJumpIfFalse <2> (4) +2 OpPop +3 OpLoadFast <1> b +4 OpJumpIfTrue <5> (10) +5 OpPop +6 OpLoadFast <2> c +7 OpJumpIfFalse <2> (10) +8 OpPop +9 OpLoadFast <3> d +`, }, { `filter([1, 2, 3, 4, 5], # > 3 && # != 4 && # != 5)`, - []op{ - {vm.OpPush, 0}, - {vm.OpBegin, 0}, - {vm.OpJumpIfEnd, 26}, - {vm.OpPointer, 0}, - {vm.OpDeref, 0}, - {vm.OpPush, 1}, - {vm.OpMore, 0}, - {vm.OpJumpIfFalse, 18}, - {vm.OpPop, 0}, - {vm.OpPointer, 0}, - {vm.OpDeref, 0}, - {vm.OpPush, 2}, - {vm.OpEqual, 0}, - {vm.OpNot, 0}, - {vm.OpJumpIfFalse, 11}, - {vm.OpPop, 0}, - {vm.OpPointer, 0}, - {vm.OpDeref, 0}, - {vm.OpPush, 3}, - {vm.OpEqual, 0}, - {vm.OpNot, 0}, - {vm.OpJumpIfFalse, 4}, - {vm.OpPop, 0}, - {vm.OpIncrementCount, 0}, - {vm.OpPointer, 0}, - {vm.OpJump, 1}, - {vm.OpPop, 0}, - {vm.OpIncrementIndex, 0}, - {vm.OpJumpBackward, 27}, - {vm.OpGetCount, 0}, - {vm.OpEnd, 0}, - {vm.OpArray, 0}, - }, + `0 OpPush <0> [1 2 3 4 5] +1 OpBegin +2 OpJumpIfEnd <26> (29) +3 OpPointer +4 OpDeref +5 OpPush <1> 3 +6 OpMore +7 OpJumpIfFalse <18> (26) +8 OpPop +9 OpPointer +10 OpDeref +11 OpPush <2> 4 +12 OpEqual +13 OpNot +14 OpJumpIfFalse <11> (26) +15 OpPop +16 OpPointer +17 OpDeref +18 OpPush <3> 5 +19 OpEqual +20 OpNot +21 OpJumpIfFalse <4> (26) +22 OpPop +23 OpIncrementCount +24 OpPointer +25 OpJump <1> (27) +26 OpPop +27 OpIncrementIndex +28 OpJumpBackward <27> (2) +29 OpGetCount +30 OpEnd +31 OpArray +`, }, { `let foo = true; let bar = false; let baz = true; foo && bar || baz`, - []op{ - {vm.OpTrue, 0}, - {vm.OpStore, 0}, - {vm.OpFalse, 0}, - {vm.OpStore, 1}, - {vm.OpTrue, 0}, - {vm.OpStore, 2}, - {vm.OpLoadVar, 0}, - {vm.OpJumpIfFalse, 2}, - {vm.OpPop, 0}, - {vm.OpLoadVar, 1}, - {vm.OpJumpIfTrue, 2}, - {vm.OpPop, 0}, - {vm.OpLoadVar, 2}, - }, + `0 OpTrue +1 OpStore <0> foo +2 OpFalse +3 OpStore <1> bar +4 OpTrue +5 OpStore <2> baz +6 OpLoadVar <0> foo +7 OpJumpIfFalse <2> (10) +8 OpPop +9 OpLoadVar <1> bar +10 OpJumpIfTrue <2> (13) +11 OpPop +12 OpLoadVar <2> baz +`, }, { `true ?? nil ?? nil ?? nil`, - []op{ - {vm.OpTrue, 0}, - {vm.OpJumpIfNotNil, 8}, - {vm.OpPop, 0}, - {vm.OpNil, 0}, - {vm.OpJumpIfNotNil, 5}, - {vm.OpPop, 0}, - {vm.OpNil, 0}, - {vm.OpJumpIfNotNil, 2}, - {vm.OpPop, 0}, - {vm.OpNil, 0}, - }, + `0 OpTrue +1 OpJumpIfNotNil <8> (10) +2 OpPop +3 OpNil +4 OpJumpIfNotNil <5> (10) +5 OpPop +6 OpNil +7 OpJumpIfNotNil <2> (10) +8 OpPop +9 OpNil +`, }, { `let m = {"a": {"b": {"c": 1}}}; m?.a?.b?.c`, - []op{ - {vm.OpPush, 0}, - {vm.OpPush, 1}, - {vm.OpPush, 2}, - {vm.OpPush, 3}, - {vm.OpPush, 3}, - {vm.OpMap, 0}, - {vm.OpPush, 3}, - {vm.OpMap, 0}, - {vm.OpPush, 3}, - {vm.OpMap, 0}, - {vm.OpStore, 0}, - {vm.OpLoadVar, 0}, - {vm.OpJumpIfNil, 8}, - {vm.OpPush, 0}, - {vm.OpFetch, 0}, - {vm.OpJumpIfNil, 5}, - {vm.OpPush, 1}, - {vm.OpFetch, 0}, - {vm.OpJumpIfNil, 2}, - {vm.OpPush, 2}, - {vm.OpFetch, 0}, - }, + `0 OpPush <0> a +1 OpPush <1> b +2 OpPush <2> c +3 OpPush <3> 1 +4 OpPush <3> 1 +5 OpMap +6 OpPush <3> 1 +7 OpMap +8 OpPush <3> 1 +9 OpMap +10 OpStore <0> m +11 OpLoadVar <0> m +12 OpJumpIfNil <8> (21) +13 OpPush <0> a +14 OpFetch +15 OpJumpIfNil <5> (21) +16 OpPush <1> b +17 OpFetch +18 OpJumpIfNil <2> (21) +19 OpPush <2> c +20 OpFetch +21 OpJumpIfNotNil <2> (24) +22 OpPop +23 OpNil +`, }, { `-1 not in [1, 2, 5]`, - []op{ - {vm.OpPush, 0}, - {vm.OpPush, 1}, - {vm.OpIn, 0}, - {vm.OpNot, 0}, - }, + `0 OpPush <0> -1 +1 OpPush <1> map[1:{} 2:{} 5:{}] +2 OpIn +3 OpNot +`, }, { `1 + 8 not in [1, 2, 5]`, - []op{ - {vm.OpPush, 0}, - {vm.OpPush, 1}, - {vm.OpIn, 0}, - {vm.OpNot, 0}, - }, + `0 OpPush <0> 9 +1 OpPush <1> map[1:{} 2:{} 5:{}] +2 OpIn +3 OpNot +`, }, { `true ? false : 8 not in [1, 2, 5]`, - []op{ - {vm.OpTrue, 0}, - {vm.OpJumpIfFalse, 3}, - {vm.OpPop, 0}, - {vm.OpFalse, 0}, - {vm.OpJump, 5}, - {vm.OpPop, 0}, - {vm.OpPush, 0}, - {vm.OpPush, 1}, - {vm.OpIn, 0}, - {vm.OpNot, 0}, - }, + `0 OpTrue +1 OpJumpIfFalse <3> (5) +2 OpPop +3 OpFalse +4 OpJump <5> (10) +5 OpPop +6 OpPush <0> 8 +7 OpPush <1> map[1:{} 2:{} 5:{}] +8 OpIn +9 OpNot +`, }, } @@ -580,12 +624,7 @@ func TestCompile_optimizes_jumps(t *testing.T) { t.Run(test.code, func(t *testing.T) { program, err := expr.Compile(test.code, expr.Env(env)) require.NoError(t, err) - - require.Equal(t, len(test.want), len(program.Bytecode)) - for i, op := range test.want { - require.Equal(t, op.Bytecode, program.Bytecode[i]) - require.Equalf(t, op.Arg, program.Arguments[i], "at %d", i) - } + require.Equal(t, test.want, program.Disassemble()) }) } } diff --git a/expr_test.go b/expr_test.go index 23f4c496..df4eba50 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2602,3 +2602,20 @@ func TestArrayComparison(t *testing.T) { }) } } + +func TestIssue_570(t *testing.T) { + type Student struct { + Name string + } + + env := map[string]any{ + "student": (*Student)(nil), + } + + program, err := expr.Compile("student?.Name", expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.IsType(t, nil, out) +} diff --git a/testdata/examples.txt b/testdata/examples.txt index 5b9d2cdd..b02094a7 100644 --- a/testdata/examples.txt +++ b/testdata/examples.txt @@ -3070,11 +3070,8 @@ all(array, not true) all(array, ok) all(false ? add : "bar", ok) all(filter(array, true), ok) -all(groupBy(array, #)?.f32, #) all(groupBy(array, f32).score, 0.5 in #?.f64) -all(groupBy(array, foo)?.i64, none(.div, true == #)) all(groupBy(list, #).i, #?.half() not in # + #) -all(groupBy(list, #)?.add, #?.f64.i64(f32 startsWith #, div)) all(i32 .. 1, ok) all(list, !false) all(list, "bar" in #) @@ -3242,7 +3239,6 @@ any(array, one(array, ok)) any(array, score != add) any(array, score == half) any(array, score == nil) -any(groupBy(list, #)?.Bar, #) any(i32 .. i, half == half) any(list, !ok) any(list, !true) @@ -3992,9 +3988,7 @@ count(array, reduce(list, false)) count(array, true == ok) count(array, true) / f64 count(filter(list, true), ok) -count(groupBy(array, #)?.String, #) count(groupBy(list, #).String, .div?.array) -count(groupBy(list, #)?.list, #?.Qux) count(i .. 1, # == i) count(i .. i, ok) count(i32 .. i, # >= #) @@ -5652,12 +5646,7 @@ filter(array, true == true) filter(array, true and true) filter(false ? ok : array, # == nil) filter(filter(list, true), ok) -filter(groupBy(array, #).greet, # matches #?.score) -filter(groupBy(array, #)?.div, .Qux(#).array) -filter(groupBy(array, f64)?.f32, #) filter(groupBy(list, #).f64, # not matches #?.foo) -filter(groupBy(list, #)?.f64, .foo) -filter(groupBy(list, 0.5)?.i, #) filter(i .. i64, ok) filter(list, !ok) filter(list, !true) @@ -6033,13 +6022,9 @@ findIndex(filter(array, false), # > 1) findIndex(filter(array, true), # not in array) findIndex(filter(list, false), 1 <= i64) findIndex(groupBy(array, #).f64, .Bar(half(add, half, div, greet, f64))) -findIndex(groupBy(array, #)?.greet, .f64(foo(array, #), # != "bar")) findIndex(groupBy(array, false).String, #?.i64()) findIndex(groupBy(array, foo).String, #) findIndex(groupBy(list, #).Qux, #) -findIndex(groupBy(list, #)?.Qux, #) -findIndex(groupBy(list, #)?.i64, #) -findIndex(groupBy(list, true)?.list, get(false ? # : #, score(#, Bar, half, #))) findIndex(i32 .. 1, # >= #) findIndex(list, !false) findIndex(list, !ok) @@ -6198,13 +6183,9 @@ findLast(array, score != div) findLast(array, score != nil) findLast(array, true) % i64 findLast(groupBy(array, #).i32, #) -findLast(groupBy(array, f64)?.score, #) findLast(groupBy(list, "bar").div, #?.f32?.half()) findLast(groupBy(list, #)[ok], .ok) -findLast(groupBy(list, 0.5)?.String, #?.half(Bar)) -findLast(groupBy(list, 0.5)?.f32, ok) findLast(groupBy(list, f64).foo, .f32(#, #)?.Qux(#?.i64(add, nil))) -findLast(groupBy(list, f64)?.Bar, #) findLast(i .. 1, # < #) findLast(i .. i32, # != i64) findLast(i .. i32, ok) @@ -6379,12 +6360,7 @@ findLastIndex(array, true == false) findLastIndex(array, true) / f32 findLastIndex(array, true) not in array findLastIndex(filter(array, false), ok) -findLastIndex(groupBy(array, "foo")?.add, #[list]) -findLastIndex(groupBy(array, #)?.foo, #) -findLastIndex(groupBy(array, #)?.i32, not #) -findLastIndex(groupBy(array, #)?.score, #) findLastIndex(groupBy(list, #).ok, #) -findLastIndex(groupBy(list, #)?.String, # + #.add) findLastIndex(groupBy(list, 0.5).Bar, #?.String endsWith .f32(list)) findLastIndex(i32 .. i32, ok) findLastIndex(i64 .. 1, ok) @@ -7208,7 +7184,7 @@ fromJSON(toJSON(nil)) fromJSON(toJSON(true)) fromPairs(filter(array, false)) fromPairs(filter(list, false)) -fromPairs(groupBy(array, #)?.String) +fromPairs(groupBy(array, #).String) get(["foo"], i32) get(array, -1) get(array, -i32) @@ -8388,8 +8364,6 @@ groupBy(filter(array, true), f64) groupBy(filter(list, ok), #) groupBy(groupBy(array, "foo").list, #) groupBy(groupBy(array, #).i32, add) -groupBy(groupBy(array, 0.5)?.div, get(.foo(add, false), none(#, foo))) -groupBy(groupBy(array, f64)?.ok, .Bar()?.Bar) groupBy(groupBy(array, ok).Bar, #?.array not matches #) groupBy(groupBy(array, true).f64, #[#]) groupBy(i .. 1, !ok)?.f64 @@ -12816,8 +12790,6 @@ map(groupBy(array, #).String, i32) map(groupBy(array, #).greet, foo.Qux(.f32)) map(groupBy(array, #).greet, score) map(groupBy(array, #).score, #?.list()) -map(groupBy(array, 0.5)?.i64, .Qux()) -map(groupBy(array, f64)?.array, first(.div())) map(groupBy(list, i32).i, #) map(i .. 1, -#) map(i .. 1, 0.5 ^ #) @@ -13465,7 +13437,6 @@ mean(array) >= i64 mean(array) ^ i mean(array) ^ i32 mean(filter(array, true)) -mean(groupBy(array, #)?.half) mean(groupBy(array, i64).score) mean(i .. 1) mean(i .. i) @@ -13981,9 +13952,6 @@ none(array, ok) none(array, reduce(array, true)) none(filter(array, # >= f64), ok) none(groupBy(array, #).f64, # or #.Bar) -none(groupBy(array, #)?.array, fromJSON(#)) -none(groupBy(array, #)?.f32, #) -none(groupBy(array, #)?.greet, #) none(groupBy(list, #).div, ok) none(groupBy(list, #).greet, #) none(groupBy(list, 0.5).i, ok) @@ -15298,10 +15266,8 @@ one(array, true == nil) one(array, true) ? greet : add one(array, true) or 1 != nil one(false ? greet : "bar", f64 != #) -one(groupBy(array, 0.5)?.String, #) one(groupBy(array, f64).array, .i(nil).ok) one(groupBy(list, ok).foo, .add?.array) -one(groupBy(list, ok)?.score, .i64()) one(i32 .. 1, # == #) one(i32 .. i32, not ok) one(i64 .. 1, # >= 1) @@ -17297,7 +17263,6 @@ sum(array) ^ f64 sum(array) not in array sum(filter(array, ok)) sum(groupBy(array, i32).String) -sum(groupBy(list, #)?.greet) sum(i32 .. 1) sum(i64 .. i32) sum(i64 .. i64)