From 84ae2bc4a87a432b1400474f3d48dd08961ecf40 Mon Sep 17 00:00:00 2001 From: Lann Date: Wed, 13 Oct 2021 14:18:41 -0400 Subject: [PATCH] Fix Select subquery with DollarPlaceholder (#298) Fixes #286 --- case.go | 2 +- part.go | 10 +++++++++- select.go | 16 ++++++---------- select_test.go | 17 +++++++++++++++++ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/case.go b/case.go index e3b099b..299e14b 100644 --- a/case.go +++ b/case.go @@ -27,7 +27,7 @@ func (b *sqlizerBuffer) WriteSql(item Sqlizer) { var str string var args []interface{} - str, args, b.err = item.ToSql() + str, args, b.err = nestedToSql(item) if b.err != nil { return diff --git a/part.go b/part.go index 2926d03..f3a7b15 100644 --- a/part.go +++ b/part.go @@ -29,9 +29,17 @@ func (p part) ToSql() (sql string, args []interface{}, err error) { return } +func nestedToSql(s Sqlizer) (string, []interface{}, error) { + if raw, ok := s.(rawSqlizer); ok { + return raw.toSqlRaw() + } else { + return s.ToSql() + } +} + func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []interface{}) ([]interface{}, error) { for i, p := range parts { - partSql, partArgs, err := p.ToSql() + partSql, partArgs, err := nestedToSql(p) if err != nil { return nil, err } else if len(partSql) == 0 { diff --git a/select.go b/select.go index 48f4f73..b585344 100644 --- a/select.go +++ b/select.go @@ -52,7 +52,7 @@ func (d *selectData) QueryRow() RowScanner { } func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) { - sqlStr, args, err = d.toSql() + sqlStr, args, err = d.toSqlRaw() if err != nil { return } @@ -62,10 +62,6 @@ func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) { } func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { - return d.toSql() -} - -func (d *selectData) toSql() (sqlStr string, args []interface{}, err error) { if len(d.Columns) == 0 { err = fmt.Errorf("select statements must have at least one result column") return @@ -222,6 +218,11 @@ func (b SelectBuilder) ToSql() (string, []interface{}, error) { return data.ToSql() } +func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) { + data := builder.GetStruct(b).(selectData) + return data.toSqlRaw() +} + // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b SelectBuilder) MustSql() (string, []interface{}) { @@ -232,11 +233,6 @@ func (b SelectBuilder) MustSql() (string, []interface{}) { return sql, args } -func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) { - data := builder.GetStruct(b).(selectData) - return data.toSqlRaw() -} - // Prefix adds an expression to the beginning of the query func (b SelectBuilder) Prefix(sql string, args ...interface{}) SelectBuilder { return b.PrefixExpr(Expr(sql, args...)) diff --git a/select_test.go b/select_test.go index ce7a069..08cfedb 100644 --- a/select_test.go +++ b/select_test.go @@ -247,6 +247,23 @@ func TestSelectWithEmptyStringWhereClause(t *testing.T) { assert.Equal(t, "SELECT * FROM users", sql) } +func TestSelectSubqueryPlaceholderNumbering(t *testing.T) { + subquery := Select("a").Where("b = ?", 1).PlaceholderFormat(Dollar) + with := subquery.Prefix("WITH a AS (").Suffix(")") + + sql, args, err := Select("*"). + PrefixExpr(with). + FromSelect(subquery, "q"). + Where("c = ?", 2). + PlaceholderFormat(Dollar). + ToSql() + assert.NoError(t, err) + + expectedSql := "WITH a AS ( SELECT a WHERE b = $1 ) SELECT * FROM (SELECT a WHERE b = $2) AS q WHERE c = $3" + assert.Equal(t, expectedSql, sql) + assert.Equal(t, []interface{}{1, 1, 2}, args) +} + func ExampleSelect() { Select("id", "created", "first_name").From("users") // ... continue building up your query