Skip to content

Commit

Permalink
Merge pull request #89 from shogo82148/revert-84-main
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Aug 3, 2023
2 parents 300c00d + 918628b commit 7ea01d1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 45 deletions.
61 changes: 36 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ type Conn struct {
// It will trigger PrePing, Ping, PostPing hooks.
//
// If the original connection does not satisfy "database/sql/driver".Pinger, it does nothing.
func (conn *Conn) Ping(c context.Context) (err error) {
func (conn *Conn) Ping(c context.Context) error {
var err error
var ctx interface{}
hooks := conn.Proxy.getHooks(c)

if hooks != nil {
defer func() { err = hooks.postPing(c, ctx, conn, err) }()
defer func() { hooks.postPing(c, ctx, conn, err) }()
if ctx, err = hooks.prePing(c, conn); err != nil {
return err
}
Expand All @@ -48,30 +49,31 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {
}

// PrepareContext returns a prepared statement which is wrapped by Stmt.
func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.Stmt, err error) {
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
var ctx interface{}
var stmtAux = &Stmt{
var stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
var err error
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { err = hooks.postPrepare(c, ctx, stmtAux, err) }()
if ctx, err = hooks.prePrepare(c, stmtAux); err != nil {
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
return nil, err
}
}

if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmtAux.Stmt, err = connCtx.PrepareContext(c, stmtAux.QueryString)
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
} else {
stmtAux.Stmt, err = conn.Conn.Prepare(stmtAux.QueryString)
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
if err == nil {
select {
default:
case <-c.Done():
stmtAux.Stmt.Close()
stmt.Stmt.Close()
return nil, c.Err()
}
}
Expand All @@ -81,20 +83,21 @@ func (conn *Conn) PrepareContext(c context.Context, query string) (stmt driver.S
}

if hooks != nil {
if err = hooks.prepare(c, ctx, stmtAux); err != nil {
if err = hooks.prepare(c, ctx, stmt); err != nil {
return nil, err
}
}
return stmtAux, nil
return stmt, nil
}

// Close calls the original Close method.
func (conn *Conn) Close() (err error) {
func (conn *Conn) Close() error {
ctx := context.Background()
var err error
var myctx interface{}

if hooks := conn.Proxy.hooks; hooks != nil {
defer func() { err = hooks.postClose(ctx, myctx, conn, err) }()
defer func() { hooks.postClose(ctx, myctx, conn, err) }()
if myctx, err = hooks.preClose(ctx, conn); err != nil {
return err
}
Expand All @@ -120,12 +123,14 @@ func (conn *Conn) Begin() (driver.Tx, error) {

// BeginTx starts and returns a new transaction which is wrapped by Tx.
// It will trigger PreBegin, Begin, PostBegin hooks.
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
func (conn *Conn) BeginTx(c context.Context, opts driver.TxOptions) (driver.Tx, error) {
// set the hooks.
var err error
var ctx interface{}
var tx driver.Tx
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { err = hooks.postBegin(c, ctx, conn, err) }()
defer func() { hooks.postBegin(c, ctx, conn, err) }()
if ctx, err = hooks.preBegin(c, conn); err != nil {
return nil, err
}
Expand Down Expand Up @@ -188,7 +193,7 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
// It will trigger PreExec, Exec, PostExec hooks.
//
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (drv driver.Result, err error) {
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
execer, exOk := conn.Conn.(driver.Execer)
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
if !exOk && !exCtxOk {
Expand All @@ -202,17 +207,19 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
Conn: conn,
}
var ctx interface{}
var err error
var result driver.Result
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { err = hooks.postExec(c, ctx, stmt, args, drv, err) }()
defer func() { hooks.postExec(c, ctx, stmt, args, result, err) }()
if ctx, err = hooks.preExec(c, stmt, args); err != nil {
return nil, err
}
}

// call the original method.
if execerCtx != nil {
drv, err = execerCtx.ExecContext(c, stmt.QueryString, args)
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
} else {
select {
default:
Expand All @@ -223,18 +230,19 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
if err0 != nil {
return nil, err0
}
drv, err = execer.Exec(stmt.QueryString, dargs)
result, err = execer.Exec(stmt.QueryString, dargs)
}
if err != nil {
return nil, err
}

if hooks != nil {
if err = hooks.exec(c, ctx, stmt, args, drv); err != nil {
if err = hooks.exec(c, ctx, stmt, args, result); err != nil {
return nil, err
}
}
return drv, err

return result, nil
}

// Query executes a query that may return rows.
Expand All @@ -250,7 +258,7 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
// It wil trigger PreQuery, Query, PostQuery hooks.
//
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
queryer, qok := conn.Conn.(driver.Queryer)
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
if !qok && !qCtxOk {
Expand All @@ -263,9 +271,11 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
Conn: conn,
}
var ctx interface{}
var err error
var rows driver.Rows
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
defer func() { err = hooks.postQuery(c, ctx, stmt, args, rows, err) }()
defer func() { hooks.postQuery(c, ctx, stmt, args, rows, err) }()
if ctx, err = hooks.preQuery(c, stmt, args); err != nil {
return nil, err
}
Expand Down Expand Up @@ -333,12 +343,13 @@ type sessionResetter interface {
}

// ResetSession resets the state of Conn.
func (conn *Conn) ResetSession(ctx context.Context) (err error) {
func (conn *Conn) ResetSession(ctx context.Context) error {
var err error
var myctx interface{}
hooks := conn.Proxy.getHooks(ctx)

if hooks != nil {
defer func() { err = hooks.postResetSession(ctx, myctx, conn, err) }()
defer func() { hooks.postResetSession(ctx, myctx, conn, err) }()
if myctx, err = hooks.preResetSession(ctx, conn); err != nil {
return err
}
Expand Down
20 changes: 10 additions & 10 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func (h *HooksContext) ping(c context.Context, ctx interface{}, conn *Conn) erro

func (h *HooksContext) postPing(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostPing == nil {
return err
return nil
}
return h.PostPing(c, ctx, conn, err)
}
Expand All @@ -433,7 +433,7 @@ func (h *HooksContext) open(c context.Context, ctx interface{}, conn *Conn) erro

func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostOpen == nil {
return err
return nil
}
return h.PostOpen(c, ctx, conn, err)
}
Expand All @@ -454,7 +454,7 @@ func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) e

func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
if h == nil || h.PostPrepare == nil {
return err
return nil
}
return h.PostPrepare(c, ctx, stmt, err)
}
Expand All @@ -475,7 +475,7 @@ func (h *HooksContext) exec(c context.Context, ctx interface{}, stmt *Stmt, args

func (h *HooksContext) postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error {
if h == nil || h.PostExec == nil {
return err
return nil
}
return h.PostExec(c, ctx, stmt, args, result, err)
}
Expand All @@ -496,7 +496,7 @@ func (h *HooksContext) query(c context.Context, ctx interface{}, stmt *Stmt, arg

func (h *HooksContext) postQuery(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, rows driver.Rows, err error) error {
if h == nil || h.PostQuery == nil {
return err
return nil
}
return h.PostQuery(c, ctx, stmt, args, rows, err)
}
Expand All @@ -517,7 +517,7 @@ func (h *HooksContext) begin(c context.Context, ctx interface{}, conn *Conn) err

func (h *HooksContext) postBegin(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostBegin == nil {
return err
return nil
}
return h.PostBegin(c, ctx, conn, err)
}
Expand All @@ -538,7 +538,7 @@ func (h *HooksContext) commit(c context.Context, ctx interface{}, tx *Tx) error

func (h *HooksContext) postCommit(c context.Context, ctx interface{}, tx *Tx, err error) error {
if h == nil || h.PostCommit == nil {
return err
return nil
}
return h.PostCommit(c, ctx, tx, err)
}
Expand All @@ -559,7 +559,7 @@ func (h *HooksContext) rollback(c context.Context, ctx interface{}, tx *Tx) erro

func (h *HooksContext) postRollback(c context.Context, ctx interface{}, tx *Tx, err error) error {
if h == nil || h.PostRollback == nil {
return err
return nil
}
return h.PostRollback(c, ctx, tx, err)
}
Expand All @@ -580,7 +580,7 @@ func (h *HooksContext) close(c context.Context, ctx interface{}, conn *Conn) err

func (h *HooksContext) postClose(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostClose == nil {
return err
return nil
}
return h.PostClose(c, ctx, conn, err)
}
Expand All @@ -601,7 +601,7 @@ func (h *HooksContext) resetSession(c context.Context, ctx interface{}, conn *Co

func (h *HooksContext) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
if h == nil || h.PostResetSession == nil {
return err
return nil
}
return h.PostResetSession(c, ctx, conn, err)
}
Expand Down
20 changes: 10 additions & 10 deletions logging_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (h *loggingHook) postPing(c context.Context, ctx interface{}, conn *Conn, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostPing]")
return err
return nil
}

func (h *loggingHook) preOpen(c context.Context, name string) (interface{}, error) {
Expand All @@ -58,7 +58,7 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostOpen]")
return err
return nil
}

func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
Expand All @@ -79,7 +79,7 @@ func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostPrepare]")
return err
return nil
}

func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
Expand All @@ -100,7 +100,7 @@ func (h *loggingHook) postExec(c context.Context, ctx interface{}, stmt *Stmt, a
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostExec]")
return err
return nil
}

func (h *loggingHook) preQuery(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
Expand All @@ -121,7 +121,7 @@ func (h *loggingHook) postQuery(c context.Context, ctx interface{}, stmt *Stmt,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostQuery]")
return err
return nil
}

func (h *loggingHook) preBegin(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -142,7 +142,7 @@ func (h *loggingHook) postBegin(c context.Context, ctx interface{}, conn *Conn,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostBegin]")
return err
return nil
}

func (h *loggingHook) preCommit(c context.Context, tx *Tx) (interface{}, error) {
Expand All @@ -163,7 +163,7 @@ func (h *loggingHook) postCommit(c context.Context, ctx interface{}, tx *Tx, err
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostCommit]")
return err
return nil
}

func (h *loggingHook) preRollback(c context.Context, tx *Tx) (interface{}, error) {
Expand All @@ -184,7 +184,7 @@ func (h *loggingHook) postRollback(c context.Context, ctx interface{}, tx *Tx, e
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostRollback]")
return err
return nil
}

func (h *loggingHook) preClose(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -205,7 +205,7 @@ func (h *loggingHook) postClose(c context.Context, ctx interface{}, conn *Conn,
h.mu.Lock()
defer h.mu.Unlock()
fmt.Fprintln(h, "[PostClose]")
return err
return nil
}

func (h *loggingHook) preResetSession(c context.Context, conn *Conn) (interface{}, error) {
Expand All @@ -217,7 +217,7 @@ func (h *loggingHook) resetSession(c context.Context, ctx interface{}, conn *Con
}

func (h *loggingHook) postResetSession(c context.Context, ctx interface{}, conn *Conn, err error) error {
return err
return nil
}

func (h *loggingHook) preIsValid(conn *Conn) (interface{}, error) {
Expand Down

0 comments on commit 7ea01d1

Please sign in to comment.