Skip to content

Commit

Permalink
Merge pull request #61 from Songmu/overwrite-query
Browse files Browse the repository at this point in the history
pass stmt.QueryString to QueryContext and ExecContext calls
  • Loading branch information
shogo82148 committed Jan 8, 2021
2 parents de14c4e + 9abdebe commit b039787
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,16 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
}

// set the hooks.
var stmt *Stmt
var stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
var ctx interface{}
var err error
var result driver.Result
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
defer func() { hooks.postExec(c, ctx, stmt, args, result, err) }()
if ctx, err = hooks.preExec(c, stmt, args); err != nil {
return nil, err
Expand All @@ -205,7 +204,7 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam

// call the original method.
if execerCtx, ok := execer.(driver.ExecerContext); ok {
result, err = execerCtx.ExecContext(c, query, args)
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
} else {
select {
default:
Expand All @@ -216,7 +215,7 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
if err0 != nil {
return nil, err0
}
result, err = execer.Exec(query, dargs)
result, err = execer.Exec(stmt.QueryString, dargs)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -250,17 +249,16 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
return nil, driver.ErrSkip
}

var stmt *Stmt
var stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
var ctx interface{}
var err error
var rows driver.Rows
hooks := conn.Proxy.getHooks(c)
if hooks != nil {
stmt = &Stmt{
QueryString: query,
Proxy: conn.Proxy,
Conn: conn,
}
defer func() { hooks.postQuery(c, ctx, stmt, args, rows, err) }()
if ctx, err = hooks.preQuery(c, stmt, args); err != nil {
return nil, err
Expand All @@ -269,7 +267,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na

// call the original method.
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
rows, err = queryerCtx.QueryContext(c, query, args)
rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args)
} else {
select {
default:
Expand All @@ -280,7 +278,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
if err0 != nil {
return nil, err0
}
rows, err = queryer.Query(query, dargs)
rows, err = queryer.Query(stmt.QueryString, dargs)
}
if err != nil {
return nil, err
Expand Down

0 comments on commit b039787

Please sign in to comment.