Skip to content

Commit

Permalink
Merge pull request #62 from Songmu/preprepare
Browse files Browse the repository at this point in the history
add PrePrepare, Prepare and PostPrepare to hook prepare
  • Loading branch information
shogo82148 committed Jan 11, 2021
2 parents b039787 + 61cb1ab commit a60ed77
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 10 deletions.
34 changes: 24 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,44 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {

// PrepareContext returns a prepared statement which is wrapped by Stmt.
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
var stmt driver.Stmt
var ctx interface{}
var stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
var err error
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
return nil, err
}
}

if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmt, err = connCtx.PrepareContext(c, query)
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
} else {
stmt, err = conn.Conn.Prepare(query)
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
if err == nil {
select {
default:
case <-c.Done():
stmt.Close()
stmt.Stmt.Close()
return nil, c.Err()
}
}
}
if err != nil {
return nil, err
}
return &Stmt{
Stmt: stmt,
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}, nil

if hooks != nil {
if err = hooks.prepare(c, ctx, stmt); err != nil {
return nil, err
}
}
return stmt, nil
}

// Close calls the original Close method.
Expand Down
84 changes: 84 additions & 0 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type hooks interface {
preOpen(c context.Context, name string) (interface{}, error)
open(c context.Context, ctx interface{}, conn *Conn) error
postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error
prePrepare(c context.Context, stmt *Stmt) (interface{}, error)
prepare(c context.Context, ctx interface{}, stmt *Stmt) error
postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error
preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error)
exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error
postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error
Expand Down Expand Up @@ -109,6 +112,36 @@ type HooksContext struct {
// `Hooks.PreOpen` method, and may be nil.
PostOpen func(c context.Context, ctx interface{}, conn *Conn, err error) error

// PrePrepare is a callback that gets called prior to calling
// `db.Prepare`, and is ALWAYS called. If this callback returns an
// error, the underlying driver's `db.Exec` and `Hooks.Prepare` methods
// are not called.
//
// The first return value is passed to both `Hooks.Prepare` and
// `Hooks.PostPrepare` callbacks. You may specify anything you want.
// Return nil if you do not need to use it.
//
// The second return value is indicates the error found while
// executing this hook.
PrePrepare func(c context.Context, stmt *Stmt) (interface{}, error)

// Prepare is called after the underlying driver's `db.Prepare` method
// returns without any errors.
//
// The `ctx` parameter is the return value supplied from the
// `Hooks.PrePrepare` method, and may be nil.
//
// If this callback returns an error, then the error from this
// callback is returned by the `db.Prepare` method.
Prepare func(c context.Context, ctx interface{}, stmt *Stmt) error

// PostPrepare is a callback that gets called at the end of
// the call to `db.Prepare`. It is ALWAYS called.
//
// The `ctx` parameter is the return value supplied from the
// `Hooks.PrePrepare` method, and may be nil.
PostPrepare func(c context.Context, ctx interface{}, stmt *Stmt, err error) error

// PreExec is a callback that gets called prior to calling
// `Stmt.Exec`, and is ALWAYS called. If this callback returns an
// error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods
Expand Down Expand Up @@ -405,6 +438,27 @@ func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn,
return h.PostOpen(c, ctx, conn, err)
}

func (h *HooksContext) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
if h == nil || h.PrePrepare == nil {
return nil, nil
}
return h.PrePrepare(c, stmt)
}

func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
if h == nil || h.Prepare == nil {
return nil
}
return h.Prepare(c, ctx, stmt)
}

func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
if h == nil || h.PostPrepare == nil {
return nil
}
return h.PostPrepare(c, ctx, stmt, err)
}

func (h *HooksContext) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
if h == nil || h.PreExec == nil {
return nil, nil
Expand Down Expand Up @@ -929,6 +983,18 @@ func (h *Hooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err err
return h.PostOpen(ctx, conn)
}

func (h *Hooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
return nil, nil
}

func (h *Hooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
return nil
}

func (h *Hooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
return nil
}

func (h *Hooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
if h == nil || h.PreExec == nil {
return nil, nil
Expand Down Expand Up @@ -1187,6 +1253,24 @@ func (h multipleHooks) postOpen(c context.Context, ctx interface{}, conn *Conn,
})
}

func (h multipleHooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
return h.preDo(func(h hooks) (interface{}, error) {
return h.prePrepare(c, stmt)
})
}

func (h multipleHooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
return h.do(ctx, func(h hooks, ctx interface{}) error {
return h.prepare(c, ctx, stmt)
})
}

func (h multipleHooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error {
return h.postPrepare(c, ctx, stmt, err)
})
}

func (h multipleHooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
return h.preDo(func(h hooks) (interface{}, error) {
return h.preExec(c, stmt, args)
Expand Down
21 changes: 21 additions & 0 deletions logging_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
return nil
}

func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PrePrepare]")
return nil, nil
}

func (h *loggingHook) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[Prepare]")
return nil
}

func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostPrepare]")
return nil
}

func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
h.mu.Lock()
defer h.mu.Unlock()
Expand Down
9 changes: 9 additions & 0 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func TestFakeDB(t *testing.T) {
Name: "execAll",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreExec]\n[Exec]\n[PostExec]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand All @@ -49,6 +50,7 @@ func TestFakeDB(t *testing.T) {
FailExec: true,
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreExec]\n[PostExec]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand All @@ -64,6 +66,7 @@ func TestFakeDB(t *testing.T) {
Name: "execError-NamedValue",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreExec]\n[PostExec]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand All @@ -80,6 +83,7 @@ func TestFakeDB(t *testing.T) {
Name: "queryAll",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreQuery]\n[Query]\n[PostQuery]\n",
f: func(db *sql.DB) error {
_, err := db.Query("SELECT * FROM test WHERE id = ?", 123456789)
Expand All @@ -92,6 +96,7 @@ func TestFakeDB(t *testing.T) {
FailQuery: true,
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreQuery]\n[PostQuery]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand All @@ -107,6 +112,7 @@ func TestFakeDB(t *testing.T) {
Name: "prepare",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?")
Expand Down Expand Up @@ -255,6 +261,7 @@ func TestFakeDB(t *testing.T) {
ConnType: "fakeConnExt",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreExec]\n[Exec]\n[PostExec]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand Down Expand Up @@ -325,6 +332,7 @@ func TestFakeDB(t *testing.T) {
ConnType: "fakeConnCtx",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreExec]\n[Exec]\n[PostExec]\n" +
"[PreClose]\n[Close]\n[PostClose]\n",
f: func(db *sql.DB) error {
Expand All @@ -343,6 +351,7 @@ func TestFakeDB(t *testing.T) {
ConnType: "fakeConnCtx",
},
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
"[PreQuery]\n[Query]\n[PostQuery]\n",
f: func(db *sql.DB) error {
stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?")
Expand Down

0 comments on commit a60ed77

Please sign in to comment.