Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ALTER TABLE query #726

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion bun.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type AfterDropTableHook interface {
AfterDropTable(ctx context.Context, query *DropTableQuery) error
}

// SetLogger overwriters default Bun logger.
// SetLogger overwrites default Bun logger.
func SetLogger(logger internal.Logging) {
internal.Logger = logger
}
Expand Down
18 changes: 11 additions & 7 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ func (db *DB) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(db)
}

func (db *DB) NewAlterTable() *AlterTableQuery {
return NewAlterTableQuery(db)
}

func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error {
for _, model := range models {
if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil {
Expand Down Expand Up @@ -195,7 +199,7 @@ func (db *DB) Table(typ reflect.Type) *schema.Table {
return db.dialect.Tables().Get(typ)
}

// RegisterModel registers models by name so they can be referenced in table relations
// RegisterModel registers models by name, so they can be referenced in table relations
// and fixtures.
func (db *DB) RegisterModel(models ...interface{}) {
db.dialect.Tables().Register(models...)
Expand Down Expand Up @@ -234,7 +238,7 @@ func (db *DB) HasFeature(feat feature.Feature) bool {
return db.fmter.HasFeature(feat)
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
return db.ExecContext(context.Background(), query, args...)
Expand Down Expand Up @@ -280,7 +284,7 @@ func (db *DB) format(query string, args []interface{}) string {
return db.fmter.FormatQuery(query, args...)
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type Conn struct {
db *DB
Expand Down Expand Up @@ -426,7 +430,7 @@ func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
}, nil
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type Stmt struct {
*sql.Stmt
Expand All @@ -444,7 +448,7 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
return Stmt{Stmt: stmt}, nil
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type Tx struct {
ctx context.Context
Expand Down Expand Up @@ -584,7 +588,7 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac
return row
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (tx Tx) Begin() (Tx, error) {
return tx.BeginTx(tx.ctx, nil)
Expand Down Expand Up @@ -700,7 +704,7 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(tx.db).Conn(tx)
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
Expand Down
3 changes: 2 additions & 1 deletion dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ const (
UpdateFromTable
MSSavepoint
GeneratedIdentity
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
AlterTableQuery // ALTER TABLE (temporary, all dialects must support it)
)
3 changes: 2 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ func New() *Dialect {
feature.InsertOnConflict |
feature.SelectExists |
feature.GeneratedIdentity |
feature.CompositeIn
feature.CompositeIn |
feature.AlterTableQuery
return d
}

Expand Down
45 changes: 17 additions & 28 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ func funcName(x interface{}) string {
return s
}

func skipIfNotHasFeature(tb testing.TB, db *bun.DB, feat feature.Feature, featName string) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: skipIfNoFeature

tb.Helper()
if !db.HasFeature(feat) {
tb.Skipf("%q dialect does not support %q", db.Dialect().Name(), featName)
}
}

func TestDB(t *testing.T) {
type Test struct {
run func(t *testing.T, db *bun.DB)
Expand Down Expand Up @@ -320,10 +327,7 @@ func testSelectScan(t *testing.T, db *bun.DB) {
}

func testSelectCount(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
return
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"num": 1},
Expand Down Expand Up @@ -365,9 +369,7 @@ func testSelectMap(t *testing.T, db *bun.DB) {
}

func testSelectMapSlice(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"column1": 1},
Expand Down Expand Up @@ -464,9 +466,7 @@ func testSelectNestedStructPtr(t *testing.T, db *bun.DB) {
}

func testSelectStructSlice(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

type Model struct {
Num int `bun:"column1"`
Expand All @@ -491,9 +491,7 @@ func testSelectStructSlice(t *testing.T, db *bun.DB) {
}

func testSelectSingleSlice(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"column1": 1},
Expand All @@ -511,9 +509,7 @@ func testSelectSingleSlice(t *testing.T, db *bun.DB) {
}

func testSelectMultiSlice(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"a": 1, "b": "foo"},
Expand Down Expand Up @@ -616,9 +612,7 @@ func testScanSingleRow(t *testing.T, db *bun.DB) {
}

func testScanSingleRowByRow(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"num": 1},
Expand Down Expand Up @@ -650,9 +644,7 @@ func testScanSingleRowByRow(t *testing.T, db *bun.DB) {
}

func testScanRows(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

values := db.NewValues(&[]map[string]interface{}{
{"num": 1},
Expand Down Expand Up @@ -940,7 +932,7 @@ func testWithForeignKeysAndRules(t *testing.T, db *bun.DB) {
_, err = db.NewInsert().Model(new(Deck)).Exec(ctx)
require.Error(t, err)

// Create a deck that violates the user_id FK contraint
// Create a deck that violates the user_id FK constraint
deck := &Deck{UserID: 42}

_, err = db.NewInsert().Model(deck).Exec(ctx)
Expand Down Expand Up @@ -1028,7 +1020,7 @@ func testWithForeignKeys(t *testing.T, db *bun.DB) {
_, err = db.NewInsert().Model(new(Deck)).Exec(ctx)
require.Error(t, err)

// Create a deck that violates the user_id FK contraint
// Create a deck that violates the user_id FK constraint
deck := &Deck{UserID: 42}

_, err = db.NewInsert().Model(deck).Exec(ctx)
Expand Down Expand Up @@ -1245,10 +1237,7 @@ func testUpsert(t *testing.T, db *bun.DB) {
}

func testMultiUpdate(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
return
}
skipIfNotHasFeature(t, db, feature.CTE, "CTE")

type Model struct {
ID int64 `bun:",pk,autoincrement"`
Expand Down
131 changes: 129 additions & 2 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package dbtest_test
import (
"encoding/json"
"fmt"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"path/filepath"
"regexp"
"testing"
Expand All @@ -19,6 +21,8 @@ func init() {
cupaloy.Global = cupaloy.Global.WithOptions(cupaloy.SnapshotSubdirectory(snapshotsDir))
}

var timeRE = regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)

func TestQuery(t *testing.T) {
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Expand Down Expand Up @@ -1003,8 +1007,6 @@ func TestQuery(t *testing.T) {
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
for i, fn := range queries {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
Expand All @@ -1021,3 +1023,128 @@ func TestQuery(t *testing.T) {
}
})
}

func TestAlterTable(t *testing.T) {
needsAlterTable := needs(feature.AlterTableQuery, "ALTER TABLE")
type Model struct {
ID int64
Old string
Active bool
}

cases := []TestCase{
{
"rename column", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).
RenameColumn().Column("old").To("new")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reasons why not simply RenameColumn("old_name", "new_name")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

},
},
{
"invalid model", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model(1)
},
},
{
"rename table", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).
Rename().To("new_models")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not RenameTable("new_name")?

Copy link
Contributor Author

@bevzzz bevzzz Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, that's actually what I started with, but I wasn't sure if it wasn't more consistent with the rest of the API to mirror actual the SQL keywords.
I agree that that what you've suggested is much more readable and explicit though.

},
},
{
"rename table if exists", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).IfExists().
Rename().To("new_models")
},
},
{
"change column type", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
// Change column type with common SQL data type
return db.NewAlterTable().Model((*Model)(nil)).
AlterColumn().Column("old").Type(sqltype.Blob)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion AlterColumnType(colName, colType)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

},
},
{
"change column type chained", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
// Alter 2 columns in a chained query
return db.NewAlterTable().Model((*Model)(nil)).
AlterColumn().Column("old").Type(sqltype.Blob).
AlterColumn().Column("active").Type(sqltype.SmallInt)
},
},
{
"add column", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).AddColumn().
ColumnExpr("deadline ?", bun.Safe(sqltype.Timestamp))
},
},
{
"add several columns", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).
AddColumn().ColumnExpr("one INTEGER").
AddColumn().ColumnExpr("two BIGINT")
},
},
{
"drop column if exists", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).DropColumn().IfExists().Column("old")
},
},
{
"drop column expression", needsAlterTable,
func(db *bun.DB) schema.QueryAppender {
return db.NewAlterTable().Model((*Model)(nil)).DropColumn().ColumnExpr("old CASCADE")
},
},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
for i, tt := range cases {
skipIfNotHasFeature(t, db, tt.req.Feature, tt.req.Name)

t.Run(fmt.Sprintf("%d-%s", i, tt.name), func(t *testing.T) {
checkQuerySnapshot(t, db, tt.fn(db))
})
}
})
}

func checkQuerySnapshot(t *testing.T, db *bun.DB, q schema.QueryAppender) {
t.Helper()
query, err := q.AppendQuery(db.Formatter(), nil)
if err != nil {
cupaloy.SnapshotT(t, err.Error())
return
}

query = timeRE.ReplaceAll(query, []byte("[TIME]"))
cupaloy.SnapshotT(t, string(query))
}

// Example:
// tt := TestCase{
// "common table expressions", needs(feature.CTE, "CTE"),
// func(db *bun.DB) schema.QueryAppender {...},
// }
type TestCase struct {
name string
req *requiredFeature
fn func(db *bun.DB) schema.QueryAppender
}

type requiredFeature struct {
feature.Feature
Name string
}

func needs(feat feature.Feature, name string) *requiredFeature {
return &requiredFeature{feat, name}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" RENAME COLUMN "old" TO "new"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: Model(non-pointer int)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" RENAME TO "new_models"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE IF EXISTS "models" RENAME TO "new_models"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" ALTER COLUMN "old" SET DATA TYPE BLOB
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" ALTER COLUMN "old" SET DATA TYPE BLOB, ALTER COLUMN "active" SET DATA TYPE SMALLINT
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" ADD COLUMN deadline TIMESTAMP
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" ADD COLUMN one INTEGER, ADD COLUMN two BIGINT
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" DROP COLUMN IF EXISTS "old"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "models" DROP COLUMN old CASCADE