Skip to content

Commit

Permalink
Add Scan and Value methods to FlatArray and Array
Browse files Browse the repository at this point in the history
Support direct usage in database/sql.

#1458
#1779
#1956
#1662
  • Loading branch information
jackc committed May 19, 2024
1 parent 7328897 commit 06d85e0
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 23 deletions.
123 changes: 123 additions & 0 deletions pgtype/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgtype

import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -419,6 +420,53 @@ func (a Array[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// Array needs a *Map to decode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (a *Array[T]) Scan(v any) error {
if v == nil {
*a = Array[T]{}
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), a)
case []byte:
return m.Scan(0, 0, v, a)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// Array needs a *Map to encode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src Array[T]) Value() (driver.Value, error) {
if !src.Valid {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}

// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
// and custom lower bounds. Use Array to preserve these.
type FlatArray[T any] []T
Expand Down Expand Up @@ -458,3 +506,78 @@ func (a FlatArray[T]) ScanIndex(i int) any {
func (a FlatArray[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// FlatArray needs a *Map to decode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (a *FlatArray[T]) Scan(v any) error {
if v == nil {
*a = nil
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), a)
case []byte:
return m.Scan(0, 0, v, a)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// FlatArray needs a *Map to encode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src FlatArray[T]) Value() (driver.Value, error) {
if src == nil {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}

// flatArrayForWrapper is FlatArray without the Scan and Value methods. The avoids a wrapped plan attempt always
// "succeeding".
type flatArrayForWrapper[T any] []T

func (a flatArrayForWrapper[T]) Dimensions() []ArrayDimension {
return FlatArray[T](a).Dimensions()
}

func (a flatArrayForWrapper[T]) Index(i int) any {
return FlatArray[T](a).Index(i)
}

func (a flatArrayForWrapper[T]) IndexType() any {
return FlatArray[T](a).IndexType()
}

func (a *flatArrayForWrapper[T]) SetDimensions(dimensions []ArrayDimension) error {
return (*FlatArray[T])(a).SetDimensions(dimensions)
}

func (a flatArrayForWrapper[T]) ScanIndex(i int) any {
return FlatArray[T](a).ScanIndex(i)
}

func (a flatArrayForWrapper[T]) ScanIndexType() any {
return FlatArray[T](a).ScanIndexType()
}
41 changes: 25 additions & 16 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net"
"net/netip"
"reflect"
"sync"
"time"
)

Expand Down Expand Up @@ -123,6 +124,14 @@ const (
Int8multirangeArrayOID = 6157
)

// databaseSQLMapPool is a sync.Pool that holds *Map instances used for implementing sql.Scanner and driver.Valuer on
// types that need a *Map to encode and decode such as FlatArray[T].
var databaseSQLMapPool = sync.Pool{
New: func() any {
return NewMap()
},
}

type InfinityModifier int8

const (
Expand Down Expand Up @@ -932,19 +941,19 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa
// Avoid using reflect path for common types.
switch target := target.(type) {
case *[]int16:
return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true
return &wrapPtrSliceScanPlan[int16]{}, (*flatArrayForWrapper[int16])(target), true
case *[]int32:
return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true
return &wrapPtrSliceScanPlan[int32]{}, (*flatArrayForWrapper[int32])(target), true
case *[]int64:
return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true
return &wrapPtrSliceScanPlan[int64]{}, (*flatArrayForWrapper[int64])(target), true
case *[]float32:
return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true
return &wrapPtrSliceScanPlan[float32]{}, (*flatArrayForWrapper[float32])(target), true
case *[]float64:
return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true
return &wrapPtrSliceScanPlan[float64]{}, (*flatArrayForWrapper[float64])(target), true
case *[]string:
return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true
return &wrapPtrSliceScanPlan[string]{}, (*flatArrayForWrapper[string])(target), true
case *[]time.Time:
return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true
return &wrapPtrSliceScanPlan[time.Time]{}, (*flatArrayForWrapper[time.Time])(target), true
}

targetType := reflect.TypeOf(target)
Expand All @@ -968,7 +977,7 @@ type wrapPtrSliceScanPlan[T any] struct {
func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next }

func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error {
return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T)))
return plan.next.Scan(src, (*flatArrayForWrapper[T])(target.(*[]T)))
}

type wrapPtrSliceReflectScanPlan struct {
Expand Down Expand Up @@ -1773,19 +1782,19 @@ func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextVa
// Avoid using reflect path for common types.
switch value := value.(type) {
case []int16:
return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true
return &wrapSliceEncodePlan[int16]{}, (flatArrayForWrapper[int16])(value), true
case []int32:
return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true
return &wrapSliceEncodePlan[int32]{}, (flatArrayForWrapper[int32])(value), true
case []int64:
return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true
return &wrapSliceEncodePlan[int64]{}, (flatArrayForWrapper[int64])(value), true
case []float32:
return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true
return &wrapSliceEncodePlan[float32]{}, (flatArrayForWrapper[float32])(value), true
case []float64:
return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true
return &wrapSliceEncodePlan[float64]{}, (flatArrayForWrapper[float64])(value), true
case []string:
return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true
return &wrapSliceEncodePlan[string]{}, (flatArrayForWrapper[string])(value), true
case []time.Time:
return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true
return &wrapSliceEncodePlan[time.Time]{}, (flatArrayForWrapper[time.Time])(value), true
}

if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Slice {
Expand All @@ -1805,7 +1814,7 @@ type wrapSliceEncodePlan[T any] struct {
func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next }

func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
return plan.next.Encode((FlatArray[T])(value.([]T)), buf)
return plan.next.Encode((flatArrayForWrapper[T])(value.([]T)), buf)
}

type wrapSliceEncodeReflectPlan struct {
Expand Down
51 changes: 51 additions & 0 deletions stdlib/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"strings"
"testing"
"time"

"github.com/jackc/pgx/v5/pgtype"
)

func getSelectRowsCounts(b *testing.B) []int64 {
Expand Down Expand Up @@ -107,3 +109,52 @@ func BenchmarkSelectRowsScanNull(b *testing.B) {
})
}
}

func BenchmarkFlatArrayEncodeArgument(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)

input := make(pgtype.FlatArray[string], 10)
for i := range input {
input[i] = fmt.Sprintf("String %d", i)
}

b.ResetTimer()

for i := 0; i < b.N; i++ {
var n int64
err := db.QueryRow("select cardinality($1::text[])", input).Scan(&n)
if err != nil {
b.Fatal(err)
}
if n != int64(len(input)) {
b.Fatalf("Expected %d, got %d", len(input), n)
}
}
}

func BenchmarkFlatArrayScanResult(b *testing.B) {
db := openDB(b)
defer closeDB(b, db)

var input string
for i := 0; i < 10; i++ {
if i > 0 {
input += ","
}
input += fmt.Sprintf(`'String %d'`, i)
}

b.ResetTimer()

for i := 0; i < b.N; i++ {
var result pgtype.FlatArray[string]
err := db.QueryRow(fmt.Sprintf("select array[%s]::text[]", input)).Scan(&result)
if err != nil {
b.Fatal(err)
}
if len(result) != 10 {
b.Fatalf("Expected %d, got %d", len(result), 10)
}
}
}
45 changes: 38 additions & 7 deletions stdlib/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,18 +509,49 @@ func TestConnQueryScanGoArray(t *testing.T) {
})
}

func TestConnQueryScanArray(t *testing.T) {
func TestPGTypeFlatArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var names pgtype.FlatArray[string]

var a pgtype.Array[int64]
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names)
require.NoError(t, err)
require.Equal(t, pgtype.FlatArray[string]{"John", "Jane"}, names)

var n int
err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)

err = db.QueryRow("select null::text[]").Scan(&names)
require.NoError(t, err)
require.Nil(t, names)
})
}

func TestPGTypeArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
var matrix pgtype.Array[int64]

err := db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[]").Scan(&matrix)
require.NoError(t, err)
require.Equal(t,
pgtype.Array[int64]{
Elements: []int64{1, 2, 3, 4, 5, 6},
Dims: []pgtype.ArrayDimension{
{Length: 2, LowerBound: 1},
{Length: 3, LowerBound: 1},
},
Valid: true},
matrix)

var equal bool
err = db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[] = $1::bigint[]", matrix).Scan(&equal)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
require.Equal(t, true, equal)

err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
err = db.QueryRow("select null::bigint[]").Scan(&matrix)
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, matrix)
})
}

Expand Down

0 comments on commit 06d85e0

Please sign in to comment.