Skip to content

Commit

Permalink
Merge branch 'main' into godocfx
Browse files Browse the repository at this point in the history
  • Loading branch information
gcf-merge-on-green[bot] committed Oct 23, 2023
2 parents 1cb19d8 + c8e7692 commit fe61526
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 71 deletions.
105 changes: 85 additions & 20 deletions bigquery/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,114 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"math/big"

"cloud.google.com/go/civil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/memory"
"google.golang.org/api/iterator"
)

type arrowDecoder struct {
tableSchema Schema
rawArrowSchema []byte
arrowSchema *arrow.Schema
// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
type ArrowRecordBatch struct {
reader io.Reader
// Serialized Arrow Record Batch.
Data []byte
// Serialized Arrow Schema.
Schema []byte
// Source partition ID. In the Storage API world, it represents the ReadStream.
PartitionID string
}

// Read makes ArrowRecordBatch implements io.Reader
func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
if r.reader == nil {
buf := bytes.NewBuffer(r.Schema)
buf.Write(r.Data)
r.reader = buf
}
return r.reader.Read(p)
}

// ArrowIterator represents a way to iterate through a stream of arrow records.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
type ArrowIterator interface {
Next() (*ArrowRecordBatch, error)
Schema() Schema
SerializedArrowSchema() []byte
}

func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) {
bqSession := session.bqSession
if bqSession == nil {
return nil, errors.New("read session not initialized")
// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
func NewArrowIteratorReader(it ArrowIterator) io.Reader {
return &arrowIteratorReader{
it: it,
}
arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema()
}

type arrowIteratorReader struct {
buf *bytes.Buffer
it ArrowIterator
}

// Read makes ArrowIteratorReader implement io.Reader
func (r *arrowIteratorReader) Read(p []byte) (int, error) {
if r.it == nil {
return -1, errors.New("bigquery: nil ArrowIterator")
}
if r.buf == nil { // init with schema
buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
r.buf = buf
}
n, err := r.buf.Read(p)
if err == io.EOF {
batch, err := r.it.Next()
if err == iterator.Done {
return 0, io.EOF
}
r.buf.Write(batch.Data)
return r.Read(p)
}
return n, err
}

type arrowDecoder struct {
allocator memory.Allocator
tableSchema Schema
arrowSchema *arrow.Schema
}

func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
buf := bytes.NewBuffer(arrowSerializedSchema)
r, err := ipc.NewReader(buf)
if err != nil {
return nil, err
}
defer r.Release()
p := &arrowDecoder{
tableSchema: schema,
rawArrowSchema: arrowSerializedSchema,
arrowSchema: r.Schema(),
tableSchema: schema,
arrowSchema: r.Schema(),
allocator: memory.DefaultAllocator,
}
return p, nil
}

func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) {
buf := bytes.NewBuffer(ap.rawArrowSchema)
buf.Write(serializedArrowRecordBatch)
return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema))
func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
return ipc.NewReader(
arrowRecordBatch,
ipc.WithSchema(ap.arrowSchema),
ipc.WithAllocator(ap.allocator),
)
}

// decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand All @@ -79,8 +144,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([
}

// decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion bigquery/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type RowIterator struct {
ctx context.Context
src *rowSource

arrowIterator *arrowIterator
arrowIterator ArrowIterator
arrowDecoder *arrowDecoder

pageInfo *iterator.PageInfo
nextFunc func() error
Expand Down
2 changes: 1 addition & 1 deletion bigquery/storage_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) {
}
}
b.ReportMetric(float64(it.TotalRows), "rows")
bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams")
b.ReportMetric(float64(maxStreamCount), "max_streams")
}
Expand Down
93 changes: 88 additions & 5 deletions bigquery/storage_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ import (
"time"

"cloud.google.com/go/internal/testutil"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/apache/arrow/go/v12/arrow/math"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/google/go-cmp/cmp"
"google.golang.org/api/iterator"
)
Expand Down Expand Up @@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
session := it.arrowIterator.(*storageArrowIterator).session
bqSession := session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams))
}
streamSettings := it.arrowIterator.session.settings.maxStreamCount
streamSettings := session.settings.maxStreamCount
if tc.maxExpectedStreams > 0 {
if streamSettings > tc.maxExpectedStreams {
t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings)
Expand Down Expand Up @@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
total++
}

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
}
total++ // as we read the first value separately

bqSession := it.arrowIterator.session.bqSession
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
if len(bqSession.Streams) == 0 {
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
}
Expand Down Expand Up @@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) {
}
// resources are cleaned asynchronously
time.Sleep(time.Second)
if !it.arrowIterator.isDone() {
arrowIt := it.arrowIterator.(*storageArrowIterator)
if !arrowIt.isDone() {
t.Fatal("expected stream to be done")
}
}

func TestIntegration_StorageReadArrow(t *testing.T) {
if client == nil {
t.Skip("Integration tests skipped")
}
ctx := context.Background()
table := "`bigquery-public-data.usa_names.usa_1910_current`"
sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table)

q := storageOptimizedClient.Query(sql)
job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths
if err != nil {
t.Fatal(err)
}
it, err := job.Read(ctx)
if err != nil {
t.Fatal(err)
}

checkedAllocator := memory.NewCheckedAllocator(memory.DefaultAllocator)
it.arrowDecoder.allocator = checkedAllocator
defer checkedAllocator.AssertSize(t, 0)

arrowIt, err := it.ArrowIterator()
if err != nil {
t.Fatalf("expected iterator to be accelerated: %v", err)
}
arrowItReader := NewArrowIteratorReader(arrowIt)

records := []arrow.Record{}
r, err := ipc.NewReader(arrowItReader, ipc.WithAllocator(checkedAllocator))
numrec := 0
for r.Next() {
rec := r.Record()
rec.Retain()
defer rec.Release()
records = append(records, rec)
numrec += int(rec.NumRows())
}
r.Release()

arrowSchema := r.Schema()
arrowTable := array.NewTableFromRecords(arrowSchema, records)
defer arrowTable.Release()
if arrowTable.NumRows() != int64(it.TotalRows) {
t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows())
}
if arrowTable.NumCols() != 3 {
t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols())
}

sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table)
sumQuery := client.Query(sumSQL)
sumIt, err := sumQuery.Read(ctx)
if err != nil {
t.Fatal(err)
}
sumValues := []Value{}
err = sumIt.Next(&sumValues)
if err != nil {
t.Fatal(err)
}
totalFromSQL := sumValues[0].(int64)

tr := array.NewTableReader(arrowTable, arrowTable.NumRows())
defer tr.Release()
var totalFromArrow int64
for tr.Next() {
rec := tr.Record()
vec := rec.Column(1).(*array.Int64)
totalFromArrow += math.Int64.Sum(vec)
}
if totalFromArrow != totalFromSQL {
t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow)
}
}

func countIteratorRows(it *RowIterator) (total uint64, err error) {
for {
var dst []Value
Expand Down

0 comments on commit fe61526

Please sign in to comment.