diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2831b92..c98cc68 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -50,12 +50,8 @@ jobs: - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - - name: Create Table - run: | - make mysql - make psql - name: Run integration tests env: - MYSQL_DSN: "root:pass@/txdb_test" - PSQL_DSN: "postgres://postgres:pass@localhost/txdb_test" + MYSQL_DSN: "root:pass@/" + PSQL_DSN: "postgres://postgres:pass@localhost/" run: go test ./... diff --git a/Makefile b/Makefile index 30e1320..68d2742 100644 --- a/Makefile +++ b/Makefile @@ -1,54 +1,4 @@ -define MYSQL_SQL -CREATE TABLE users ( - id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, - username VARCHAR(32) NOT NULL, - email VARCHAR(255) NOT NULL, - PRIMARY KEY (id), - UNIQUE INDEX uniq_email (email) -) DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci ENGINE = InnoDB; -endef - -define PSQL_SQL -CREATE TABLE users ( - id SERIAL PRIMARY KEY, - username VARCHAR(32) NOT NULL, - email VARCHAR(255) UNIQUE NOT NULL -); -endef - -export MYSQL_SQL -MYSQL := "$$MYSQL_SQL" - -export PSQL_SQL -PSQL := "$$PSQL_SQL" - -INSERTS := "INSERT INTO users (username, email) VALUES ('gopher', 'gopher@go.com'), ('john', 'john@doe.com'), ('jane', 'jane@doe.com');" - -MYSQLCMD=mysql -ifndef CI - MYSQLCMD=docker compose exec mysql mysql -endif - -PSQLCMD=psql -ifndef CI - PSQLCMD=docker compose exec postgres psql -endif - test: MYSQL_DSN=root:pass@/txdb_test test: PSQL_DSN=postgres://postgres:pass@localhost/txdb_test test: mysql psql @go test -race - -mysql: - @$(MYSQLCMD) -h 127.0.0.1 -u root -ppass -e 'DROP DATABASE IF EXISTS txdb_test' - @$(MYSQLCMD) -h 127.0.0.1 -u root -ppass -e 'CREATE DATABASE txdb_test' - @$(MYSQLCMD) -h 127.0.0.1 -u root -ppass txdb_test -e $(MYSQL) - @$(MYSQLCMD) -h 127.0.0.1 -u root -ppass txdb_test -e $(INSERTS) - -psql: - @$(PSQLCMD) "postgresql://postgres:pass@127.0.0.1" -c 'DROP DATABASE IF EXISTS txdb_test' - @$(PSQLCMD) "postgresql://postgres:pass@127.0.0.1" -c 'CREATE DATABASE txdb_test' - @$(PSQLCMD) "postgresql://postgres:pass@127.0.0.1/txdb_test" -c $(PSQL) - @$(PSQLCMD) "postgresql://postgres:pass@127.0.0.1/txdb_test" -c $(INSERTS) - -.PHONY: test mysql psql diff --git a/bootstrap_test.go b/bootstrap_test.go new file mode 100644 index 0000000..f4d897b --- /dev/null +++ b/bootstrap_test.go @@ -0,0 +1,62 @@ +package txdb_test + +import ( + "database/sql" + "testing" +) + +const ( + mysql_sql = `CREATE TABLE users ( + id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, + username VARCHAR(32) NOT NULL, + email VARCHAR(255) NOT NULL, + PRIMARY KEY (id), + UNIQUE INDEX uniq_email (email) + ) DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci ENGINE = InnoDB` + + psql_sql = `CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(32) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL + )` + + inserts = `INSERT INTO users (username, email) VALUES ('gopher', 'gopher@go.com'), ('john', 'john@doe.com'), ('jane', 'jane@doe.com')` + + testDB = "txdb_test" +) + +// bootstrap bootstraps the database with the nfor tests. +func bootstrap(t *testing.T, driver, dsn string) { + db, err := sql.Open(driver, dsn) + if err != nil { + t.Fatal(err) + } + switch driver { + case "mysql": + if _, err := db.Exec(mysql_sql); err != nil { + t.Fatal(err) + } + case "postgres": + if _, err := db.Exec(psql_sql); err != nil { + t.Fatal(err) + } + default: + panic("unrecognized driver: " + driver) + } + if _, err := db.Exec(inserts); err != nil { + t.Fatal(err) + } +} + +func createDB(t *testing.T, driver, dsn string) { + db, err := sql.Open(driver, dsn) + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec("DROP DATABASE IF EXISTS txdb_test"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("CREATE DATABASE txdb_test"); err != nil { + t.Fatal(err) + } +} diff --git a/db_test.go b/db_test.go index 8b035a0..9185e30 100644 --- a/db_test.go +++ b/db_test.go @@ -44,17 +44,19 @@ type testDrivers []*testDriver var registerMu sync.Mutex -// dsn returns the full dsn for the test driver, or calls t.Skip if it is -// unset or disabled. -func (d *testDriver) dsn(t *testing.T) string { - dsn := os.Getenv(d.dsnEnvKey) - if dsn == "" { +// dsn returns the base dsn (without DB name) and the full dsn (with dbname) +// for the test driver, or calls t.Skip if it is unset or disabled. +func (d *testDriver) dsn(t *testing.T) (base string, full string) { + t.Helper() + base = os.Getenv(d.dsnEnvKey) + if base == "" { t.Skipf("%s not set, skipping tests for %s", d.dsnEnvKey, d.driver) } + full = strings.TrimSuffix(base, "/") + "/" + testDB if d.options == "" { - return dsn + return base, full } - return dsn + "?" + d.options + return base + "?" + d.options, full + "?" + d.options } func (d *testDriver) register(t *testing.T) { @@ -62,24 +64,26 @@ func (d *testDriver) register(t *testing.T) { registerMu.Lock() defer registerMu.Unlock() if !d.registered { - dsn := d.dsn(t) + base, full := d.dsn(t) d.registered = true - txdb.Register(d.name, d.driver, dsn) + createDB(t, d.driver, base) + bootstrap(t, d.driver, full) + txdb.Register(d.name, d.driver, full) } } // Run registers the driver, if not already registered, then calls f with the // driver name. -func (d *testDriver) Run(t *testing.T, f func(t *testing.T, driver string)) { +func (d *testDriver) Run(t *testing.T, f func(t *testing.T, driver *testDriver)) { t.Helper() t.Run(d.name, func(t *testing.T) { d.register(t) - f(t, d.name) + f(t, d) }) } // Run iterates over the configured drivers, and calls [testDriver.Run] on each. -func (d testDrivers) Run(t *testing.T, f func(t *testing.T, driver string)) { +func (d testDrivers) Run(t *testing.T, f func(t *testing.T, driver *testDriver)) { t.Helper() for _, driver := range d { driver.Run(t, f) @@ -103,8 +107,9 @@ func (d testDrivers) drivers(names ...string) testDrivers { func TestShouldWorkWithOpenDB(t *testing.T) { t.Parallel() for _, d := range txDrivers { - d.Run(t, func(t *testing.T, _ string) { - db := sql.OpenDB(txdb.New(d.driver, d.dsn(t))) + d.Run(t, func(t *testing.T, driver *testDriver) { + _, dsn := driver.dsn(t) + db := sql.OpenDB(txdb.New(d.driver, dsn)) defer db.Close() _, err := db.Exec("SELECT 1") if err != nil { @@ -116,11 +121,11 @@ func TestShouldWorkWithOpenDB(t *testing.T) { func TestShouldRunWithNestedTransaction(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { var count int - db, err := sql.Open(driver, "five") + db, err := sql.Open(driver.name, "five") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } func(db *sql.DB) { @@ -128,88 +133,88 @@ func TestShouldRunWithNestedTransaction(t *testing.T) { _, err = db.Exec(`INSERT INTO users (username, email) VALUES('txdb', 'txdb@test1.com')`) if err != nil { - t.Fatalf(driver+": failed to insert an user: %s", err) + t.Fatalf("failed to insert a user: %s", err) } err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 4 { - t.Fatalf(driver+": expected 4 users to be in database, but got %d", count) + t.Fatalf("expected 4 users to be in database, but got %d", count) } tx, err := db.Begin() if err != nil { - t.Fatalf(driver+": failed to begin transaction: %s", err) + t.Fatalf("failed to begin transaction: %s", err) } { _, err = tx.Exec(`INSERT INTO users (username, email) VALUES('txdb', 'txdb@test2.com')`) if err != nil { - t.Fatalf(driver+": failed to insert an user: %s", err) + t.Fatalf("failed to insert an user: %s", err) } err = tx.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 5 { - t.Fatalf(driver+": expected 5 users to be in database, but got %d", count) + t.Fatalf("expected 5 users to be in database, but got %d", count) } if err := tx.Rollback(); err != nil { - t.Fatalf(driver+": failed to rollback transaction: %s", err) + t.Fatalf("failed to rollback transaction: %s", err) } } err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 4 { - t.Fatalf(driver+": expected 4 users to be in database, but got %d", count) + t.Fatalf("expected 4 users to be in database, but got %d", count) } tx, err = db.Begin() if err != nil { - t.Fatalf(driver+": failed to begin transaction: %s", err) + t.Fatalf("failed to begin transaction: %s", err) } { _, err = tx.Exec(`INSERT INTO users (username, email) VALUES('txdb', 'txdb@test2.com')`) if err != nil { - t.Fatalf(driver+": failed to insert an user: %s", err) + t.Fatalf("failed to insert an user: %s", err) } err = tx.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 5 { - t.Fatalf(driver+": expected 5 users to be in database, but got %d", count) + t.Fatalf("expected 5 users to be in database, but got %d", count) } if err := tx.Commit(); err != nil { - t.Fatalf(driver+": failed to commit transaction: %s", err) + t.Fatalf("failed to commit transaction: %s", err) } } err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 5 { - t.Fatalf(driver+": expected 5 users to be in database, but got %d", count) + t.Fatalf("expected 5 users to be in database, but got %d", count) } }(db) - db, err = sql.Open(driver, "six") + db, err = sql.Open(driver.name, "six") if err != nil { - t.Fatalf(driver+": failed to reopen a mysql connection: %s", err) + t.Fatalf("failed to reopen a mysql connection: %s", err) } func(db *sql.DB) { defer db.Close() err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 3 { - t.Fatalf(driver+": expected 3 users to be in database, but got %d", count) + t.Fatalf("expected 3 users to be in database, but got %d", count) } }(db) }) @@ -217,11 +222,11 @@ func TestShouldRunWithNestedTransaction(t *testing.T) { func TestShouldRunWithinTransaction(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { var count int - db, err := sql.Open(driver, "one") + db, err := sql.Open(driver.name, "one") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } func(db *sql.DB) { @@ -229,30 +234,30 @@ func TestShouldRunWithinTransaction(t *testing.T) { _, err = db.Exec(`INSERT INTO users (username, email) VALUES('txdb', 'txdb@test.com')`) if err != nil { - t.Fatalf(driver+": failed to insert an user: %s", err) + t.Fatalf("failed to insert an user: %s", err) } err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 4 { - t.Fatalf(driver+": expected 4 users to be in database, but got %d", count) + t.Fatalf("expected 4 users to be in database, but got %d", count) } }(db) - db, err = sql.Open(driver, "two") + db, err = sql.Open(driver.name, "two") if err != nil { - t.Fatalf(driver+": failed to reopen a mysql connection: %s", err) + t.Fatalf("failed to reopen a mysql connection: %s", err) } func(db *sql.DB) { defer db.Close() err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 3 { - t.Fatalf(driver+": expected 3 users to be in database, but got %d", count) + t.Fatalf("expected 3 users to be in database, but got %d", count) } }(db) }) @@ -260,32 +265,32 @@ func TestShouldRunWithinTransaction(t *testing.T) { func TestShouldNotHoldConnectionForRows(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "three") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "three") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() rows, err := db.Query("SELECT username FROM users") if err != nil { - t.Fatalf(driver+": failed to query users: %s", err) + t.Fatalf("failed to query users: %s", err) } defer rows.Close() _, err = db.Exec(`INSERT INTO users(username, email) VALUES('txdb', 'txdb@test.com')`) if err != nil { - t.Fatalf(driver+": failed to insert an user: %s", err) + t.Fatalf("failed to insert an user: %s", err) } }) } func TestShouldPerformParallelActions(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "four") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "four") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() @@ -296,19 +301,19 @@ func TestShouldPerformParallelActions(t *testing.T) { defer wg.Done() rows, err := d.Query("SELECT username FROM users") if err != nil { - t.Errorf(driver+": failed to query users: %s", err) + t.Errorf("failed to query users: %s", err) } defer rows.Close() insertSQL := "INSERT INTO users(username, email) VALUES(?, ?)" - if strings.Index(driver, "psql_") == 0 { + if strings.Index(driver.name, "psql_") == 0 { insertSQL = "INSERT INTO users(username, email) VALUES($1, $2)" } username := fmt.Sprintf("parallel%d", idx) email := fmt.Sprintf("parallel%d@test.com", idx) _, err = d.Exec(insertSQL, username, email) if err != nil { - t.Errorf(driver+": failed to insert an user: %s", err) + t.Errorf("failed to insert an user: %s", err) } }(db, i) } @@ -316,164 +321,164 @@ func TestShouldPerformParallelActions(t *testing.T) { var count int err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": failed to count users: %s", err) + t.Fatalf("failed to count users: %s", err) } if count != 7 { - t.Fatalf(driver+": expected 7 users to be in database, but got %d", count) + t.Fatalf("expected 7 users to be in database, but got %d", count) } }) } func TestShouldFailInvalidPrepareStatement(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "fail_prepare") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "fail_prepare") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() if _, err = db.Prepare("THIS SHOULD FAIL..."); err == nil { - t.Fatalf(driver + ": expected an error, since prepare should validate sql query, but got none") + t.Fatal("expected an error, since prepare should validate sql query, but got none") } }) } func TestShouldHandlePrepare(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "prepare") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "prepare") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() selectSQL := "SELECT email FROM users WHERE username = ?" - if strings.Index(driver, "psql_") == 0 { + if strings.Index(driver.name, "psql_") == 0 { selectSQL = "SELECT email FROM users WHERE username = $1" } stmt1, err := db.Prepare(selectSQL) if err != nil { - t.Fatalf(driver+": could not prepare - %s", err) + t.Fatalf("could not prepare - %s", err) } insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)" - if strings.Index(driver, "psql_") == 0 { + if strings.Index(driver.name, "psql_") == 0 { insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)" } stmt2, err := db.Prepare(insertSQL) if err != nil { - t.Fatalf(driver+": could not prepare - %s", err) + t.Fatalf("could not prepare - %s", err) } var email string if err = stmt1.QueryRow("jane").Scan(&email); err != nil { - t.Fatalf(driver+": could not scan email - %s", err) + t.Fatalf("could not scan email - %s", err) } _, err = stmt2.Exec("mark", "mark.spencer@gmail.com") if err != nil { - t.Fatalf(driver+": should have inserted user - %s", err) + t.Fatalf("should have inserted user - %s", err) } }) } func TestShouldCloseRootDB(t *testing.T) { - txDrivers.Run(t, func(t *testing.T, driver string) { - db1, err := sql.Open(driver, "first") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db1, err := sql.Open(driver.name, "first") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db1.Close() stmt, err := db1.Prepare("SELECT * FROM users") if err != nil { - t.Fatalf(driver+": could not prepare - %s", err) + t.Fatalf("could not prepare - %s", err) } defer stmt.Close() drv1 := db1.Driver().(*txdb.TxDriver) if drv1.DB() == nil { - t.Fatalf(driver+": expected database, drv1.db: %v", drv1.DB()) + t.Fatalf("expected database, drv1.db: %v", drv1.DB()) } - db2, err := sql.Open(driver, "second") + db2, err := sql.Open(driver.name, "second") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db2.Close() stmt, err = db2.Prepare("SELECT * FROM users") if err != nil { - t.Fatalf(driver+": could not prepare - %s", err) + t.Fatalf("could not prepare - %s", err) } defer stmt.Close() // Both drivers share the same database. drv2 := db2.Driver().(*txdb.TxDriver) if drv2.DB() != drv1.DB() { - t.Fatalf(driver+": drv1.db=%v != drv2.db=%v", drv1.DB(), drv2.DB()) + t.Fatalf("drv1.db=%v != drv2.db=%v", drv1.DB(), drv2.DB()) } // Database should remain open while a connection is open. if err := db1.Close(); err != nil { - t.Fatalf(driver+": could not close database - %s", err) + t.Fatalf("could not close database - %s", err) } if drv1.DB() == nil { - t.Fatal(driver + ": expected database, not nil") + t.Fatal("expected database, not nil") } if drv2.DB() == nil { - t.Fatal(driver + ": expected database ,not nil") + t.Fatal("expected database ,not nil") } // Database should close after last connection is closed. if err := db2.Close(); err != nil { - t.Fatalf(driver+": could not close database - %s", err) + t.Fatalf("could not close database - %s", err) } if drv1.DB() != nil { - t.Fatalf(driver+": expected closed database, not %v", drv1.DB()) + t.Fatalf("expected closed database, not %v", drv1.DB()) } if drv2.DB() != nil { - t.Fatalf(driver+": expected closed database, not %v", drv2.DB()) + t.Fatalf("expected closed database, not %v", drv2.DB()) } }) } func TestShouldReopenAfterClose(t *testing.T) { - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "first") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "first") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() stmt, err := db.Prepare("SELECT * FROM users") if err != nil { - t.Fatalf(driver+": could not prepare - %s", err) + t.Fatalf("could not prepare - %s", err) } defer stmt.Close() if err := db.Close(); err != nil { - t.Fatalf(driver+": could not close database - %s", err) + t.Fatalf("could not close database - %s", err) } if err := db.Ping(); err.Error() != "sql: database is closed" { - t.Fatalf(driver+": expected closed database - %s", err) + t.Fatalf("expected closed database - %s", err) } - db, err = sql.Open(driver, "second") + db, err = sql.Open(driver.name, "second") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() if err := db.Ping(); err != nil { - t.Fatalf(driver+": failed to ping, have you run 'make test'? err: %s", err) + t.Fatalf("failed to ping: %s", err) } }) } @@ -490,56 +495,54 @@ func (canceledContext) Err() error { return errors.New("c func (canceledContext) Value(key interface{}) interface{} { return nil } func TestShouldDiscardConnectionWhenClosedBecauseOfError(t *testing.T) { - txDrivers.Run(t, func(t *testing.T, driver string) { - t.Run(fmt.Sprintf("using driver %s", driver), func(t *testing.T) { - { - db, err := sql.Open(driver, "first") - if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) - } - defer db.Close() - - tx, err := db.Begin() - defer func() { - err = tx.Rollback() - if err != nil { - t.Fatalf(driver+": rollback err: %s", err) - } - }() + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + { + db, err := sql.Open(driver.name, "first") + if err != nil { + t.Fatalf("failed to open a connection: %s", err) + } + defer db.Close() + + tx, err := db.Begin() + defer func() { + err = tx.Rollback() if err != nil { - t.Fatalf(driver+": failed to begin transaction err: %s", err) + t.Fatalf("rollback err: %s", err) } + }() + if err != nil { + t.Fatalf("failed to begin transaction err: %s", err) + } - // TODO: we somehow need to poison the DB connection here so that Rollback fails + // TODO: we somehow need to poison the DB connection here so that Rollback fails - _, err = tx.PrepareContext(canceledContext{}, "SELECT * FROM users") - if err == nil { - t.Fatalf(driver + ": should have returned error for prepare") - } + _, err = tx.PrepareContext(canceledContext{}, "SELECT * FROM users") + if err == nil { + t.Fatal("should have returned error for prepare") } + } - fmt.Println("Opening db...") + fmt.Println("Opening db...") - { - db, err := sql.Open(driver, "second") - if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) - } - defer db.Close() + { + db, err := sql.Open(driver.name, "second") + if err != nil { + t.Fatalf("failed to open a connection: %s", err) + } + defer db.Close() - if err := db.Ping(); err != nil { - t.Fatalf(driver+": failed to ping, have you run 'make test'? err: %s", err) - } + if err := db.Ping(); err != nil { + t.Fatalf("failed to ping: %s", err) } - }) + } }) } func TestPostgresRowsScanTypeTables(t *testing.T) { - txDrivers.drivers("postgres").Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "scantype") + txDrivers.drivers("postgres").Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "scantype") if err != nil { - t.Fatalf("psql: failed to open a postgres connection, have you run 'make test'? err: %s", err) + t.Fatalf("psql: failed to open a postgres connection: %s", err) } defer db.Close() @@ -561,10 +564,10 @@ func TestPostgresRowsScanTypeTables(t *testing.T) { } func TestMysqlShouldBeAbleToLockTables(t *testing.T) { - txDrivers.drivers("mysql").Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "locks") + txDrivers.drivers("mysql").Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "locks") if err != nil { - t.Fatalf("mysql: failed to open a mysql connection, have you run 'make test'? err: %s", err) + t.Fatalf("mysql: failed to open a mysql connection: %s", err) } defer db.Close() @@ -591,16 +594,16 @@ func TestMysqlShouldBeAbleToLockTables(t *testing.T) { func TestShouldGetMultiRowSet(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "multiRows") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "multiRows") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() rows, err := db.QueryContext(context.Background(), "SELECT username FROM users; SELECT COUNT(*) FROM users;") if err != nil { - t.Fatalf(driver+": failed to query users: %s", err) + t.Fatalf("failed to query users: %s", err) } defer rows.Close() @@ -608,143 +611,141 @@ func TestShouldGetMultiRowSet(t *testing.T) { for rows.Next() { var name string if err := rows.Scan(&name); err != nil { - t.Fatalf(driver+": unexpected row scan err: %v", err) + t.Fatalf("unexpected row scan err: %v", err) } users = append(users, name) } if !rows.NextResultSet() { - t.Fatal(driver + ": expected next result set") + t.Fatal("expected next result set") } if !rows.Next() { - t.Fatal(driver + ": expected next result set - row") + t.Fatal("expected next result set - row") } var count int if err := rows.Scan(&count); err != nil { - t.Fatalf(driver+": unexpected row scan err: %v", err) + t.Fatalf("unexpected row scan err: %v", err) } if count != len(users) { - t.Fatal(driver + ": unexpected number of users") + t.Fatal("unexpected number of users") } }) } func TestShouldBeAbleToPingWithContext(t *testing.T) { - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "ping") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "ping") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() if err := db.PingContext(context.Background()); err != nil { - t.Fatalf(driver+": %v", err) + t.Fatalf("%v", err) } }) } func TestShouldHandleStmtsWithoutContextPollution(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - t.Run(driver, func(t *testing.T) { - db, err := sql.Open(driver, "contextpollution") - if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) - } - defer db.Close() + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "contextpollution") + if err != nil { + t.Fatalf("failed to open a connection: %s", err) + } + defer db.Close() - insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)" - if strings.Index(driver, "psql_") == 0 { - insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)" - } + insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)" + if strings.Index(driver.name, "psql_") == 0 { + insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)" + } - ctx1, cancel1 := context.WithCancel(context.Background()) - defer cancel1() + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() - _, err = db.ExecContext(ctx1, insertSQL, "first", "first@foo.com") - if err != nil { - t.Fatalf("unexpected error inserting user 1: %s", err) - } - cancel1() + _, err = db.ExecContext(ctx1, insertSQL, "first", "first@foo.com") + if err != nil { + t.Fatalf("unexpected error inserting user 1: %s", err) + } + cancel1() - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() - _, err = db.ExecContext(ctx2, insertSQL, "second", "second@foo.com") - if err != nil { - t.Fatalf("unexpected error inserting user 2: %s", err) - } - cancel2() + _, err = db.ExecContext(ctx2, insertSQL, "second", "second@foo.com") + if err != nil { + t.Fatalf("unexpected error inserting user 2: %s", err) + } + cancel2() - const selectQuery = ` + const selectQuery = ` select username from users where username = 'first' OR username = 'second'` - rows, err := db.QueryContext(context.Background(), selectQuery) - if err != nil { - t.Fatalf("unexpected error querying users: %s", err) - } - defer rows.Close() - - assertRows := func(t *testing.T, rows *sql.Rows) { - t.Helper() - - var users []string - for rows.Next() { - var user string - err := rows.Scan(&user) - if err != nil { - t.Errorf("unexpected scan failure: %s", err) - continue - } - users = append(users, user) - } - sort.Strings(users) + rows, err := db.QueryContext(context.Background(), selectQuery) + if err != nil { + t.Fatalf("unexpected error querying users: %s", err) + } + defer rows.Close() - wanted := []string{"first", "second"} + assertRows := func(t *testing.T, rows *sql.Rows) { + t.Helper() - if len(users) != 2 { - t.Fatalf("invalid users received; want=%v\tgot=%v", wanted, users) + var users []string + for rows.Next() { + var user string + err := rows.Scan(&user) + if err != nil { + t.Errorf("unexpected scan failure: %s", err) + continue } - for i, want := range wanted { - if got := users[i]; want != got { - t.Errorf("invalid user; want=%s\tgot=%s", want, got) - } + users = append(users, user) + } + sort.Strings(users) + + wanted := []string{"first", "second"} + + if len(users) != 2 { + t.Fatalf("invalid users received; want=%v\tgot=%v", wanted, users) + } + for i, want := range wanted { + if got := users[i]; want != got { + t.Errorf("invalid user; want=%s\tgot=%s", want, got) } } + } - assertRows(t, rows) + assertRows(t, rows) - ctx3, cancel3 := context.WithCancel(context.Background()) - defer cancel3() + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() - stmt, err := db.PrepareContext(ctx3, selectQuery) - if err != nil { - t.Fatalf("unexpected error preparing stmt: %s", err) - } + stmt, err := db.PrepareContext(ctx3, selectQuery) + if err != nil { + t.Fatalf("unexpected error preparing stmt: %s", err) + } - rows, err = stmt.QueryContext(context.TODO()) - if err != nil { - t.Fatalf("unexpected error in stmt querying users: %s", err) - } - defer rows.Close() + rows, err = stmt.QueryContext(context.TODO()) + if err != nil { + t.Fatalf("unexpected error in stmt querying users: %s", err) + } + defer rows.Close() - assertRows(t, rows) - }) + assertRows(t, rows) }) } // https://github.com/DATA-DOG/go-txdb/issues/49 func TestIssue49(t *testing.T) { t.Parallel() - txDrivers.Run(t, func(t *testing.T, driver string) { - db, err := sql.Open(driver, "rollback") + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "rollback") if err != nil { - t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + t.Fatalf("failed to open a connection: %s", err) } defer db.Close() @@ -753,7 +754,7 @@ func TestIssue49(t *testing.T) { var count int err = db.QueryRow("SELECT COUNT(id) FROM users").Scan(&count) if err != nil { - t.Fatalf(driver+": prepared statement count err %v", err) + t.Fatalf("prepared statement count err %v", err) } if count != 3 { t.Logf("Count not 3: %d", count) @@ -763,32 +764,32 @@ func TestIssue49(t *testing.T) { // start a nested transaction tx, err := db.Begin() if err != nil { - t.Fatalf(driver+": failed to start transaction: %s", err) + t.Fatalf("failed to start transaction: %s", err) } // need a prepared statement to reproduce the error insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)" - if strings.Index(driver, "psql_") == 0 { + if strings.Index(driver.name, "psql_") == 0 { insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)" } stmt, err := tx.Prepare(insertSQL) if err != nil { - t.Fatalf(driver+": failed to prepare named statement: %s", err) + t.Fatalf("failed to prepare named statement: %s", err) } // try to insert already existing username/email _, err = stmt.Exec("gopher", "gopher@go.com") if err == nil { - t.Fatalf(driver + ": double insert?") + t.Fatal("double insert?") } // The insert failed, so we need to close the prepared statement err = stmt.Close() if err != nil { - t.Fatalf(driver+": error closing prepared statement: %s", err) + t.Fatalf("error closing prepared statement: %s", err) } // rollback the transaction now that it has failed err = tx.Rollback() if err != nil { - t.Logf(driver+": failed rollback of failed transaction: %s", err) + t.Logf("failed rollback of failed transaction: %s", err) t.FailNow() } })