Skip to content

Commit

Permalink
Cache reflection analysis in RowToStructBy...
Browse files Browse the repository at this point in the history
Modify the RowToStructByPos/Name functions to store the computed mapping
of columns to struct field locations in a cache to reuse between calls.
Because this computation can be expensive and the same few results will
frequently be reused, caching these results provides a significant
speedup.

For positional mappings, we can key the cache by just the struct-type.
However, for named mappings, the key must include a representation of
the columns, in order, since different columns produce different
mappings.
  • Loading branch information
zolstein authored and jackc committed Apr 16, 2024
1 parent 8db0f28 commit ec98406
Showing 1 changed file with 184 additions and 74 deletions.
258 changes: 184 additions & 74 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"

"github.com/jackc/pgx/v5/pgconn"
Expand Down Expand Up @@ -541,7 +542,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
// ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}

Expand All @@ -550,62 +551,76 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
// the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}

type positionalStructRowScanner struct {
ptrToStruct any
}

func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
dst := rs.ptrToStruct
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Ptr {
return fmt.Errorf("dst not a pointer")
func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
fields := lookupStructFields(typ)
if len(rows.RawValues()) > len(fields) {
return fmt.Errorf(
"got %d values, but dst struct has only %d fields",
len(rows.RawValues()),
len(fields),
)
}

dstElemValue := dstValue.Elem()
scanTargets := rs.appendScanTargets(dstElemValue, nil)

if len(rows.RawValues()) > len(scanTargets) {
return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
}

scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}

func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
dstElemType := dstElemValue.Type()
// Map from reflect.Type -> []structRowField
var positionalStructFieldMap sync.Map

if scanTargets == nil {
scanTargets = make([]any, 0, dstElemType.NumField())
func lookupStructFields(t reflect.Type) []structRowField {
if cached, ok := positionalStructFieldMap.Load(t); ok {
return cached.([]structRowField)
}

for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i)
fieldStack := make([]int, 0, 1)
fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
return fieldsIface.([]structRowField)
}

func computeStructFields(
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) []structRowField {
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
fields = computeStructFields(sf.Type, fields, fieldStack)
} else if sf.PkgPath == "" {
dbTag, _ := sf.Tag.Lookup(structTagKey)
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
fields = append(fields, structRowField{
path: append([]int(nil), *fieldStack...),
})
}
}

return scanTargets
*fieldStack = (*fieldStack)[:tail]
return fields
}

// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByName[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}

Expand All @@ -615,7 +630,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}

Expand All @@ -624,7 +639,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return value, err
}

Expand All @@ -634,7 +649,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return &value, err
}

Expand All @@ -643,64 +658,123 @@ type namedStructRowScanner struct {
lax bool
}

func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
dst := rs.ptrToStruct
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Ptr {
return fmt.Errorf("dst not a pointer")
}

dstElemValue := dstValue.Elem()
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
fldDescs := rows.FieldDescriptions()
namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
if err != nil {
return err
}

for i, t := range scanTargets {
if t == nil {
return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
}
if rs.lax && namedStructFields.missingField != "" {
return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
}

fields := namedStructFields.fields
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}

const structTagKey = "db"

func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
i = -1
for i, desc := range fldDescs {
// Map from namedStructFieldMap -> *namedStructFields
var namedStructFieldMap sync.Map

// Snake case support.
field = strings.ReplaceAll(field, "_", "")
descName := strings.ReplaceAll(desc.Name, "_", "")
type namedStructFieldsKey struct {
t reflect.Type
colNames string
}

if strings.EqualFold(descName, field) {
return i
}
}
return
type namedStructFields struct {
fields []structRowField
// missingField is the first field from the struct without a corresponding row field.
// This is used to construct the correct error message for non-lax queries.
missingField string
}

func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
var err error
dstElemType := dstElemValue.Type()
func lookupNamedStructFields(
t reflect.Type,
fldDescs []pgconn.FieldDescription,
) (*namedStructFields, error) {
key := namedStructFieldsKey{
t: t,
colNames: joinFieldNames(fldDescs),
}
if cached, ok := namedStructFieldMap.Load(key); ok {
return cached.(*namedStructFields), nil
}

if scanTargets == nil {
scanTargets = make([]any, len(fldDescs))
// We could probably do two-levels of caching, where we compute the key -> fields mapping
// for a type only once, cache it by type, then use that to compute the column -> fields
// mapping for a given set of columns.
fieldStack := make([]int, 0, 1)
fields, missingField := computeNamedStructFields(
fldDescs,
t,
make([]structRowField, len(fldDescs)),
&fieldStack,
)
for i, f := range fields {
if f.path == nil {
return nil, fmt.Errorf(
"struct doesn't have corresponding row field %s",
fldDescs[i].Name,
)
}
}

for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i)
fieldsIface, _ := namedStructFieldMap.LoadOrStore(
key,
&namedStructFields{fields: fields, missingField: missingField},
)
return fieldsIface.(*namedStructFields), nil
}

func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
switch len(fldDescs) {
case 0:
return ""
case 1:
return fldDescs[0].Name
}

totalSize := len(fldDescs) - 1 // Space for separator bytes.
for _, d := range fldDescs {
totalSize += len(d.Name)
}
var b strings.Builder
b.Grow(totalSize)
b.WriteString(fldDescs[0].Name)
for _, d := range fldDescs[1:] {
b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
b.WriteString(d.Name)
}
return b.String()
}

func computeNamedStructFields(
fldDescs []pgconn.FieldDescription,
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) ([]structRowField, string) {
var missingField string
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
if sf.PkgPath != "" && !sf.Anonymous {
// Field is unexported, skip it.
continue
}
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
if err != nil {
return nil, err
var missingSubField string
fields, missingSubField = computeNamedStructFields(
fldDescs,
sf.Type,
fields,
fieldStack,
)
if missingField == "" {
missingField = missingSubField
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
Expand All @@ -717,17 +791,53 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
if rs.lax {
continue
if missingField == "" {
missingField = colName
}
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
continue
}
if fpos >= len(scanTargets) && !rs.lax {
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
fields[fpos] = structRowField{
path: append([]int(nil), *fieldStack...),
}
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
}
}
*fieldStack = (*fieldStack)[:tail]

return fields, missingField
}

return scanTargets, err
const structTagKey = "db"

func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
i = -1
for i, desc := range fldDescs {

// Snake case support.
field = strings.ReplaceAll(field, "_", "")
descName := strings.ReplaceAll(desc.Name, "_", "")

if strings.EqualFold(descName, field) {
return i
}
}
return
}

// structRowField describes a field of a struct.
//
// TODO: It would be a bit more efficient to track the path using the pointer
// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
// construct references when scanning rows. However, it's not clear it's worth
// using unsafe for this.
type structRowField struct {
path []int
}

func setupStructScanTargets(receiver any, fields []structRowField) []any {
scanTargets := make([]any, len(fields))
v := reflect.ValueOf(receiver).Elem()
for i, f := range fields {
scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
}
return scanTargets
}

0 comments on commit ec98406

Please sign in to comment.