diff --git a/db.go b/db.go index 3c7b422..4fbc7eb 100644 --- a/db.go +++ b/db.go @@ -181,10 +181,7 @@ func (c *conn) Close() (err error) { c.opened-- if c.opened == 0 { if c.tx != nil { - err = c.tx.Rollback() - if err != nil { - return - } + c.tx.Rollback() c.tx = nil } c.drv.deleteConn(c.dsn) diff --git a/db_test.go b/db_test.go index fa7c749..a4a7007 100644 --- a/db_test.go +++ b/db_test.go @@ -2,10 +2,12 @@ package txdb import ( "database/sql" + "errors" "fmt" "strings" "sync" "testing" + "time" ) func drivers() []string { @@ -381,3 +383,55 @@ func TestShouldReopenAfterClose(t *testing.T) { } } } + +type canceledContext struct{} + +func (canceledContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, true } +func (canceledContext) Done() <-chan struct{} { + done := make(chan struct{}, 0) + close(done) + return done +} +func (canceledContext) Err() error { return errors.New("canceled") } +func (canceledContext) Value(key interface{}) interface{} { return nil } + +func TestShouldDiscardConnectionWhenClosedBecauseOfError(t *testing.T) { + for _, driver := range drivers() { + t.Run(fmt.Sprintf("using driver %s", driver), func(t *testing.T) { + { + db, err := sql.Open(driver, "first") + if err != nil { + t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + } + defer db.Close() + + tx, err := db.Begin() + defer tx.Rollback() + if err != nil { + t.Fatalf(driver+": failed to begin transaction err: %s", err) + } + + // TODO: we somehow need to poison the DB connection here so that Rollback fails + + _, err = tx.PrepareContext(canceledContext{}, "SELECT * FROM users") + if err == nil { + t.Fatalf(driver + ": should have returned error for prepare") + } + } + + fmt.Println("Opening db...") + + { + db, err := sql.Open(driver, "second") + if err != nil { + t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + } + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatalf(driver+": failed to ping, have you run 'make test'? err: %s", err) + } + } + }) + } +}