Skip to content

Commit

Permalink
Merge pull request #328 from tonyhb/perf/precompile-lua-functions
Browse files Browse the repository at this point in the history
Precompile lua functions to prevent wasted CPU cycles
  • Loading branch information
alicebob committed May 20, 2023
2 parents b2ccd76 + 536e717 commit a3211de
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions cmd_scripting.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"strconv"
"strings"
"sync"

luajson "github.com/alicebob/gopher-json"
lua "github.com/yuin/gopher-lua"
Expand All @@ -21,6 +22,10 @@ func commandsScripting(m *Miniredis) {
m.srv.Register("SCRIPT", m.cmdScript)
}

var (
parsedScripts = sync.Map{}
)

// Execute lua. Needs to run m.Lock()ed, from within withTx().
// Returns true if the lua was OK (and hence should be cached).
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
Expand Down Expand Up @@ -91,20 +96,56 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri
return 1
}))

l.DoString(protectGlobals)
_ = doScript(l, protectGlobals)

l.Push(lua.LString("redis"))
l.Call(1, 0)

if err := l.DoString(script); err != nil {
c.WriteError(errLuaParseError(err))
if err := doScript(l, script); err != nil {
c.WriteError(err.Error())
return false
}

luaToRedis(l, c, l.Get(1))
return true
}

// doScript pre-compiiles the given script into a Lua prototype,
// then executes the pre-compiled function against the given lua state.
//
// This is thread-safe.
func doScript(l *lua.LState, script string) error {
proto, err := compile(script)
if err != nil {
return fmt.Errorf(errLuaParseError(err))
}

lfunc := l.NewFunctionFromProto(proto)
l.Push(lfunc)
if err := l.PCall(0, lua.MultRet, nil); err != nil {
// ensure we wrap with the correct format.
return fmt.Errorf(errLuaParseError(err))
}

return nil
}

func compile(script string) (*lua.FunctionProto, error) {
if val, ok := parsedScripts.Load(script); ok {
return val.(*lua.FunctionProto), nil
}
chunk, err := parse.Parse(strings.NewReader(script), "<string>")
if err != nil {
return nil, err
}
proto, err := lua.Compile(chunk, "")
if err != nil {
return nil, err
}
parsedScripts.Store(script, proto)
return proto, nil
}

func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
Expand Down

0 comments on commit a3211de

Please sign in to comment.