diff --git a/firestore/client.go b/firestore/client.go index 28b4035ae58..6b76e9cde48 100644 --- a/firestore/client.go +++ b/firestore/client.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) // resourcePrefixHeader is the name of the metadata header used to indicate @@ -53,9 +54,10 @@ const DetectProjectID = "*detect-project-id*" // A Client provides access to the Firestore service. type Client struct { - c *vkit.Client - projectID string - databaseID string // A client is tied to a single database. + c *vkit.Client + projectID string + databaseID string // A client is tied to a single database. + readSettings *readSettings // readSettings allows setting a snapshot time to read the database } // NewClient creates a new Firestore client that uses the given project. @@ -94,9 +96,10 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio } vc.SetGoogleClientInfo("gccl", internal.Version) c := &Client{ - c: vc, - projectID: projectID, - databaseID: "(default)", // always "(default)", for now + c: vc, + projectID: projectID, + databaseID: "(default)", // always "(default)", for now + readSettings: &readSettings{}, } return c, nil } @@ -199,10 +202,10 @@ func (c *Client) GetAll(ctx context.Context, docRefs []*DocumentRef) (_ []*Docum ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.GetAll") defer func() { trace.EndSpan(ctx, err) }() - return c.getAll(ctx, docRefs, nil) + return c.getAll(ctx, docRefs, nil, nil) } -func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte) (_ []*DocumentSnapshot, err error) { +func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte, rs *readSettings) (_ []*DocumentSnapshot, err error) { ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.Client.BatchGetDocuments") defer func() { trace.EndSpan(ctx, err) }() @@ -219,9 +222,18 @@ func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte) Database: c.path(), Documents: docNames, } + + // Note that transaction ID and other consistency selectors are mutually exclusive. + // We respect the transaction first, any read options passed by the caller second, + // and any read options stored in the client third. + if rt, hasOpts := parseReadTime(c, rs); hasOpts { + req.ConsistencySelector = &pb.BatchGetDocumentsRequest_ReadTime{ReadTime: rt} + } + if tid != nil { - req.ConsistencySelector = &pb.BatchGetDocumentsRequest_Transaction{tid} + req.ConsistencySelector = &pb.BatchGetDocumentsRequest_Transaction{Transaction: tid} } + streamClient, err := c.c.BatchGetDocuments(withResourceHeader(ctx, req.Database), req) if err != nil { return nil, err @@ -306,6 +318,15 @@ func (c *Client) BulkWriter(ctx context.Context) *BulkWriter { return bw } +// WithReadOptions specifies constraints for accessing documents from the database, +// e.g. at what time snapshot to read the documents. +func (c *Client) WithReadOptions(opts ...ReadOption) *Client { + for _, ro := range opts { + ro.apply(c.readSettings) + } + return c +} + // commit calls the Commit RPC outside of a transaction. func (c *Client) commit(ctx context.Context, ws []*pb.Write) (_ []*WriteResult, err error) { ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.Client.commit") @@ -381,3 +402,35 @@ func (ec emulatorCreds) GetRequestMetadata(ctx context.Context, uri ...string) ( func (ec emulatorCreds) RequireTransportSecurity() bool { return false } + +// ReadTime specifies a time-specific snapshot of the database to read. +func ReadTime(t time.Time) ReadOption { + return readTime(t) +} + +type readTime time.Time + +func (rt readTime) apply(rs *readSettings) { + rs.readTime = time.Time(rt) +} + +// ReadOption interface allows for abstraction of computing read time settings. +type ReadOption interface { + apply(*readSettings) +} + +// readSettings contains the ReadOptions for a read operation +type readSettings struct { + readTime time.Time +} + +// parseReadTime ensures that fallback order of read options is respected. +func parseReadTime(c *Client, rs *readSettings) (*timestamppb.Timestamp, bool) { + if rs != nil && !rs.readTime.IsZero() { + return ×tamppb.Timestamp{Seconds: int64(rs.readTime.Unix())}, true + } + if c.readSettings != nil && !c.readSettings.readTime.IsZero() { + return ×tamppb.Timestamp{Seconds: int64(c.readSettings.readTime.Unix())}, true + } + return nil, false +} diff --git a/firestore/client_test.go b/firestore/client_test.go index 9c028c87d76..42f7a42765c 100644 --- a/firestore/client_test.go +++ b/firestore/client_test.go @@ -17,16 +17,19 @@ package firestore import ( "context" "testing" + "time" tspb "github.com/golang/protobuf/ptypes/timestamp" pb "google.golang.org/genproto/googleapis/firestore/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) var testClient = &Client{ - projectID: "projectID", - databaseID: "(default)", + projectID: "projectID", + databaseID: "(default)", + readSettings: &readSettings{}, } func TestClientCollectionAndDoc(t *testing.T) { @@ -45,16 +48,18 @@ func TestClientCollectionAndDoc(t *testing.T) { path: "projects/projectID/databases/(default)/documents/X", parentPath: db + "/documents", }, + readSettings: &readSettings{}, } if !testEqual(coll1, wantc1) { t.Fatalf("got\n%+v\nwant\n%+v", coll1, wantc1) } doc1 := testClient.Doc("X/a") wantd1 := &DocumentRef{ - Parent: coll1, - ID: "a", - Path: "projects/projectID/databases/(default)/documents/X/a", - shortPath: "X/a", + Parent: coll1, + ID: "a", + Path: "projects/projectID/databases/(default)/documents/X/a", + shortPath: "X/a", + readSettings: &readSettings{}, } if !testEqual(doc1, wantd1) { @@ -309,3 +314,44 @@ func TestGetAllErrors(t *testing.T) { t.Error("got nil, want error") } } + +func TestClient_WithReadOptions(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const dbPath = "projects/projectID/databases/(default)" + const docPath = dbPath + "/documents/C/a" + tm := time.Date(2021, time.February, 20, 0, 0, 0, 0, time.UTC) + + dr := &DocumentRef{ + Parent: &CollectionRef{ + c: c, + }, + ID: "123", + Path: docPath, + } + + srv.addRPC(&pb.BatchGetDocumentsRequest{ + Database: dbPath, + Documents: []string{docPath}, + ConsistencySelector: &pb.BatchGetDocumentsRequest_ReadTime{ + ReadTime: ×tamppb.Timestamp{Seconds: tm.Unix()}, + }, + }, []interface{}{ + &pb.BatchGetDocumentsResponse{ + ReadTime: ×tamppb.Timestamp{Seconds: tm.Unix()}, + Result: &pb.BatchGetDocumentsResponse_Found{ + Found: &pb.Document{}, + }, + }, + }) + + _, err := c.WithReadOptions(ReadTime(tm)).GetAll(ctx, []*DocumentRef{ + dr, + }) + + if err != nil { + t.Fatal(err) + } +} diff --git a/firestore/collref.go b/firestore/collref.go index 73cecb66e20..0366f1e00e0 100644 --- a/firestore/collref.go +++ b/firestore/collref.go @@ -49,6 +49,10 @@ type CollectionRef struct { // Use the methods of Query on a CollectionRef to create and run queries. Query + + // readSettings specifies constraints for reading documents in the collection + // e.g. read time + readSettings *readSettings } func newTopLevelCollRef(c *Client, dbPath, id string) *CollectionRef { @@ -64,6 +68,7 @@ func newTopLevelCollRef(c *Client, dbPath, id string) *CollectionRef { path: dbPath + "/documents/" + id, parentPath: dbPath + "/documents", }, + readSettings: &readSettings{}, } } @@ -82,6 +87,7 @@ func newCollRefWithParent(c *Client, parent *DocumentRef, id string) *Collection path: parent.Path + "/" + id, parentPath: parent.Path, }, + readSettings: &readSettings{}, } } @@ -121,7 +127,7 @@ func (c *CollectionRef) Add(ctx context.Context, data interface{}) (*DocumentRef // missing documents. A missing document is a document that does not exist but has // sub-documents. func (c *CollectionRef) DocumentRefs(ctx context.Context) *DocumentRefIterator { - return newDocumentRefIterator(ctx, c, nil) + return newDocumentRefIterator(ctx, c, nil, c.readSettings) } const alphanum = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -136,3 +142,12 @@ func uniqueID() string { } return string(b) } + +// WithReadOptions specifies constraints for accessing documents from the database, +// e.g. at what time snapshot to read the documents. +func (c *CollectionRef) WithReadOptions(opts ...ReadOption) *CollectionRef { + for _, ro := range opts { + ro.apply(c.readSettings) + } + return c +} diff --git a/firestore/collref_test.go b/firestore/collref_test.go index 410bf4c6db1..26af6fb3218 100644 --- a/firestore/collref_test.go +++ b/firestore/collref_test.go @@ -17,19 +17,22 @@ package firestore import ( "context" "testing" + "time" "github.com/golang/protobuf/proto" pb "google.golang.org/genproto/googleapis/firestore/v1" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestDoc(t *testing.T) { coll := testClient.Collection("C") got := coll.Doc("d") want := &DocumentRef{ - Parent: coll, - ID: "d", - Path: "projects/projectID/databases/(default)/documents/C/d", - shortPath: "C/d", + Parent: coll, + ID: "d", + Path: "projects/projectID/databases/(default)/documents/C/d", + shortPath: "C/d", + readSettings: &readSettings{}, } if !testEqual(got, want) { t.Errorf("got %+v, want %+v", got, want) @@ -98,3 +101,35 @@ func TestNilErrors(t *testing.T) { t.Fatalf("got <%v>, want <%v>", err, errNilDocRef) } } + +func TestCollRef_WithReadOptions(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const dbPath = "projects/projectID/databases/(default)" + const docPath = dbPath + "/documents/C/a" + tm := time.Date(2021, time.February, 20, 0, 0, 0, 0, time.UTC) + + srv.addRPC(&pb.ListDocumentsRequest{ + Parent: dbPath, + CollectionId: "myCollection", + ShowMissing: true, + ConsistencySelector: &pb.ListDocumentsRequest_ReadTime{ + ReadTime: ×tamppb.Timestamp{Seconds: tm.Unix()}, + }, + }, []interface{}{ + &pb.ListDocumentsResponse{ + Documents: []*pb.Document{ + { + Name: docPath, + }, + }, + }, + }) + + _, err := c.Collection("myCollection").WithReadOptions(ReadTime(tm)).DocumentRefs(ctx).GetAll() + if err == nil { + t.Fatal(err) + } +} diff --git a/firestore/docref.go b/firestore/docref.go index 55d518d2e03..6dba46f9fca 100644 --- a/firestore/docref.go +++ b/firestore/docref.go @@ -47,14 +47,18 @@ type DocumentRef struct { // The ID of the document: the last component of the resource path. ID string + + // The options (only read time currently supported) for reading this document + readSettings *readSettings } func newDocRef(parent *CollectionRef, id string) *DocumentRef { return &DocumentRef{ - Parent: parent, - ID: id, - Path: parent.Path + "/" + id, - shortPath: parent.selfPath + "/" + id, + Parent: parent, + ID: id, + Path: parent.Path + "/" + id, + shortPath: parent.selfPath + "/" + id, + readSettings: &readSettings{}, } } @@ -77,7 +81,8 @@ func (d *DocumentRef) Get(ctx context.Context) (_ *DocumentSnapshot, err error) if d == nil { return nil, errNilDocRef } - docsnaps, err := d.Parent.c.getAll(ctx, []*DocumentRef{d}, nil) + + docsnaps, err := d.Parent.c.getAll(ctx, []*DocumentRef{d}, nil, d.readSettings) if err != nil { return nil, err } @@ -803,7 +808,7 @@ type DocumentSnapshotIterator struct { // Next is not expected to return iterator.Done unless it is called after Stop. // Rarely, networking issues may also cause iterator.Done to be returned. func (it *DocumentSnapshotIterator) Next() (*DocumentSnapshot, error) { - btree, _, readTime, err := it.ws.nextSnapshot() + btree, _, rt, err := it.ws.nextSnapshot() if err != nil { if err == io.EOF { err = iterator.Done @@ -812,7 +817,7 @@ func (it *DocumentSnapshotIterator) Next() (*DocumentSnapshot, error) { return nil, err } if btree.Len() == 0 { // document deleted - return &DocumentSnapshot{Ref: it.docref, ReadTime: readTime}, nil + return &DocumentSnapshot{Ref: it.docref, ReadTime: rt}, nil } snap, _ := btree.At(0) return snap.(*DocumentSnapshot), nil @@ -824,3 +829,12 @@ func (it *DocumentSnapshotIterator) Next() (*DocumentSnapshot, error) { func (it *DocumentSnapshotIterator) Stop() { it.ws.stop() } + +// WithReadOptions specifies constraints for accessing documents from the database, +// e.g. at what time snapshot to read the documents. +func (d *DocumentRef) WithReadOptions(opts ...ReadOption) *DocumentRef { + for _, ro := range opts { + ro.apply(d.readSettings) + } + return d +} diff --git a/firestore/docref_test.go b/firestore/docref_test.go index 6525ba9aee7..29c97f91678 100644 --- a/firestore/docref_test.go +++ b/firestore/docref_test.go @@ -21,10 +21,12 @@ import ( "testing" "time" + "google.golang.org/api/iterator" pb "google.golang.org/genproto/googleapis/firestore/v1" "google.golang.org/genproto/googleapis/type/latlng" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) var ( @@ -315,3 +317,67 @@ func TestUpdateProcess(t *testing.T) { } } } + +func TestDocRef_WithReadOptions(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const dbPath = "projects/projectID/databases/(default)" + const docPath = dbPath + "/documents/C/a" + tm := time.Date(2021, time.February, 20, 0, 0, 0, 0, time.UTC) + + srv.addRPC(&pb.ListDocumentsRequest{ + Parent: dbPath + "/documents", + CollectionId: "myCollection", + Mask: &pb.DocumentMask{}, + ShowMissing: true, + }, []interface{}{ + &pb.ListDocumentsResponse{ + Documents: []*pb.Document{ + { + Name: dbPath + "/documents/C/a", + CreateTime: ×tamppb.Timestamp{Seconds: 10}, + UpdateTime: ×tamppb.Timestamp{Seconds: 20}, + Fields: map[string]*pb.Value{"a": intval(1)}, + }, + }, + }, + }) + srv.addRPC(&pb.BatchGetDocumentsRequest{ + Database: dbPath, + Documents: []string{docPath}, + ConsistencySelector: &pb.BatchGetDocumentsRequest_ReadTime{ + ReadTime: ×tamppb.Timestamp{Seconds: tm.Unix()}, + }, + }, []interface{}{ + &pb.BatchGetDocumentsResponse{ + ReadTime: ×tamppb.Timestamp{Seconds: tm.Unix()}, + Result: &pb.BatchGetDocumentsResponse_Found{ + Found: &pb.Document{ + Name: dbPath + "/documents/C/a", + CreateTime: ×tamppb.Timestamp{Seconds: 10}, + UpdateTime: ×tamppb.Timestamp{Seconds: 20}, + Fields: map[string]*pb.Value{"a": intval(1)}, + }, + }, + }, + }) + + it := c.Collection("myCollection").DocumentRefs(ctx) + + for { + dr, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + _, err = dr.WithReadOptions(ReadTime(tm)).Get(ctx) + if err != nil { + t.Fatal(err) + } + } + +} diff --git a/firestore/from_value_test.go b/firestore/from_value_test.go index 3c6cd0ed8c8..99bd931edb9 100644 --- a/firestore/from_value_test.go +++ b/firestore/from_value_test.go @@ -86,7 +86,9 @@ func TestCreateFromProtoValue(t *testing.T) { parentPath: "projects/P/databases/D/documents", path: "projects/P/databases/D/documents/c", }, + readSettings: &readSettings{}, }, + readSettings: &readSettings{}, }, }, } { @@ -517,6 +519,7 @@ func TestPathToDoc(t *testing.T) { collectionID: "c2", parentPath: "projects/P/databases/D/documents/c1/d1", path: "projects/P/databases/D/documents/c1/d1/c2", + readSettings: &readSettings{}, }, Parent: &DocumentRef{ ID: "d1", @@ -535,9 +538,13 @@ func TestPathToDoc(t *testing.T) { parentPath: "projects/P/databases/D/documents", path: "projects/P/databases/D/documents/c1", }, + readSettings: &readSettings{}, }, + readSettings: &readSettings{}, }, + readSettings: &readSettings{}, }, + readSettings: &readSettings{}, } if !testEqual(got, want) { t.Errorf("\ngot %+v\nwant %+v", got, want) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 2fe22b0f1ac..9ba1c8b0d65 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -1823,3 +1823,48 @@ func TestIntegration_CountAggregationQuery(t *testing.T) { t.Errorf("COUNT aggregation query mismatch;\ngot: %d, want: %d", cv.GetIntegerValue(), 2) } } + +func TestIntegration_ClientReadTime(t *testing.T) { + docs := []*DocumentRef{ + iColl.NewDoc(), + iColl.NewDoc(), + } + + c := integrationClient(t) + ctx := context.Background() + bw := c.BulkWriter(ctx) + jobs := make([]*BulkWriterJob, 0) + + // Populate the collection + f := integrationTestMap + for _, d := range docs { + j, err := bw.Create(d, f) + jobs = append(jobs, j) + if err != nil { + t.Fatal(err) + } + } + bw.End() + + for _, j := range jobs { + _, err := j.Results() + if err != nil { + t.Fatal(err) + } + } + + tm := time.Now() + c.WithReadOptions(ReadTime(tm)) + + ds, err := c.GetAll(ctx, docs) + if err != nil { + t.Fatal(err) + } + + for _, d := range ds { + if !tm.Equal(d.ReadTime) { + t.Errorf("wanted read time: %v; got: %v", + tm.UnixNano(), d.ReadTime.UnixNano()) + } + } +} diff --git a/firestore/list_documents.go b/firestore/list_documents.go index 712697f3d8b..58b64073822 100644 --- a/firestore/list_documents.go +++ b/firestore/list_documents.go @@ -33,7 +33,7 @@ type DocumentRefIterator struct { err error } -func newDocumentRefIterator(ctx context.Context, cr *CollectionRef, tid []byte) *DocumentRefIterator { +func newDocumentRefIterator(ctx context.Context, cr *CollectionRef, tid []byte, rs *readSettings) *DocumentRefIterator { ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.ListDocuments") defer func() { trace.EndSpan(ctx, nil) }() @@ -44,8 +44,14 @@ func newDocumentRefIterator(ctx context.Context, cr *CollectionRef, tid []byte) ShowMissing: true, Mask: &pb.DocumentMask{}, // empty mask: we want only the ref } + + // Transactions and ReadTime are mutually exclusive; Transactions should be + // respected before read time. + if rt, hasOpts := parseReadTime(client, rs); hasOpts { + req.ConsistencySelector = &pb.ListDocumentsRequest_ReadTime{ReadTime: rt} + } if tid != nil { - req.ConsistencySelector = &pb.ListDocumentsRequest_Transaction{tid} + req.ConsistencySelector = &pb.ListDocumentsRequest_Transaction{Transaction: tid} } it := &DocumentRefIterator{ client: client, diff --git a/firestore/mock_test.go b/firestore/mock_test.go index 34fa3582b04..ee90fc7e8a8 100644 --- a/firestore/mock_test.go +++ b/firestore/mock_test.go @@ -166,6 +166,25 @@ func (s *mockServer) BatchGetDocuments(req *pb.BatchGetDocumentsRequest, bs pb.F return nil } +func (s *mockServer) ListDocuments(ctx context.Context, req *pb.ListDocumentsRequest) (*pb.ListDocumentsResponse, error) { + res, err := s.popRPC(req) + if err != nil { + return nil, err + } + responses := res.([]interface{}) + for _, res := range responses { + switch res := res.(type) { + case *pb.ListDocumentsResponse: + return res, nil + case error: + return nil, res + default: + panic(fmt.Sprintf("bad response type in ListDocuments: %+v", res)) + } + } + return nil, nil +} + func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryServer) error { res, err := s.popRPC(req) if err != nil { diff --git a/firestore/query.go b/firestore/query.go index 218738b7a35..ccbdff4e65a 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -54,6 +54,10 @@ type Query struct { // allDescendants indicates whether this query is for all collections // that match the ID under the specified parentPath. allDescendants bool + + // readOptions specifies constraints for reading results from the query + // e.g. read time + readSettings *readSettings } // DocumentID is the special field name representing the ID of a document @@ -787,7 +791,7 @@ func trunc32(i int) int32 { // Documents returns an iterator over the query's resulting documents. func (q Query) Documents(ctx context.Context) *DocumentIterator { return &DocumentIterator{ - iter: newQueryDocumentIterator(withResourceHeader(ctx, q.c.path()), &q, nil), q: &q, + iter: newQueryDocumentIterator(withResourceHeader(ctx, q.c.path()), &q, nil, q.readSettings), q: &q, } } @@ -892,15 +896,17 @@ type queryDocumentIterator struct { q *Query tid []byte // transaction ID, if any streamClient pb.Firestore_RunQueryClient + readSettings *readSettings // readOptions, if any } -func newQueryDocumentIterator(ctx context.Context, q *Query, tid []byte) *queryDocumentIterator { +func newQueryDocumentIterator(ctx context.Context, q *Query, tid []byte, rs *readSettings) *queryDocumentIterator { ctx, cancel := context.WithCancel(ctx) return &queryDocumentIterator{ - ctx: ctx, - cancel: cancel, - q: q, - tid: tid, + ctx: ctx, + cancel: cancel, + q: q, + tid: tid, + readSettings: rs, } } @@ -916,10 +922,15 @@ func (it *queryDocumentIterator) next() (_ *DocumentSnapshot, err error) { } req := &pb.RunQueryRequest{ Parent: it.q.parentPath, - QueryType: &pb.RunQueryRequest_StructuredQuery{sq}, + QueryType: &pb.RunQueryRequest_StructuredQuery{StructuredQuery: sq}, + } + + // Respect transactions first and read options (read time) second + if rt, hasOpts := parseReadTime(client, it.readSettings); hasOpts { + req.ConsistencySelector = &pb.RunQueryRequest_ReadTime{ReadTime: rt} } if it.tid != nil { - req.ConsistencySelector = &pb.RunQueryRequest_Transaction{it.tid} + req.ConsistencySelector = &pb.RunQueryRequest_Transaction{Transaction: it.tid} } it.streamClient, err = client.c.RunQuery(it.ctx, req) if err != nil { @@ -1045,6 +1056,15 @@ func (it *btreeDocumentIterator) next() (*DocumentSnapshot, error) { func (*btreeDocumentIterator) stop() {} +// WithReadOptions specifies constraints for accessing documents from the database, +// e.g. at what time snapshot to read the documents. +func (q *Query) WithReadOptions(opts ...ReadOption) *Query { + for _, ro := range opts { + ro.apply(q.readSettings) + } + return q +} + // AggregationQuery allows for generating aggregation results of an underlying // basic query. A single AggregationQuery can contain multiple aggregations. type AggregationQuery struct { diff --git a/firestore/transaction.go b/firestore/transaction.go index 4cda4704625..8f766b621af 100644 --- a/firestore/transaction.go +++ b/firestore/transaction.go @@ -34,6 +34,7 @@ type Transaction struct { maxAttempts int readOnly bool readAfterWrite bool + readSettings *readSettings } // A TransactionOption is an option passed to Client.Transaction. @@ -98,9 +99,10 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr } db := c.path() t := &Transaction{ - c: c, - ctx: withResourceHeader(ctx, db), - maxAttempts: DefaultTransactionMaxAttempts, + c: c, + ctx: withResourceHeader(ctx, db), + maxAttempts: DefaultTransactionMaxAttempts, + readSettings: &readSettings{}, } for _, opt := range opts { opt.config(t) @@ -220,7 +222,7 @@ func (t *Transaction) GetAll(drs []*DocumentRef) ([]*DocumentSnapshot, error) { t.readAfterWrite = true return nil, errReadAfterWrite } - return t.c.getAll(t.ctx, drs, t.id) + return t.c.getAll(t.ctx, drs, t.id, t.readSettings) } // A Queryer is a Query or a CollectionRef. CollectionRefs act as queries whose @@ -238,7 +240,7 @@ func (t *Transaction) Documents(q Queryer) *DocumentIterator { } query := q.query() return &DocumentIterator{ - iter: newQueryDocumentIterator(t.ctx, query, t.id), q: query, + iter: newQueryDocumentIterator(t.ctx, query, t.id, t.readSettings), q: query, } } @@ -250,7 +252,7 @@ func (t *Transaction) DocumentRefs(cr *CollectionRef) *DocumentRefIterator { t.readAfterWrite = true return &DocumentRefIterator{err: errReadAfterWrite} } - return newDocumentRefIterator(t.ctx, cr, t.id) + return newDocumentRefIterator(t.ctx, cr, t.id, t.readSettings) } // Create adds a Create operation to the Transaction. @@ -287,3 +289,12 @@ func (t *Transaction) addWrites(ws []*pb.Write, err error) error { t.writes = append(t.writes, ws...) return nil } + +// WithReadOptions specifies constraints for accessing documents from the database, +// e.g. at what time snapshot to read the documents. +func (t *Transaction) WithReadOptions(opts ...ReadOption) *Transaction { + for _, ro := range opts { + ro.apply(t.readSettings) + } + return t +} diff --git a/firestore/transaction_test.go b/firestore/transaction_test.go index 11b62be2864..2aeb5794ba6 100644 --- a/firestore/transaction_test.go +++ b/firestore/transaction_test.go @@ -17,12 +17,14 @@ package firestore import ( "context" "testing" + "time" "github.com/golang/protobuf/ptypes/empty" "google.golang.org/api/iterator" pb "google.golang.org/genproto/googleapis/firestore/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestRunTransaction(t *testing.T) { @@ -500,3 +502,44 @@ func TestRunTransaction_NonTransactionalOp(t *testing.T) { t.Fatal(err) } } + +func TestTransaction_WithReadOptions(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const db = "projects/projectID/databases/(default)" + tm := time.Date(2021, time.February, 20, 0, 0, 0, 0, time.UTC) + ts := ×tamppb.Timestamp{Nanos: int32(tm.UnixNano())} + tid := []byte{1} + + beginReq := &pb.BeginTransactionRequest{Database: db} + beginRes := &pb.BeginTransactionResponse{Transaction: tid} + + srv.reset() + srv.addRPC(beginReq, beginRes) + + srv.addRPC( + &pb.CommitRequest{ + Database: db, + Transaction: tid, + }, + &pb.CommitResponse{CommitTime: ts}, + ) + + srv.addRPC( + &pb.CommitRequest{ + Database: db, + Transaction: tid, + }, + &pb.CommitResponse{CommitTime: ts}, + ) + + if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error { + docref := c.Collection("C").Doc("a") + tx.WithReadOptions(ReadTime(tm)).Get(docref) + return nil + }); err != nil { + t.Fatal(err) + } +} diff --git a/firestore/util_test.go b/firestore/util_test.go index 2e22aaa694a..b7ff4fb8d78 100644 --- a/firestore/util_test.go +++ b/firestore/util_test.go @@ -49,9 +49,12 @@ func mustTimestampProto(t time.Time) *tspb.Timestamp { } var cmpOpts = []cmp.Option{ - cmp.AllowUnexported(DocumentRef{}, CollectionRef{}, DocumentSnapshot{}, - Query{}, filter{}, order{}, fpv{}), + cmp.AllowUnexported(DocumentSnapshot{}, + Query{}, filter{}, order{}, fpv{}, DocumentRef{}, CollectionRef{}, Query{}), cmpopts.IgnoreTypes(Client{}, &Client{}), + cmp.Comparer(func(*readSettings, *readSettings) bool { + return true // Don't try to compare two readSettings pointer types + }), } // testEqual implements equality for Firestore tests.