Skip to content

Commit

Permalink
Query supports QueryExecMode
Browse files Browse the repository at this point in the history
Fixed QueryExecModeExec as it must only use text format without
specifying param OIDs.
  • Loading branch information
jackc committed Mar 12, 2022
1 parent 0c166c7 commit 1390a11
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 89 deletions.
213 changes: 135 additions & 78 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ var ErrNoRows = errors.New("no rows in result set")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")

var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")

// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
func Connect(ctx context.Context, connString string) (*Conn, error) {
Expand Down Expand Up @@ -430,7 +433,7 @@ optionLoop:
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
return pgconn.CommandTag{}, errDisabledStatementCache
}
sd, err := c.statementCache.Get(ctx, sql)
if err != nil {
Expand All @@ -440,7 +443,7 @@ optionLoop:
return c.execPrepared(ctx, sd, arguments)
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
return pgconn.CommandTag{}, errDisabledDescriptionCache
}
sd, err := c.descriptionCache.Get(ctx, sql)
if err != nil {
Expand Down Expand Up @@ -536,24 +539,49 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{}
c.eqb.Reset()

anynil.NormalizeSlice(args)
err := c.appendParamsForQueryExecModeExec(args)
if err != nil {
return pgconn.CommandTag{}, err
}

paramOIDs := make([]uint32, len(args))
result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read()
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
return result.CommandTag, result.Err
}

// appendParamsForQueryExecModeExec appends the args to c.eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error {
for i := range args {
dt, ok := c.TypeMap().TypeForValue(args[i])
if !ok {
return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
}
err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i])
if err != nil {
return pgconn.CommandTag{}, err
if args[i] == nil {
err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, args[i])
if err != nil {
return err
}
} else {
dt, ok := c.TypeMap().TypeForValue(args[i])
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: args[i]}
}
err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, args[i])
if err != nil {
return err
}
}
paramOIDs[i] = dt.OID
}

result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
return result.CommandTag, result.Err
return nil
}

func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
Expand Down Expand Up @@ -589,14 +617,11 @@ const (
// when the the database schema is modified concurrently.
QueryExecModeDescribeExec

// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended
// protocol. Queries are executed in a single round trip. Type mappings can be registered with
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g.
// "SELECT $1::boolean".
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
QueryExecModeExec

// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
Expand All @@ -605,8 +630,13 @@ const (
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
// a map[string]string directly as an argument. This mode cannot.
//
// This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to
// specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean".
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
//
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
// not support the extended protocol.
QueryExecModeSimpleProtocol
)

Expand Down Expand Up @@ -640,13 +670,13 @@ type QueryResultFormatsByOID map[uint32]int16
// Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully
// as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row.
//
// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and
// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
mode := c.config.DefaultQueryExecMode

optionLoop:
for len(args) > 0 {
Expand All @@ -658,91 +688,118 @@ optionLoop:
resultFormatsByOID = arg
args = args[1:]
case QueryExecMode:
simpleProtocol = arg == QueryExecModeSimpleProtocol
mode = arg
args = args[1:]
default:
break optionLoop
}
}

c.eqb.Reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)

var err error
sd, ok := c.preparedStatements[sql]

if simpleProtocol && !ok {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
sd := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
err = errDisabledStatementCache
rows.fatal(err)
return rows, err
}
sd, err = c.statementCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd, err = c.descriptionCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
}
}

mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
rows.resultReader = mrr.ResultReader()
rows.multiResultReader = mrr
} else {
err = mrr.Close()
rows.fatal(err)
return rows, err
if len(sd.ParamOIDs) != len(args) {
rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err
}

return rows, nil
}
rows.sql = sd.SQL

c.eqb.Reset()

if !ok {
if c.statementCache != nil {
sd, err = c.statementCache.Get(ctx, sql)
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
if err != nil {
rows.fatal(err)
return rows, rows.err
}
} else {
sd, err = c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}

if resultFormatsByOID != nil {
resultFormats = make([]int16, len(sd.Fields))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
}
}
}
if len(sd.ParamOIDs) != len(args) {
rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err
}

rows.sql = sd.SQL
if resultFormats == nil {
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}

anynil.NormalizeSlice(args)
resultFormats = c.eqb.resultFormats
}

for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
if mode == QueryExecModeCacheDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
}
} else if mode == QueryExecModeExec {
err := c.appendParamsForQueryExecModeExec(args)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}

if resultFormatsByOID != nil {
resultFormats = make([]int16, len(sd.Fields))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
} else if mode == QueryExecModeSimpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
}

if resultFormats == nil {
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
rows.resultReader = mrr.ResultReader()
rows.multiResultReader = mrr
} else {
err = mrr.Close()
rows.fatal(err)
return rows, err
}

resultFormats = c.eqb.resultFormats
}

if c.statementCache != nil && c.statementCache.Mode() == stmtcache.ModeDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
return rows, nil
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
err = fmt.Errorf("unknown QueryExecMode: %v", mode)
rows.fatal(err)
return rows, rows.err
}

c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
Expand Down
10 changes: 1 addition & 9 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,7 @@ func TestExecFailureWithArguments(t *testing.T) {
assert.False(t, pgconn.SafeToRetry(err))

_, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2")
if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec {
// The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it
// locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check
// for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing
// the SQL.
require.NoError(t, err)
} else {
require.Error(t, err)
}
require.Error(t, err)
})
}

Expand Down
8 changes: 6 additions & 2 deletions extended_query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ type extendedQueryBuilder struct {

func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error {
f := eqb.chooseParameterFormatCode(m, oid, arg)
eqb.paramFormats = append(eqb.paramFormats, f)
return eqb.AppendParamFormat(m, oid, f, arg)
}

func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg interface{}) error {
eqb.paramFormats = append(eqb.paramFormats, format)

v, err := eqb.encodeExtendedParamValue(m, oid, f, arg)
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil {
return err
}
Expand Down
13 changes: 13 additions & 0 deletions values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,19 @@ func TestEncodeTypeRename(t *testing.T) {
inString := _string("foo")
var outString _string

// pgx.QueryExecModeExec requires all types to be registered.
conn.TypeMap().RegisterDefaultPgType(inInt, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt8, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt16, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt32, "int8")
conn.TypeMap().RegisterDefaultPgType(inInt64, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint8, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint16, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint32, "int8")
conn.TypeMap().RegisterDefaultPgType(inUint64, "int8")
conn.TypeMap().RegisterDefaultPgType(inString, "text")

err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
Expand Down

0 comments on commit 1390a11

Please sign in to comment.