From 9424cb5150815e22e0db95fbb02f1b4ef9969c88 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 4 Apr 2024 13:56:44 -0600 Subject: [PATCH 1/9] GODRIVER-2800 Remove the Session interface --- internal/docexamples/examples.go | 26 +-- internal/integration/client_test.go | 3 +- internal/integration/crud_helpers_test.go | 60 +++---- internal/integration/mongos_pinning_test.go | 27 ++-- internal/integration/sessions_test.go | 13 +- internal/integration/unified/entity.go | 6 +- .../unified/testrunner_operation.go | 4 +- internal/integration/unified_spec_test.go | 22 ++- mongo/client.go | 12 +- mongo/client_test.go | 2 +- mongo/crud_examples_test.go | 9 +- mongo/session.go | 151 +++++------------- mongo/with_transactions_test.go | 2 +- 13 files changed, 133 insertions(+), 204 deletions(-) diff --git a/internal/docexamples/examples.go b/internal/docexamples/examples.go index b08447c15c..7e43919cb6 100644 --- a/internal/docexamples/examples.go +++ b/internal/docexamples/examples.go @@ -1760,7 +1760,9 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { events := client.Database("reporting").Collection("events") return client.UseSession(ctx, func(sctx mongo.SessionContext) error { - err := sctx.StartTransaction(options.Transaction(). + sess := mongo.SessionFromContext(sctx) + + err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). SetWriteConcern(writeconcern.Majority()), ) @@ -1770,19 +1772,19 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } for { - err = sctx.CommitTransaction(sctx) + err = sess.CommitTransaction(sctx) switch e := err.(type) { case nil: return nil @@ -1830,8 +1832,10 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session // CommitWithRetry is an example function demonstrating transaction commit with retry logic. func CommitWithRetry(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + for { - err := sctx.CommitTransaction(sctx) + err := sess.CommitTransaction(sctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1892,8 +1896,10 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { } commitWithRetry := func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + for { - err := sctx.CommitTransaction(sctx) + err := sess.CommitTransaction(sctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1918,7 +1924,9 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { employees := client.Database("hr").Collection("employees") events := client.Database("reporting").Collection("events") - err := sctx.StartTransaction(options.Transaction(). + sess := mongo.SessionFromContext(sctx) + + err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). SetWriteConcern(writeconcern.Majority()), ) @@ -1928,13 +1936,13 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 8350db58e0..1aa50a5bf5 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -374,8 +374,7 @@ func TestClient(t *testing.T) { sess, err := mt.Client.StartSession(tc.opts) assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - xs := sess.(mongo.XSession) - consistent := xs.ClientSession().Consistent + consistent := sess.ClientSession().Consistent assert.Equal(mt, tc.consistent, consistent, "expected consistent to be %v, got %v", tc.consistent, consistent) }) } diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index 80abf29231..0677e6489d 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -158,7 +158,7 @@ type watcher interface { Watch(context.Context, interface{}, ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) } -func executeAggregate(mt *mtest.T, agg aggregator, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeAggregate(mt *mtest.T, agg aggregator, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() var pipeline []interface{} @@ -198,7 +198,7 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess mongo.Session, args bson return agg.Aggregate(context.Background(), pipeline, opts) } -func executeWatch(mt *mtest.T, w watcher, sess mongo.Session, args bson.Raw) (*mongo.ChangeStream, error) { +func executeWatch(mt *mtest.T, w watcher, sess *mongo.Session, args bson.Raw) (*mongo.ChangeStream, error) { mt.Helper() pipeline := []interface{}{} @@ -227,7 +227,7 @@ func executeWatch(mt *mtest.T, w watcher, sess mongo.Session, args bson.Raw) (*m return w.Watch(context.Background(), pipeline) } -func executeCountDocuments(mt *mtest.T, sess mongo.Session, args bson.Raw) (int64, error) { +func executeCountDocuments(mt *mtest.T, sess *mongo.Session, args bson.Raw) (int64, error) { mt.Helper() filter := emptyDoc @@ -265,7 +265,7 @@ func executeCountDocuments(mt *mtest.T, sess mongo.Session, args bson.Raw) (int6 return mt.Coll.CountDocuments(context.Background(), filter, opts) } -func executeInsertOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.InsertOneResult, error) { +func executeInsertOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.InsertOneResult, error) { mt.Helper() doc := emptyDoc @@ -299,7 +299,7 @@ func executeInsertOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.In return mt.Coll.InsertOne(context.Background(), doc, opts) } -func executeInsertMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.InsertManyResult, error) { +func executeInsertMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.InsertManyResult, error) { mt.Helper() var docs []interface{} @@ -362,7 +362,7 @@ func setFindModifiers(modifiersDoc bson.Raw, opts *options.FindOptions) { } } -func executeFind(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeFind(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() filter := emptyDoc @@ -410,7 +410,7 @@ func executeFind(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, return mt.Coll.Find(context.Background(), filter, opts) } -func executeRunCommand(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeRunCommand(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() cmd := emptyDoc @@ -443,7 +443,7 @@ func executeRunCommand(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.Si return mt.DB.RunCommand(context.Background(), cmd, opts) } -func executeListCollections(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeListCollections(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() filter := emptyDoc @@ -472,7 +472,7 @@ func executeListCollections(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mo return mt.DB.ListCollections(context.Background(), filter) } -func executeListCollectionNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]string, error) { +func executeListCollectionNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]string, error) { mt.Helper() filter := emptyDoc @@ -501,7 +501,7 @@ func executeListCollectionNames(mt *mtest.T, sess mongo.Session, args bson.Raw) return mt.DB.ListCollectionNames(context.Background(), filter) } -func executeListDatabaseNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]string, error) { +func executeListDatabaseNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]string, error) { mt.Helper() filter := emptyDoc @@ -530,7 +530,7 @@ func executeListDatabaseNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([ return mt.Client.ListDatabaseNames(context.Background(), filter) } -func executeListDatabases(mt *mtest.T, sess mongo.Session, args bson.Raw) (mongo.ListDatabasesResult, error) { +func executeListDatabases(mt *mtest.T, sess *mongo.Session, args bson.Raw) (mongo.ListDatabasesResult, error) { mt.Helper() filter := emptyDoc @@ -559,7 +559,7 @@ func executeListDatabases(mt *mtest.T, sess mongo.Session, args bson.Raw) (mongo return mt.Client.ListDatabases(context.Background(), filter) } -func executeFindOne(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -587,7 +587,7 @@ func executeFindOne(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.Singl return mt.Coll.FindOne(context.Background(), filter) } -func executeListIndexes(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeListIndexes(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() // no arguments expected. add a Fatal in case arguments are added in the future @@ -604,7 +604,7 @@ func executeListIndexes(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo. return mt.Coll.Indexes().List(context.Background()) } -func executeDistinct(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]interface{}, error) { +func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]interface{}, error) { mt.Helper() var fieldName string @@ -641,7 +641,7 @@ func executeDistinct(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]interfac return mt.Coll.Distinct(context.Background(), fieldName, filter, opts) } -func executeFindOneAndDelete(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndDelete(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -680,7 +680,7 @@ func executeFindOneAndDelete(mt *mtest.T, sess mongo.Session, args bson.Raw) *mo return mt.Coll.FindOneAndDelete(context.Background(), filter, opts) } -func executeFindOneAndUpdate(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndUpdate(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -737,7 +737,7 @@ func executeFindOneAndUpdate(mt *mtest.T, sess mongo.Session, args bson.Raw) *mo return mt.Coll.FindOneAndUpdate(context.Background(), filter, update, opts) } -func executeFindOneAndReplace(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndReplace(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -790,7 +790,7 @@ func executeFindOneAndReplace(mt *mtest.T, sess mongo.Session, args bson.Raw) *m return mt.Coll.FindOneAndReplace(context.Background(), filter, replacement, opts) } -func executeDeleteOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { +func executeDeleteOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { mt.Helper() filter := emptyDoc @@ -826,7 +826,7 @@ func executeDeleteOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.De return mt.Coll.DeleteOne(context.Background(), filter, opts) } -func executeDeleteMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { +func executeDeleteMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { mt.Helper() filter := emptyDoc @@ -862,7 +862,7 @@ func executeDeleteMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.D return mt.Coll.DeleteMany(context.Background(), filter, opts) } -func executeUpdateOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeUpdateOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -910,7 +910,7 @@ func executeUpdateOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Up return mt.Coll.UpdateOne(context.Background(), filter, update, opts) } -func executeUpdateMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeUpdateMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -958,7 +958,7 @@ func executeUpdateMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.U return mt.Coll.UpdateMany(context.Background(), filter, update, opts) } -func executeReplaceOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeReplaceOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -1009,7 +1009,7 @@ type withTransactionArgs struct { Options bson.Raw `bson:"options"` } -func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess mongo.Session) error { +func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess *mongo.Session) error { mt.Helper() for _, op := range operations { @@ -1037,7 +1037,7 @@ func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess mon return nil } -func executeWithTransaction(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeWithTransaction(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() var testArgs withTransactionArgs @@ -1052,7 +1052,7 @@ func executeWithTransaction(mt *mtest.T, sess mongo.Session, args bson.Raw) erro return err } -func executeBulkWrite(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.BulkWriteResult, error) { +func executeBulkWrite(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.BulkWriteResult, error) { mt.Helper() models := createBulkWriteModels(mt, args.Lookup("requests").Array()) @@ -1196,7 +1196,7 @@ func createBulkWriteModel(mt *mtest.T, rawModel bson.Raw) mongo.WriteModel { return nil } -func executeEstimatedDocumentCount(mt *mtest.T, sess mongo.Session, args bson.Raw) (int64, error) { +func executeEstimatedDocumentCount(mt *mtest.T, sess *mongo.Session, args bson.Raw) (int64, error) { mt.Helper() // no arguments expected. add a Fatal in case arguments are added in the future @@ -1255,7 +1255,7 @@ func executeGridFSDownloadByName(mt *mtest.T, bucket *mongo.GridFSBucket, args b return bucket.DownloadToStreamByName(context.Background(), file, new(bytes.Buffer)) } -func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, error) { +func executeCreateIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (string, error) { mt.Helper() model := mongo.IndexModel{ @@ -1289,7 +1289,7 @@ func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, return mt.Coll.Indexes().CreateOne(context.Background(), model) } -func executeDropIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (bson.Raw, error) { +func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.Raw, error) { mt.Helper() var name string @@ -1318,7 +1318,7 @@ func executeDropIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (bson.Raw, return mt.Coll.Indexes().DropOne(context.Background(), name) } -func executeDropCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeDropCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() var collName string @@ -1348,7 +1348,7 @@ func executeDropCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error return coll.Drop(context.Background(), dco) } -func executeCreateCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeCreateCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() cco := options.CreateCollection() diff --git a/internal/integration/mongos_pinning_test.go b/internal/integration/mongos_pinning_test.go index 06b31762c9..f91f16018e 100644 --- a/internal/integration/mongos_pinning_test.go +++ b/internal/integration/mongos_pinning_test.go @@ -31,21 +31,22 @@ func TestMongosPinning(t *testing.T) { mt.Run("unpin for next transaction", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sc mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) // Insert a document in a transaction to pin session to a mongos - err := sc.StartTransaction() + err := sess.StartTransaction() assert.Nil(mt, err, "StartTransaction error: %v", err) - _, err = mt.Coll.InsertOne(sc, bson.D{{"x", 1}}) + _, err = mt.Coll.InsertOne(sctx, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, "CommitTransaction error: %v", err) for i := 0; i < 50; i++ { // Call Find in a new transaction to unpin from the old mongos and select a new one - err = sc.StartTransaction() + err = sess.StartTransaction() assert.Nil(mt, err, iterationErrmsg("StartTransaction", i, err)) - cursor, err := mt.Coll.Find(sc, bson.D{}) + cursor, err := mt.Coll.Find(sctx, bson.D{}) assert.Nil(mt, err, iterationErrmsg("Find", i, err)) assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i) @@ -55,7 +56,7 @@ func TestMongosPinning(t *testing.T) { err = descConn.Close() assert.Nil(mt, err, iterationErrmsg("connection Close", i, err)) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, iterationErrmsg("CommitTransaction", i, err)) } return nil @@ -64,18 +65,20 @@ func TestMongosPinning(t *testing.T) { }) mt.Run("unpin for non transaction operation", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sc mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + // Insert a document in a transaction to pin session to a mongos - err := sc.StartTransaction() + err := sess.StartTransaction() assert.Nil(mt, err, "StartTransaction error: %v", err) - _, err = mt.Coll.InsertOne(sc, bson.D{{"x", 1}}) + _, err = mt.Coll.InsertOne(sctx, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, "CommitTransaction error: %v", err) for i := 0; i < 50; i++ { // Call Find with the session but outside of a transaction - cursor, err := mt.Coll.Find(sc, bson.D{}) + cursor, err := mt.Coll.Find(sctx, bson.D{}) assert.Nil(mt, err, iterationErrmsg("Find", i, err)) assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i) diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index a9ae56cba9..a95af5b15c 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -35,14 +35,14 @@ func TestSessionPool(t *testing.T) { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - initialLastUsedTime := getSessionLastUsedTime(mt, sess) + initialLastUsedTime := sess.ClientSession().LastUsed err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { return mt.Client.Ping(sc, readpref.Primary()) }) assert.Nil(mt, err, "WithSession error: %v", err) - newLastUsedTime := getSessionLastUsedTime(mt, sess) + newLastUsedTime := sess.ClientSession().LastUsed assert.True(mt, newLastUsedTime.After(initialLastUsedTime), "last used time %s is not after the initial last used time %s", newLastUsedTime, initialLastUsedTime) }) @@ -63,7 +63,6 @@ func TestSessions(t *testing.T) { defer sess.EndSession(context.Background()) ctx := mongo.NewSessionContext(context.Background(), sess) - assert.Equal(mt, sess.ID(), ctx.ID(), "expected Session ID %v, got %v", sess.ID(), ctx.ID()) gotSess := mongo.SessionFromContext(ctx) assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil") @@ -513,7 +512,7 @@ type sessionFunction struct { params []interface{} // should not include context } -func (sf sessionFunction) execute(mt *mtest.T, sess mongo.Session) error { +func (sf sessionFunction) execute(mt *mtest.T, sess *mongo.Session) error { var target reflect.Value switch sf.target { case "client": @@ -639,9 +638,3 @@ func extractSentSessionID(mt *mtest.T) []byte { _, data := lsid.Document().Lookup("id").Binary() return data } - -func getSessionLastUsedTime(mt *mtest.T, sess mongo.Session) time.Time { - xsess, ok := sess.(mongo.XSession) - assert.True(mt, ok, "expected session to implement mongo.XSession, but got %T", sess) - return xsess.ClientSession().LastUsed -} diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index 75bbee6035..873a828da7 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -191,7 +191,7 @@ type EntityMap struct { clientEntities map[string]*clientEntity dbEntites map[string]*mongo.Database collEntities map[string]*mongo.Collection - sessions map[string]mongo.Session + sessions map[string]*mongo.Session gridfsBuckets map[string]*mongo.GridFSBucket bsonValues map[string]bson.RawValue eventListEntities map[string][]bson.Raw @@ -225,7 +225,7 @@ func newEntityMap() *EntityMap { clientEntities: make(map[string]*clientEntity), collEntities: make(map[string]*mongo.Collection), dbEntites: make(map[string]*mongo.Database), - sessions: make(map[string]mongo.Session), + sessions: make(map[string]*mongo.Session), eventListEntities: make(map[string][]bson.Raw), bsonArrayEntities: make(map[string][]bson.Raw), successValues: make(map[string]int32), @@ -422,7 +422,7 @@ func (em *EntityMap) database(id string) (*mongo.Database, error) { return db, nil } -func (em *EntityMap) session(id string) (mongo.Session, error) { +func (em *EntityMap) session(id string) (*mongo.Session, error) { sess, ok := em.sessions[id] if !ok { return nil, newEntityNotFoundError("session", id) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index 38f81bfed3..1079f33840 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -443,8 +443,8 @@ func waitForEvent(ctx context.Context, args waitForEventArguments) error { } } -func extractClientSession(sess mongo.Session) *session.Client { - return sess.(mongo.XSession).ClientSession() +func extractClientSession(sess *mongo.Session) *session.Client { + return sess.ClientSession() } func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index c62de30698..97f854f7e7 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -366,12 +366,12 @@ func createBucket(mt *mtest.T, testFile testFile, testCase *testCase) { testCase.bucket = mt.DB.GridFSBucket(bucketOpts) } -func runOperation(mt *mtest.T, testCase *testCase, op *operation, sess0, sess1 mongo.Session) error { +func runOperation(mt *mtest.T, testCase *testCase, op *operation, sess0, sess1 *mongo.Session) error { if op.Name == "count" { mt.Skip("count has been deprecated") } - var sess mongo.Session + var sess *mongo.Session if sessVal, err := op.Arguments.LookupErr("session"); err == nil { sessStr := sessVal.StringValue() switch sessStr { @@ -442,14 +442,10 @@ func executeGridFSOperation(mt *mtest.T, bucket *mongo.GridFSBucket, op *operati return nil } -func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess mongo.Session) error { +func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess *mongo.Session) error { var clientSession *session.Client if sess != nil { - xsess, ok := sess.(mongo.XSession) - if !ok { - return fmt.Errorf("expected session type %T to implement mongo.XSession", sess) - } - clientSession = xsess.ClientSession() + clientSession = sess.ClientSession() } switch op.Name { @@ -635,7 +631,7 @@ func lastTwoIDs(mt *mtest.T) (bson.RawValue, bson.RawValue) { return first, second } -func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeSessionOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "startTransaction": var txnOpts *options.TransactionOptions @@ -654,7 +650,7 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err } } -func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeCollectionOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "countDocuments": // no results to verify with count @@ -798,7 +794,7 @@ func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) return nil } -func executeDatabaseOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeDatabaseOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "runCommand": res := executeRunCommand(mt, sess, op.Arguments) @@ -853,7 +849,7 @@ func executeDatabaseOperation(mt *mtest.T, op *operation, sess mongo.Session) er return nil } -func executeClientOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeClientOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "listDatabaseNames": _, err := executeListDatabaseNames(mt, sess, op.Arguments) @@ -882,7 +878,7 @@ func executeClientOperation(mt *mtest.T, op *operation, sess mongo.Session) erro return nil } -func setupSessions(mt *mtest.T, test *testCase) (mongo.Session, mongo.Session) { +func setupSessions(mt *mtest.T, test *testCase) (*mongo.Session, *mongo.Session) { mt.Helper() var sess0Opts, sess1Opts *options.SessionOptions diff --git a/mongo/client.go b/mongo/client.go index 40c0b3c411..0bae485a22 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -374,7 +374,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { // // If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read // concern, write concern, or read preference will be used, respectively. -func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) { +func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) { if c.sessionPool == nil { return nil, ErrClientDisconnected } @@ -439,7 +439,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) sess.RetryWrite = false sess.RetryRead = c.retryReads - return &sessionImpl{ + return &Session{ clientSession: sess, client: c, deployment: c.deployment, @@ -786,7 +786,7 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts // If the ctx parameter already contains a Session, that Session will be replaced with the one provided. // // Any error returned by the fn callback will be returned without any modifications. -func WithSession(ctx context.Context, sess Session, fn func(SessionContext) error) error { +func WithSession(ctx context.Context, sess *Session, fn func(SessionContext) error) error { return fn(NewSessionContext(ctx, sess)) } @@ -809,7 +809,11 @@ func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) // // UseSessionWithOptions is safe to call from multiple goroutines concurrently. However, the SessionContext passed to // the UseSessionWithOptions callback function is not safe for concurrent use by multiple goroutines. -func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.SessionOptions, fn func(SessionContext) error) error { +func (c *Client) UseSessionWithOptions( + ctx context.Context, + opts *options.SessionOptions, + fn func(SessionContext) error, +) error { defaultSess, err := c.StartSession(opts) if err != nil { return err diff --git a/mongo/client_test.go b/mongo/client_test.go index ddba3062fe..e5d08642b3 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -380,7 +380,7 @@ func TestClient(t *testing.T) { // Do an application operation and create the number of sessions specified by the test. _, err = coll.CountDocuments(bgCtx, bson.D{}) assert.Nil(t, err, "CountDocuments error: %v", err) - var sessions []Session + var sessions []*Session for i := 0; i < tc.numSessions; i++ { sess, err := client.StartSession() assert.Nil(t, err, "StartSession error at index %d: %v", i, err) diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index 47127006f3..a92b7509fe 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -684,11 +684,12 @@ func ExampleClient_UseSessionWithOptions() { context.TODO(), opts, func(ctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(ctx) // Use the mongo.SessionContext as the Context parameter for // InsertOne and FindOne so both operations are run under the new // Session. - if err := ctx.StartTransaction(); err != nil { + if err := sess.StartTransaction(); err != nil { return err } @@ -699,7 +700,7 @@ func ExampleClient_UseSessionWithOptions() { // context.Background() to ensure that the abort can complete // successfully even if the context passed to mongo.WithSession // is changed to have a timeout. - _ = ctx.AbortTransaction(context.Background()) + _ = sess.AbortTransaction(context.Background()) return err } @@ -713,7 +714,7 @@ func ExampleClient_UseSessionWithOptions() { // context.Background() to ensure that the abort can complete // successfully even if the context passed to mongo.WithSession // is changed to have a timeout. - _ = ctx.AbortTransaction(context.Background()) + _ = sess.AbortTransaction(context.Background()) return err } fmt.Println(result) @@ -721,7 +722,7 @@ func ExampleClient_UseSessionWithOptions() { // Use context.Background() to ensure that the commit can complete // successfully even if the context passed to mongo.WithSession is // changed to have a timeout. - return ctx.CommitTransaction(context.Background()) + return sess.CommitTransaction(context.Background()) }) if err != nil { log.Fatal(err) diff --git a/mongo/session.go b/mongo/session.go index dcd83f650c..45f224e5bd 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -27,6 +27,25 @@ var ErrWrongClient = errors.New("session was not created by this client") var withTransactionTimeout = 120 * time.Second +// Session is a MongoDB logical session. Sessions can be used to enable causal +// consistency for a group of operations or to execute operations in an ACID +// transaction. A new Session can be created from a Client instance. A Session +// created from a Client must only be used to execute operations using that +// Client or a Database or Collection created from that Client. For more +// information about sessions, and their use cases, see +// https://www.mongodb.com/docs/manual/reference/server-sessions/, +// https://www.mongodb.com/docs/manual/core/read-isolation-consistency-recency/#causal-consistency, and +// https://www.mongodb.com/docs/manual/core/transactions/. +// +// Implementations of Session are not safe for concurrent use by multiple +// goroutines. +type Session struct { + clientSession *session.Client + client *Client + deployment driver.Deployment + didCommitAfterStart bool // true if commit was called after start with no other operations +} + // SessionContext combines the context.Context and mongo.Session interfaces. It should be used as the Context arguments // to operations that should be executed in a session. // @@ -37,35 +56,31 @@ var withTransactionTimeout = 120 * time.Second // the provided callback. The other is to use NewSessionContext to explicitly create a SessionContext. type SessionContext interface { context.Context - Session } -type sessionContext struct { +type sessionCtx struct { context.Context - Session } -type sessionKey struct { -} +type sessionKey struct{} // NewSessionContext creates a new SessionContext associated with the given Context and Session parameters. -func NewSessionContext(ctx context.Context, sess Session) SessionContext { - return &sessionContext{ +func NewSessionContext(ctx context.Context, sess *Session) SessionContext { + return &sessionCtx{ Context: context.WithValue(ctx, sessionKey{}, sess), - Session: sess, } } // SessionFromContext extracts the mongo.Session object stored in a Context. This can be used on a SessionContext that // was created implicitly through one of the callback-based session APIs or explicitly by calling NewSessionContext. If // there is no Session stored in the provided Context, nil is returned. -func SessionFromContext(ctx context.Context) Session { +func SessionFromContext(ctx context.Context) *Session { val := ctx.Value(sessionKey{}) if val == nil { return nil } - sess, ok := val.(Session) + sess, ok := val.(*Session) if !ok { return nil } @@ -73,104 +88,18 @@ func SessionFromContext(ctx context.Context) Session { return sess } -// Session is an interface that represents a MongoDB logical session. Sessions can be used to enable causal consistency -// for a group of operations or to execute operations in an ACID transaction. A new Session can be created from a Client -// instance. A Session created from a Client must only be used to execute operations using that Client or a Database or -// Collection created from that Client. Custom implementations of this interface should not be used in production. For -// more information about sessions, and their use cases, see -// https://www.mongodb.com/docs/manual/reference/server-sessions/, -// https://www.mongodb.com/docs/manual/core/read-isolation-consistency-recency/#causal-consistency, and -// https://www.mongodb.com/docs/manual/core/transactions/. -// -// Implementations of Session are not safe for concurrent use by multiple goroutines. -type Session interface { - // StartTransaction starts a new transaction, configured with the given options, on this - // session. This method returns an error if there is already a transaction in-progress for this - // session. - StartTransaction(...*options.TransactionOptions) error - - // AbortTransaction aborts the active transaction for this session. This method returns an error - // if there is no active transaction for this session or if the transaction has been committed - // or aborted. - AbortTransaction(context.Context) error - - // CommitTransaction commits the active transaction for this session. This method returns an - // error if there is no active transaction for this session or if the transaction has been - // aborted. - CommitTransaction(context.Context) error - - // WithTransaction starts a transaction on this session and runs the fn callback. Errors with - // the TransientTransactionError and UnknownTransactionCommitResult labels are retried for up to - // 120 seconds. Inside the callback, the SessionContext must be used as the Context parameter - // for any operations that should be part of the transaction. If the ctx parameter already has a - // Session attached to it, it will be replaced by this session. The fn callback may be run - // multiple times during WithTransaction due to retry attempts, so it must be idempotent. - // Non-retryable operation errors or any operation errors that occur after the timeout expires - // will be returned without retrying. If the callback fails, the driver will call - // AbortTransaction. Because this method must succeed to ensure that server-side resources are - // properly cleaned up, context deadlines and cancellations will not be respected during this - // call. For a usage example, see the Client.StartSession method documentation. - WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), - opts ...*options.TransactionOptions) (interface{}, error) - - // EndSession aborts any existing transactions and close the session. - EndSession(context.Context) - - // ClusterTime returns the current cluster time document associated with the session. - ClusterTime() bson.Raw - - // OperationTime returns the current operation time document associated with the session. - OperationTime() *primitive.Timestamp - - // Client the Client associated with the session. - Client() *Client - - // ID returns the current ID document associated with the session. The ID document is in the - // form {"id": }. - ID() bson.Raw - - // AdvanceClusterTime advances the cluster time for a session. This method returns an error if - // the session has ended. - AdvanceClusterTime(bson.Raw) error - - // AdvanceOperationTime advances the operation time for a session. This method returns an error - // if the session has ended. - AdvanceOperationTime(*primitive.Timestamp) error - - session() -} - -// XSession is an unstable interface for internal use only. -// -// Deprecated: This interface is unstable because it provides access to a session.Client object, which exists in the -// "x" package. It should not be used by applications and may be changed or removed in any release. -type XSession interface { - ClientSession() *session.Client -} - -// sessionImpl represents a set of sequential operations executed by an application that are related in some way. -type sessionImpl struct { - clientSession *session.Client - client *Client - deployment driver.Deployment - didCommitAfterStart bool // true if commit was called after start with no other operations -} - -var _ Session = &sessionImpl{} -var _ XSession = &sessionImpl{} - // ClientSession implements the XSession interface. -func (s *sessionImpl) ClientSession() *session.Client { +func (s *Session) ClientSession() *session.Client { return s.clientSession } // ID implements the Session interface. -func (s *sessionImpl) ID() bson.Raw { +func (s *Session) ID() bson.Raw { return bson.Raw(s.clientSession.SessionID) } // EndSession implements the Session interface. -func (s *sessionImpl) EndSession(ctx context.Context) { +func (s *Session) EndSession(ctx context.Context) { if s.clientSession.TransactionInProgress() { // ignore all errors aborting during an end session _ = s.AbortTransaction(ctx) @@ -179,7 +108,7 @@ func (s *sessionImpl) EndSession(ctx context.Context) { } // WithTransaction implements the Session interface. -func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), +func (s *Session) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { timeout := time.NewTimer(withTransactionTimeout) defer timeout.Stop() @@ -259,7 +188,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo } // StartTransaction implements the Session interface. -func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error { +func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { err := s.clientSession.CheckStartTransaction() if err != nil { return err @@ -296,7 +225,7 @@ func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) erro } // AbortTransaction implements the Session interface. -func (s *sessionImpl) AbortTransaction(ctx context.Context) error { +func (s *Session) AbortTransaction(ctx context.Context) error { err := s.clientSession.CheckAbortTransaction() if err != nil { return err @@ -322,7 +251,7 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { } // CommitTransaction implements the Session interface. -func (s *sessionImpl) CommitTransaction(ctx context.Context) error { +func (s *Session) CommitTransaction(ctx context.Context) error { err := s.clientSession.CheckCommitTransaction() if err != nil { return err @@ -366,39 +295,35 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { } // ClusterTime implements the Session interface. -func (s *sessionImpl) ClusterTime() bson.Raw { +func (s *Session) ClusterTime() bson.Raw { return s.clientSession.ClusterTime } // AdvanceClusterTime implements the Session interface. -func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error { +func (s *Session) AdvanceClusterTime(d bson.Raw) error { return s.clientSession.AdvanceClusterTime(d) } // OperationTime implements the Session interface. -func (s *sessionImpl) OperationTime() *primitive.Timestamp { +func (s *Session) OperationTime() *primitive.Timestamp { return s.clientSession.OperationTime } // AdvanceOperationTime implements the Session interface. -func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error { +func (s *Session) AdvanceOperationTime(ts *primitive.Timestamp) error { return s.clientSession.AdvanceOperationTime(ts) } // Client implements the Session interface. -func (s *sessionImpl) Client() *Client { +func (s *Session) Client() *Client { return s.client } -// session implements the Session interface. -func (*sessionImpl) session() { -} - // sessionFromContext checks for a sessionImpl in the argued context and returns the session if it // exists func sessionFromContext(ctx context.Context) *session.Client { s := ctx.Value(sessionKey{}) - if ses, ok := s.(*sessionImpl); ses != nil && ok { + if ses, ok := s.(*Session); ses != nil && ok { return ses.clientSession } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index f65ba7b4f1..eacc12d864 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -325,7 +325,7 @@ func TestConvenientTransactions(t *testing.T) { "expected timeout error error; got %v", commitErr) // Assert session state is not Committed. - clientSession := session.(XSession).ClientSession() + clientSession := session.ClientSession() assert.False(t, clientSession.TransactionCommitted(), "expected session state to not be Committed") // AbortTransaction without error. From 6759bf178618c2efaef405241df832da0859299c Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 26 Apr 2024 17:06:40 -0600 Subject: [PATCH 2/9] GODRIVER-2800 Remove SessionContext --- internal/docexamples/examples.go | 76 +++++++++---------- .../integration/causal_consistency_test.go | 42 +++++----- .../client_side_encryption_test.go | 4 +- internal/integration/crud_helpers_test.go | 58 +++++++------- .../integration/load_balancer_prose_test.go | 4 +- internal/integration/mongos_pinning_test.go | 4 +- .../retryable_writes_prose_test.go | 2 +- .../sessions_mongocryptd_prose_test.go | 2 +- internal/integration/sessions_test.go | 12 +-- internal/integration/unified/operation.go | 2 +- .../unified/session_operation_execution.go | 3 +- mongo/client.go | 53 +++++++------ mongo/crud_examples_test.go | 18 ++--- mongo/crypt_retrievers.go | 4 +- mongo/database_test.go | 2 +- mongo/mongocryptd.go | 2 +- mongo/session.go | 29 ++----- mongo/with_transactions_test.go | 24 +++--- 18 files changed, 165 insertions(+), 176 deletions(-) diff --git a/internal/docexamples/examples.go b/internal/docexamples/examples.go index 7e43919cb6..c575d7b617 100644 --- a/internal/docexamples/examples.go +++ b/internal/docexamples/examples.go @@ -1759,8 +1759,8 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { employees := client.Database("hr").Collection("employees") events := client.Database("reporting").Collection("events") - return client.UseSession(ctx, func(sctx mongo.SessionContext) error { - sess := mongo.SessionFromContext(sctx) + return client.UseSession(ctx, func(ctx context.Context) error { + sess := mongo.SessionFromContext(ctx) err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). @@ -1770,21 +1770,21 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { return err } - _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) + _, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sess.AbortTransaction(sctx) + sess.AbortTransaction(ctx) log.Println("caught exception during transaction, aborting.") return err } - _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) + _, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sess.AbortTransaction(sctx) + sess.AbortTransaction(ctx) log.Println("caught exception during transaction, aborting.") return err } for { - err = sess.CommitTransaction(sctx) + err = sess.CommitTransaction(ctx) switch e := err.(type) { case nil: return nil @@ -1808,9 +1808,9 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { // Start Transactions Retry Example 1 // RunTransactionWithRetry is an example function demonstrating transaction retry logic. -func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error { +func RunTransactionWithRetry(ctx context.Context, txnFn func(context.Context) error) error { for { - err := txnFn(sctx) // Performs transaction. + err := txnFn(ctx) // Performs transaction. if err == nil { return nil } @@ -1831,11 +1831,11 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session // Start Transactions Retry Example 2 // CommitWithRetry is an example function demonstrating transaction commit with retry logic. -func CommitWithRetry(sctx mongo.SessionContext) error { - sess := mongo.SessionFromContext(sctx) +func CommitWithRetry(ctx context.Context) error { + sess := mongo.SessionFromContext(ctx) for { - err := sess.CommitTransaction(sctx) + err := sess.CommitTransaction(ctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1877,9 +1877,9 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { } // Start Transactions Retry Example 3 - runTransactionWithRetry := func(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error { + runTransactionWithRetry := func(ctx context.Context, txnFn func(context.Context) error) error { for { - err := txnFn(sctx) // Performs transaction. + err := txnFn(ctx) // Performs transaction. if err == nil { return nil } @@ -1895,11 +1895,11 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { } } - commitWithRetry := func(sctx mongo.SessionContext) error { - sess := mongo.SessionFromContext(sctx) + commitWithRetry := func(ctx context.Context) error { + sess := mongo.SessionFromContext(ctx) for { - err := sess.CommitTransaction(sctx) + err := sess.CommitTransaction(ctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1920,11 +1920,11 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { } // Updates two collections in a transaction. - updateEmployeeInfo := func(sctx mongo.SessionContext) error { + updateEmployeeInfo := func(ctx context.Context) error { employees := client.Database("hr").Collection("employees") events := client.Database("reporting").Collection("events") - sess := mongo.SessionFromContext(sctx) + sess := mongo.SessionFromContext(ctx) err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). @@ -1934,26 +1934,26 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { return err } - _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) + _, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sess.AbortTransaction(sctx) + sess.AbortTransaction(ctx) log.Println("caught exception during transaction, aborting.") return err } - _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) + _, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sess.AbortTransaction(sctx) + sess.AbortTransaction(ctx) log.Println("caught exception during transaction, aborting.") return err } - return commitWithRetry(sctx) + return commitWithRetry(ctx) } return client.UseSessionWithOptions( ctx, options.Session().SetDefaultReadPreference(readpref.Primary()), - func(sctx mongo.SessionContext) error { - return runTransactionWithRetry(sctx, updateEmployeeInfo) + func(ctx context.Context) error { + return runTransactionWithRetry(ctx, updateEmployeeInfo) }, ) } @@ -1985,13 +1985,13 @@ func WithTransactionExample(ctx context.Context) error { barColl := client.Database("mydb1").Collection("bar", wcMajorityCollectionOpts) // Step 1: Define the callback that specifies the sequence of operations to perform inside the transaction. - callback := func(sessCtx mongo.SessionContext) (interface{}, error) { - // Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the + callback := func(sesctx context.Context) (interface{}, error) { + // Important: You must pass sesctx as the Context parameter to the operations for them to be executed in the // transaction. - if _, err := fooColl.InsertOne(sessCtx, bson.D{{"abc", 1}}); err != nil { + if _, err := fooColl.InsertOne(sesctx, bson.D{{"abc", 1}}); err != nil { return nil, err } - if _, err := barColl.InsertOne(sessCtx, bson.D{{"xyz", 999}}); err != nil { + if _, err := barColl.InsertOne(sesctx, bson.D{{"xyz", 999}}); err != nil { return nil, err } @@ -2569,15 +2569,15 @@ func CausalConsistencyExamples(client *mongo.Client) error { } defer session1.EndSession(context.TODO()) - err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error { + err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error { // Run an update with our causally-consistent session - _, err = coll.UpdateOne(sctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}}) + _, err = coll.UpdateOne(ctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}}) if err != nil { return err } // Run an insert with our causally-consistent session - _, err = coll.InsertOne(sctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}}) + _, err = coll.InsertOne(ctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}}) if err != nil { return err } @@ -2602,7 +2602,7 @@ func CausalConsistencyExamples(client *mongo.Client) error { } defer session2.EndSession(context.TODO()) - err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error { + err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error { // Set cluster time of session2 to session1's cluster time clusterTime := session1.ClusterTime() session2.AdvanceClusterTime(clusterTime) @@ -2611,13 +2611,13 @@ func CausalConsistencyExamples(client *mongo.Client) error { operationTime := session1.OperationTime() session2.AdvanceOperationTime(operationTime) // Run a find on session2, which should find all the writes from session1 - cursor, err := coll.Find(sctx, bson.D{{"end", nil}}) + cursor, err := coll.Find(ctx, bson.D{{"end", nil}}) if err != nil { return err } - for cursor.Next(sctx) { + for cursor.Next(ctx) { doc := cursor.Current fmt.Printf("Document: %v\n", doc.String()) } @@ -2993,7 +2993,7 @@ func snapshotQueryPetExample(mt *mtest.T) error { defer sess.EndSession(ctx) var adoptablePetsCount int32 - err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error { + err = mongo.WithSession(ctx, sess, func(ctx context.Context) error { // Count the adoptable cats const adoptableCatsOutput = "adoptableCatsCount" cursor, err := db.Collection("cats").Aggregate(ctx, mongo.Pipeline{ @@ -3057,7 +3057,7 @@ func snapshotQueryRetailExample(mt *mtest.T) error { defer sess.EndSession(ctx) var totalDailySales int32 - err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error { + err = mongo.WithSession(ctx, sess, func(ctx context.Context) error { // Count the total daily sales const totalDailySalesOutput = "totalDailySales" cursor, err := db.Collection("sales").Aggregate(ctx, mongo.Pipeline{ diff --git a/internal/integration/causal_consistency_test.go b/internal/integration/causal_consistency_test.go index 33581c538a..1778ea926c 100644 --- a/internal/integration/causal_consistency_test.go +++ b/internal/integration/causal_consistency_test.go @@ -42,8 +42,8 @@ func TestCausalConsistency_Supported(t *testing.T) { // first read in a causally consistent session must not send afterClusterTime to the server ccOpts := options.Session().SetCausalConsistency(true) - _ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(sc mongo.SessionContext) error { - _, _ = mt.Coll.Find(sc, bson.D{}) + _ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(ctx context.Context) error { + _, _ = mt.Coll.Find(ctx, bson.D{}) return nil }) @@ -58,8 +58,8 @@ func TestCausalConsistency_Supported(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _, _ = mt.Coll.Find(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _, _ = mt.Coll.Find(ctx, bson.D{}) return nil }) @@ -86,8 +86,8 @@ func TestCausalConsistency_Supported(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) currOptime := sess.OperationTime() @@ -121,8 +121,8 @@ func TestCausalConsistency_Supported(t *testing.T) { assert.NotNil(mt, currOptime, "expected session operation time, got nil") mt.ClearEvents() - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) _, sentOptime := getReadConcernFields(mt, mt.GetStartedEvent().Command) @@ -135,10 +135,10 @@ func TestCausalConsistency_Supported(t *testing.T) { // a read operation in a non causally-consistent session should not include afterClusterTime sessOpts := options.Session().SetCausalConsistency(false) - _ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error { - _, _ = mt.Coll.Find(sc, bson.D{}) + _ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error { + _, _ = mt.Coll.Find(ctx, bson.D{}) mt.ClearEvents() - _, _ = mt.Coll.Find(sc, bson.D{}) + _, _ = mt.Coll.Find(ctx, bson.D{}) return nil }) evt := mt.GetStartedEvent() @@ -153,14 +153,14 @@ func TestCausalConsistency_Supported(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) currOptime := sess.OperationTime() mt.ClearEvents() - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) @@ -175,14 +175,14 @@ func TestCausalConsistency_Supported(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) currOptime := sess.OperationTime() mt.ClearEvents() - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { - _ = mt.Coll.FindOne(sc, bson.D{}) + _ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + _ = mt.Coll.FindOne(ctx, bson.D{}) return nil }) @@ -216,8 +216,8 @@ func TestCausalConsistency_NotSupported(t *testing.T) { // support cluster times sessOpts := options.Session().SetCausalConsistency(true) - _ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error { - _, _ = mt.Coll.Find(sc, bson.D{}) + _ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error { + _, _ = mt.Coll.Find(ctx, bson.D{}) return nil }) diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index 49af404e08..4276440708 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -145,7 +145,7 @@ func TestClientSideEncryptionWithExplicitSessions(t *testing.T) { session, err := client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - sessionCtx := mongo.NewSessionContext(context.Background(), session) + sessionCtx := mongo.ContextWithSession(context.Background(), session) capturedEvents = make([]event.CommandStartedEvent, 0) _, err = coll.InsertOne(sessionCtx, bson.D{{"encryptMe", "test"}, {"keyName", "myKey"}}) @@ -209,7 +209,7 @@ func TestClientSideEncryptionWithExplicitSessions(t *testing.T) { session, err := client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - sessionCtx := mongo.NewSessionContext(context.Background(), session) + sessionCtx := mongo.ContextWithSession(context.Background(), session) capturedEvents = make([]event.CommandStartedEvent, 0) res := coll.FindOne(sessionCtx, bson.D{{}}) diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index 0677e6489d..cca64159f0 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -188,7 +188,7 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess *mongo.Session, args bso if sess != nil { var cur *mongo.Cursor - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var aerr error cur, aerr = agg.Aggregate(sc, pipeline, opts) return aerr @@ -217,7 +217,7 @@ func executeWatch(mt *mtest.T, w watcher, sess *mongo.Session, args bson.Raw) (* if sess != nil { var stream *mongo.ChangeStream - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var csErr error stream, csErr = w.Watch(sc, pipeline) return csErr @@ -255,7 +255,7 @@ func executeCountDocuments(mt *mtest.T, sess *mongo.Session, args bson.Raw) (int if sess != nil { var count int64 - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var countErr error count, countErr = mt.Coll.CountDocuments(sc, filter, opts) return countErr @@ -289,7 +289,7 @@ func executeInsertOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.I if sess != nil { var res *mongo.InsertOneResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var insertErr error res, insertErr = mt.Coll.InsertOne(sc, doc, opts) return insertErr @@ -327,7 +327,7 @@ func executeInsertMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo. if sess != nil { var res *mongo.InsertManyResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var insertErr error res, insertErr = mt.Coll.InsertMany(sc, docs, opts) return insertErr @@ -400,7 +400,7 @@ func executeFind(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor if sess != nil { var c *mongo.Cursor - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var findErr error c, findErr = mt.Coll.Find(sc, filter, opts) return findErr @@ -434,7 +434,7 @@ func executeRunCommand(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.S if sess != nil { var sr *mongo.SingleResult - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + _ = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { sr = mt.DB.RunCommand(sc, cmd, opts) return nil }) @@ -462,7 +462,7 @@ func executeListCollections(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*m if sess != nil { var c *mongo.Cursor - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var lcErr error c, lcErr = mt.DB.ListCollections(sc, filter) return lcErr @@ -491,7 +491,7 @@ func executeListCollectionNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) if sess != nil { var res []string - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var lcErr error res, lcErr = mt.DB.ListCollectionNames(sc, filter) return lcErr @@ -520,7 +520,7 @@ func executeListDatabaseNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) ( if sess != nil { var res []string - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var ldErr error res, ldErr = mt.Client.ListDatabaseNames(sc, filter) return ldErr @@ -549,7 +549,7 @@ func executeListDatabases(mt *mtest.T, sess *mongo.Session, args bson.Raw) (mong if sess != nil { var res mongo.ListDatabasesResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var ldErr error res, ldErr = mt.Client.ListDatabases(sc, filter) return ldErr @@ -578,7 +578,7 @@ func executeFindOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.Sing if sess != nil { var res *mongo.SingleResult - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + _ = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { res = mt.Coll.FindOne(sc, filter) return nil }) @@ -594,7 +594,7 @@ func executeListIndexes(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo assert.Equal(mt, 0, len(args), "unexpected listIndexes arguments: %v", args) if sess != nil { var cursor *mongo.Cursor - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var listErr error cursor, listErr = mt.Coll.Indexes().List(sc) return listErr @@ -631,7 +631,7 @@ func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]interfa if sess != nil { var res []interface{} - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var derr error res, derr = mt.Coll.Distinct(sc, fieldName, filter, opts) return derr @@ -671,7 +671,7 @@ func executeFindOneAndDelete(mt *mtest.T, sess *mongo.Session, args bson.Raw) *m if sess != nil { var res *mongo.SingleResult - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + _ = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { res = mt.Coll.FindOneAndDelete(sc, filter, opts) return nil }) @@ -728,7 +728,7 @@ func executeFindOneAndUpdate(mt *mtest.T, sess *mongo.Session, args bson.Raw) *m if sess != nil { var res *mongo.SingleResult - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + _ = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { res = mt.Coll.FindOneAndUpdate(sc, filter, update, opts) return nil }) @@ -781,7 +781,7 @@ func executeFindOneAndReplace(mt *mtest.T, sess *mongo.Session, args bson.Raw) * if sess != nil { var res *mongo.SingleResult - _ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + _ = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { res = mt.Coll.FindOneAndReplace(sc, filter, replacement, opts) return nil }) @@ -816,7 +816,7 @@ func executeDeleteOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.D if sess != nil { var res *mongo.DeleteResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var derr error res, derr = mt.Coll.DeleteOne(sc, filter, opts) return derr @@ -852,7 +852,7 @@ func executeDeleteMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo. if sess != nil { var res *mongo.DeleteResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var derr error res, derr = mt.Coll.DeleteMany(sc, filter, opts) return derr @@ -900,7 +900,7 @@ func executeUpdateOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.U if sess != nil { var res *mongo.UpdateResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var uerr error res, uerr = mt.Coll.UpdateOne(sc, filter, update, opts) return uerr @@ -948,7 +948,7 @@ func executeUpdateMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo. if sess != nil { var res *mongo.UpdateResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var uerr error res, uerr = mt.Coll.UpdateMany(sc, filter, update, opts) return uerr @@ -992,7 +992,7 @@ func executeReplaceOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo. if sess != nil { var res *mongo.UpdateResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var uerr error res, uerr = mt.Coll.ReplaceOne(sc, filter, replacement, opts) return uerr @@ -1045,7 +1045,7 @@ func executeWithTransaction(mt *mtest.T, sess *mongo.Session, args bson.Raw) err assert.Nil(mt, err, "error creating withTransactionArgs: %v", err) opts := createTransactionOptions(mt, testArgs.Options) - _, err = sess.WithTransaction(context.Background(), func(sc mongo.SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(context.Background(), func(sc context.Context) (interface{}, error) { err := runWithTransactionOperations(mt, testArgs.Callback.Operations, sess) return nil, err }, opts) @@ -1076,7 +1076,7 @@ func executeBulkWrite(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.B if sess != nil { var res *mongo.BulkWriteResult - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var bwerr error res, bwerr = mt.Coll.BulkWrite(sc, models, opts) return bwerr @@ -1205,7 +1205,7 @@ func executeEstimatedDocumentCount(mt *mtest.T, sess *mongo.Session, args bson.R if sess != nil { var res int64 - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var countErr error res, countErr = mt.Coll.EstimatedDocumentCount(sc) return countErr @@ -1279,7 +1279,7 @@ func executeCreateIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (string if sess != nil { var indexName string - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var indexErr error indexName, indexErr = mt.Coll.Indexes().CreateOne(sc, model) return indexErr @@ -1308,7 +1308,7 @@ func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.Raw if sess != nil { var res bson.Raw - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { var indexErr error res, indexErr = mt.Coll.Indexes().DropOne(sc, name) return indexErr @@ -1340,7 +1340,7 @@ func executeDropCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) erro coll := mt.DB.Collection(collName) if sess != nil { - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { return coll.Drop(sc, dco) }) return err @@ -1373,7 +1373,7 @@ func executeCreateCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) er } if sess != nil { - err := mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { return mt.DB.CreateCollection(sc, collName, cco) }) return err diff --git a/internal/integration/load_balancer_prose_test.go b/internal/integration/load_balancer_prose_test.go index c40c052993..fbc7569898 100644 --- a/internal/integration/load_balancer_prose_test.go +++ b/internal/integration/load_balancer_prose_test.go @@ -85,7 +85,7 @@ func TestLoadBalancerSupport(t *testing.T) { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - ctx := mongo.NewSessionContext(context.Background(), sess) + ctx := mongo.ContextWithSession(context.Background(), sess) // Start a transaction and perform one transactional operation to pin a connection. err = sess.StartTransaction() @@ -116,7 +116,7 @@ func TestLoadBalancerSupport(t *testing.T) { err = sess.StartTransaction() require.NoError(mt, err, "StartTransaction error") - ctx := mongo.NewSessionContext(ctx, sess) + ctx := mongo.ContextWithSession(ctx, sess) _, err = mt.Coll.InsertOne(ctx, bson.M{"x": 1}) assert.NoError(mt, err, "InsertOne error") diff --git a/internal/integration/mongos_pinning_test.go b/internal/integration/mongos_pinning_test.go index f91f16018e..76477c02be 100644 --- a/internal/integration/mongos_pinning_test.go +++ b/internal/integration/mongos_pinning_test.go @@ -31,7 +31,7 @@ func TestMongosPinning(t *testing.T) { mt.Run("unpin for next transaction", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx context.Context) error { sess := mongo.SessionFromContext(sctx) // Insert a document in a transaction to pin session to a mongos err := sess.StartTransaction() @@ -65,7 +65,7 @@ func TestMongosPinning(t *testing.T) { }) mt.Run("unpin for non transaction operation", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx context.Context) error { sess := mongo.SessionFromContext(sctx) // Insert a document in a transaction to pin session to a mongos diff --git a/internal/integration/retryable_writes_prose_test.go b/internal/integration/retryable_writes_prose_test.go index f415a0daa6..e661fcab5a 100644 --- a/internal/integration/retryable_writes_prose_test.go +++ b/internal/integration/retryable_writes_prose_test.go @@ -109,7 +109,7 @@ func TestRetryableWritesProse(t *testing.T) { mt.ClearEvents() - err = mongo.WithSession(context.Background(), sess, func(ctx mongo.SessionContext) error { + err = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { doc := bson.D{{"foo", 1}} _, err := mt.Coll.InsertOne(ctx, doc) return err diff --git a/internal/integration/sessions_mongocryptd_prose_test.go b/internal/integration/sessions_mongocryptd_prose_test.go index 55fabc1061..ebb56168ed 100644 --- a/internal/integration/sessions_mongocryptd_prose_test.go +++ b/internal/integration/sessions_mongocryptd_prose_test.go @@ -158,7 +158,7 @@ func TestSessionsMongocryptdProse(t *testing.T) { defer session.EndSession(context.Background()) - sessionCtx := mongo.NewSessionContext(context.TODO(), session) + sessionCtx := mongo.ContextWithSession(context.TODO(), session) err = session.StartTransaction() require.NoError(mt, err, "expected error to be nil, got %v", err) diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index a95af5b15c..dc81f883c1 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -37,7 +37,7 @@ func TestSessionPool(t *testing.T) { defer sess.EndSession(context.Background()) initialLastUsedTime := sess.ClientSession().LastUsed - err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { return mt.Client.Ping(sc, readpref.Primary()) }) assert.Nil(mt, err, "WithSession error: %v", err) @@ -62,7 +62,7 @@ func TestSessions(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - ctx := mongo.NewSessionContext(context.Background(), sess) + ctx := mongo.ContextWithSession(context.Background(), sess) gotSess := mongo.SessionFromContext(ctx) assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil") @@ -76,11 +76,11 @@ func TestSessions(t *testing.T) { mt.RunOpts("run transaction", txnOpts, func(mt *mtest.T) { // Test that the imperative sessions API can be used to run a transaction. - createSessionContext := func(mt *mtest.T) mongo.SessionContext { + createSessionContext := func(mt *mtest.T) context.Context { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - return mongo.NewSessionContext(context.Background(), sess) + return mongo.ContextWithSession(context.Background(), sess) } ctx := createSessionContext(mt) @@ -113,7 +113,7 @@ func TestSessions(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + err = mongo.WithSession(context.Background(), sess, func(sc context.Context) error { _, err := mt.Coll.InsertOne(sc, bson.D{{"x", 1}}) return err }) @@ -537,7 +537,7 @@ func (sf sessionFunction) execute(mt *mtest.T, sess *mongo.Session) error { paramsValues := interfaceSliceToValueSlice(sf.params) if sess != nil { - return mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { + return mongo.WithSession(context.Background(), sess, func(sc context.Context) error { valueArgs := []reflect.Value{reflect.ValueOf(sc)} valueArgs = append(valueArgs, paramsValues...) returnValues := fn.Call(valueArgs) diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 59aa36ae8c..b6701638e9 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -83,7 +83,7 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat if err != nil { return nil, err } - ctx = mongo.NewSessionContext(ctx, sess) + ctx = mongo.ContextWithSession(ctx, sess) // Set op.Arguments to a new document that has the "session" field removed so individual operations do // not have to account for it. diff --git a/internal/integration/unified/session_operation_execution.go b/internal/integration/unified/session_operation_execution.go index 98e576093d..f8c91eb1da 100644 --- a/internal/integration/unified/session_operation_execution.go +++ b/internal/integration/unified/session_operation_execution.go @@ -11,7 +11,6 @@ import ( "fmt" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -103,7 +102,7 @@ func executeWithTransaction(ctx context.Context, op *operation, loopDone <-chan return fmt.Errorf("error unmarshalling arguments to transactionOptions: %v", err) } - _, err = sess.WithTransaction(ctx, func(ctx mongo.SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(ctx, func(ctx context.Context) (interface{}, error) { for idx, oper := range operations { if err := oper.execute(ctx, loopDone); err != nil { return nil, fmt.Errorf("error executing operation %q at index %d: %v", oper.Name, idx, err) diff --git a/mongo/client.go b/mongo/client.go index 0bae485a22..6ee2913a5b 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -776,43 +776,50 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts return names, nil } -// WithSession creates a new SessionContext from the ctx and sess parameters and uses it to call the fn callback. The -// SessionContext must be used as the Context parameter for any operations in the fn callback that should be executed -// under the session. +// WithSession creates a new session context from the ctx and sess parameters +// and uses it to call the fn callback. // -// WithSession is safe to call from multiple goroutines concurrently. However, the SessionContext passed to the -// WithSession callback function is not safe for concurrent use by multiple goroutines. +// WithSession is safe to call from multiple goroutines concurrently. However, +// the context passed to the WithSession callback function is not safe for +// concurrent use by multiple goroutines. // -// If the ctx parameter already contains a Session, that Session will be replaced with the one provided. +// If the ctx parameter already contains a Session, that Session will be +// replaced with the one provided. // -// Any error returned by the fn callback will be returned without any modifications. -func WithSession(ctx context.Context, sess *Session, fn func(SessionContext) error) error { - return fn(NewSessionContext(ctx, sess)) +// Any error returned by the fn callback will be returned without any +// modifications. +func WithSession(ctx context.Context, sess *Session, fn func(context.Context) error) error { + return fn(ContextWithSession(ctx, sess)) } -// UseSession creates a new Session and uses it to create a new SessionContext, which is used to call the fn callback. -// The SessionContext parameter must be used as the Context parameter for any operations in the fn callback that should -// be executed under a session. After the callback returns, the created Session is ended, meaning that any in-progress -// transactions started by fn will be aborted even if fn returns an error. +// UseSession creates a new Session and uses it to create a new session context, +// which is used to call the fn callback. After the callback returns, the +// created Session is ended, meaning that any in-progress transactions started +// by fn will be aborted even if fn returns an error. // -// UseSession is safe to call from multiple goroutines concurrently. However, the SessionContext passed to the -// UseSession callback function is not safe for concurrent use by multiple goroutines. +// UseSession is safe to call from multiple goroutines concurrently. However, +// the context passed to the UseSession callback function is not safe for +// concurrent use by multiple goroutines. // -// If the ctx parameter already contains a Session, that Session will be replaced with the newly created one. +// If the ctx parameter already contains a Session, that Session will be +// replaced with the newly created one. // -// Any error returned by the fn callback will be returned without any modifications. -func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) error { +// Any error returned by the fn callback will be returned without any +// modifications. +func (c *Client) UseSession(ctx context.Context, fn func(context.Context) error) error { return c.UseSessionWithOptions(ctx, options.Session(), fn) } -// UseSessionWithOptions operates like UseSession but uses the given SessionOptions to create the Session. +// UseSessionWithOptions operates like UseSession but uses the given +// SessionOptions to create the Session. // -// UseSessionWithOptions is safe to call from multiple goroutines concurrently. However, the SessionContext passed to -// the UseSessionWithOptions callback function is not safe for concurrent use by multiple goroutines. +// UseSessionWithOptions is safe to call from multiple goroutines concurrently. +// However, the context passed to the UseSessionWithOptions callback function is +// not safe for concurrent use by multiple goroutines. func (c *Client) UseSessionWithOptions( ctx context.Context, opts *options.SessionOptions, - fn func(SessionContext) error, + fn func(context.Context) error, ) error { defaultSess, err := c.StartSession(opts) if err != nil { @@ -820,7 +827,7 @@ func (c *Client) UseSessionWithOptions( } defer defaultSess.EndSession(ctx) - return fn(NewSessionContext(ctx, defaultSess)) + return fn(ContextWithSession(ctx, defaultSess)) } // Watch returns a change stream for all changes on the deployment. See diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index a92b7509fe..8a771452f5 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -626,8 +626,8 @@ func ExampleWithSession() { err = mongo.WithSession( context.TODO(), sess, - func(ctx mongo.SessionContext) error { - // Use the mongo.SessionContext as the Context parameter for + func(ctx context.Context) error { + // Use the context.Context as the Context parameter for // InsertOne and FindOne so both operations are run under the new // Session. @@ -683,9 +683,9 @@ func ExampleClient_UseSessionWithOptions() { err := client.UseSessionWithOptions( context.TODO(), opts, - func(ctx mongo.SessionContext) error { + func(ctx context.Context) error { sess := mongo.SessionFromContext(ctx) - // Use the mongo.SessionContext as the Context parameter for + // Use the context.Context as the Context parameter for // InsertOne and FindOne so both operations are run under the new // Session. @@ -752,8 +752,8 @@ func ExampleClient_StartSession_withTransaction() { SetReadPreference(readpref.PrimaryPreferred()) result, err := sess.WithTransaction( context.TODO(), - func(ctx mongo.SessionContext) (interface{}, error) { - // Use the mongo.SessionContext as the Context parameter for + func(ctx context.Context) (interface{}, error) { + // Use the context.Context as the Context parameter for // InsertOne and FindOne so both operations are run in the same // transaction. @@ -780,7 +780,7 @@ func ExampleClient_StartSession_withTransaction() { fmt.Printf("result: %v\n", result) } -func ExampleNewSessionContext() { +func ExampleContextWithSession() { var client *mongo.Client // Create a new Session and SessionContext. @@ -789,9 +789,9 @@ func ExampleNewSessionContext() { panic(err) } defer sess.EndSession(context.TODO()) - ctx := mongo.NewSessionContext(context.TODO(), sess) + ctx := mongo.ContextWithSession(context.TODO(), sess) - // Start a transaction and use the mongo.SessionContext as the Context + // Start a transaction and use the context.Context as the Context // parameter for InsertOne and FindOne so both operations are run in the // transaction. if err = sess.StartTransaction(); err != nil { diff --git a/mongo/crypt_retrievers.go b/mongo/crypt_retrievers.go index 5e96da731a..a210a5c433 100644 --- a/mongo/crypt_retrievers.go +++ b/mongo/crypt_retrievers.go @@ -20,7 +20,7 @@ type keyRetriever struct { func (kr *keyRetriever) cryptKeys(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error) { // Remove the explicit session from the context if one is set. // The explicit session may be from a different client. - ctx = NewSessionContext(ctx, nil) + ctx = ContextWithSession(ctx, nil) cursor, err := kr.coll.Find(ctx, filter) if err != nil { return nil, EncryptionKeyVaultError{Wrapped: err} @@ -48,7 +48,7 @@ type collInfoRetriever struct { func (cir *collInfoRetriever) cryptCollInfo(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error) { // Remove the explicit session from the context if one is set. // The explicit session may be from a different client. - ctx = NewSessionContext(ctx, nil) + ctx = ContextWithSession(ctx, nil) cursor, err := cir.client.Database(db).ListCollections(ctx, filter) if err != nil { return nil, err diff --git a/mongo/database_test.go b/mongo/database_test.go index fb1ea3a426..2bfe9ff595 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -121,7 +121,7 @@ func TestDatabase(t *testing.T) { assert.Nil(t, err, "expected nil, got %v", err) defer sess.EndSession(bgCtx) - sessCtx := NewSessionContext(bgCtx, sess) + sessCtx := ContextWithSession(bgCtx, sess) err = sess.StartTransaction() assert.Nil(t, err, "expected nil, got %v", err) diff --git a/mongo/mongocryptd.go b/mongo/mongocryptd.go index efb283e208..9e20c5ab6d 100644 --- a/mongo/mongocryptd.go +++ b/mongo/mongocryptd.go @@ -88,7 +88,7 @@ func (mc *mongocryptdClient) markCommand(ctx context.Context, dbName string, cmd // Remove the explicit session from the context if one is set. // The explicit session will be from a different client. // If an explicit session is set, it is applied after automatic encryption. - ctx = NewSessionContext(ctx, nil) + ctx = ContextWithSession(ctx, nil) db := mc.client.Database(dbName, databaseOpts) res, err := db.RunCommand(ctx, cmd).Raw() diff --git a/mongo/session.go b/mongo/session.go index 45f224e5bd..ad6f7909a2 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -46,29 +46,12 @@ type Session struct { didCommitAfterStart bool // true if commit was called after start with no other operations } -// SessionContext combines the context.Context and mongo.Session interfaces. It should be used as the Context arguments -// to operations that should be executed in a session. -// -// Implementations of SessionContext are not safe for concurrent use by multiple goroutines. -// -// There are two ways to create a SessionContext and use it in a session/transaction. The first is to use one of the -// callback-based functions such as WithSession and UseSession. These functions create a SessionContext and pass it to -// the provided callback. The other is to use NewSessionContext to explicitly create a SessionContext. -type SessionContext interface { - context.Context -} - -type sessionCtx struct { - context.Context -} - type sessionKey struct{} -// NewSessionContext creates a new SessionContext associated with the given Context and Session parameters. -func NewSessionContext(ctx context.Context, sess *Session) SessionContext { - return &sessionCtx{ - Context: context.WithValue(ctx, sessionKey{}, sess), - } +// ContextWithSession creates a new SessionContext associated with the given +// Context and Session parameters. +func ContextWithSession(parent context.Context, sess *Session) context.Context { + return context.WithValue(parent, sessionKey{}, sess) } // SessionFromContext extracts the mongo.Session object stored in a Context. This can be used on a SessionContext that @@ -108,7 +91,7 @@ func (s *Session) EndSession(ctx context.Context) { } // WithTransaction implements the Session interface. -func (s *Session) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), +func (s *Session) WithTransaction(ctx context.Context, fn func(ctx context.Context) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { timeout := time.NewTimer(withTransactionTimeout) defer timeout.Stop() @@ -119,7 +102,7 @@ func (s *Session) WithTransaction(ctx context.Context, fn func(ctx SessionContex return nil, err } - res, err := fn(NewSessionContext(ctx, s)) + res, err := fn(ContextWithSession(ctx, s)) if err != nil { if s.clientSession.TransactionRunning() { // Wrap the user-provided Context in a new one that behaves like context.Background() for deadlines and diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index eacc12d864..478a1c1785 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -76,7 +76,7 @@ func TestConvenientTransactions(t *testing.T) { defer sess.EndSession(context.Background()) testErr := errors.New("test error") - _, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(context.Background(), func(context.Context) (interface{}, error) { return nil, testErr }) assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err) @@ -90,7 +90,7 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) { + res, err := sess.WithTransaction(context.Background(), func(context.Context) (interface{}, error) { return false, nil }) assert.Nil(t, err, "WithTransaction error: %v", err) @@ -110,7 +110,7 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(context.Background(), func(context.Context) (interface{}, error) { return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}} }) assert.NotNil(t, err, "expected WithTransaction error, got nil") @@ -142,7 +142,7 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(context.Background(), func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{"x", 1}}) return nil, err }) @@ -175,7 +175,7 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(context.Background(), func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{"x", 1}}) return nil, err }) @@ -240,7 +240,7 @@ func TestConvenientTransactions(t *testing.T) { // insert succeeds, it cancels the Context created above and returns a non-retryable error, which forces // WithTransaction to abort the txn. callbackErr := errors.New("error") - callback := func(sc SessionContext) (interface{}, error) { + callback := func(sc context.Context) (interface{}, error) { _, err = coll.InsertOne(sc, bson.D{{"x", 1}}) if err != nil { return nil, err @@ -306,7 +306,7 @@ func TestConvenientTransactions(t *testing.T) { defer session.EndSession(bgCtx) assert.Nil(t, err, "StartSession error: %v", err) - _ = WithSession(bgCtx, session, func(sessionContext SessionContext) error { + _ = WithSession(bgCtx, session, func(sessionContext context.Context) error { // Start transaction. err = session.StartTransaction() assert.Nil(t, err, "StartTransaction error: %v", err) @@ -402,7 +402,7 @@ func TestConvenientTransactions(t *testing.T) { callback := func(ctx context.Context) { transactionCtx, cancel := context.WithCancel(ctx) - _, _ = sess.WithTransaction(transactionCtx, func(ctx SessionContext) (interface{}, error) { + _, _ = sess.WithTransaction(transactionCtx, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.M{"x": 1}) assert.Nil(t, err, "InsertOne error: %v", err) cancel() @@ -426,7 +426,7 @@ func TestConvenientTransactions(t *testing.T) { // returnError tracks whether or not the callback is being retried returnError := true - res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) { + res, err := sess.WithTransaction(context.Background(), func(context.Context) (interface{}, error) { if returnError { returnError = false return nil, fmt.Errorf("%w", @@ -464,7 +464,7 @@ func TestConvenientTransactions(t *testing.T) { withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() - _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { + _, _ = sess.WithTransaction(withTransactionContext, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) @@ -494,7 +494,7 @@ func TestConvenientTransactions(t *testing.T) { withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second) cancel() - _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { + _, _ = sess.WithTransaction(withTransactionContext, func(ctx context.Context) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) @@ -541,7 +541,7 @@ func TestConvenientTransactions(t *testing.T) { defer sess.EndSession(context.Background()) callback := func(ctx context.Context) { - _, err = sess.WithTransaction(ctx, func(ctx SessionContext) (interface{}, error) { + _, err = sess.WithTransaction(ctx, func(ctx context.Context) (interface{}, error) { // Set a timeout of 300ms to cause a timeout on first insertOne // and force a retry. c, cancel := context.WithTimeout(ctx, 300*time.Millisecond) From a0131511f55328e58902b4e0e160a0402d2d2079 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 26 Apr 2024 17:56:22 -0600 Subject: [PATCH 3/9] GODRIVER-2800 Update ClientSession comment --- mongo/session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongo/session.go b/mongo/session.go index ecda85e470..8730d4dd7d 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -70,7 +70,7 @@ func SessionFromContext(ctx context.Context) *Session { return sess } -// ClientSession implements the XSession interface. +// ClientSession returns the experimental client session. func (s *Session) ClientSession() *session.Client { return s.clientSession } From 18151e8d361dcfeca70a706c5396ba09f10db230 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 10:25:40 -0600 Subject: [PATCH 4/9] GODRIVER-2800 Transfer session comments --- .../integration/load_balancer_prose_test.go | 4 +- .../sessions_mongocryptd_prose_test.go | 2 +- internal/integration/sessions_test.go | 4 +- internal/integration/unified/operation.go | 2 +- mongo/client.go | 4 +- mongo/crud_examples_test.go | 4 +- mongo/crypt_retrievers.go | 4 +- mongo/database_test.go | 2 +- mongo/mongocryptd.go | 2 +- mongo/session.go | 54 +++++++++++++------ mongo/with_transactions_test.go | 10 ++-- 11 files changed, 57 insertions(+), 35 deletions(-) diff --git a/internal/integration/load_balancer_prose_test.go b/internal/integration/load_balancer_prose_test.go index fbc7569898..c40c052993 100644 --- a/internal/integration/load_balancer_prose_test.go +++ b/internal/integration/load_balancer_prose_test.go @@ -85,7 +85,7 @@ func TestLoadBalancerSupport(t *testing.T) { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - ctx := mongo.ContextWithSession(context.Background(), sess) + ctx := mongo.NewSessionContext(context.Background(), sess) // Start a transaction and perform one transactional operation to pin a connection. err = sess.StartTransaction() @@ -116,7 +116,7 @@ func TestLoadBalancerSupport(t *testing.T) { err = sess.StartTransaction() require.NoError(mt, err, "StartTransaction error") - ctx := mongo.ContextWithSession(ctx, sess) + ctx := mongo.NewSessionContext(ctx, sess) _, err = mt.Coll.InsertOne(ctx, bson.M{"x": 1}) assert.NoError(mt, err, "InsertOne error") diff --git a/internal/integration/sessions_mongocryptd_prose_test.go b/internal/integration/sessions_mongocryptd_prose_test.go index ebb56168ed..55fabc1061 100644 --- a/internal/integration/sessions_mongocryptd_prose_test.go +++ b/internal/integration/sessions_mongocryptd_prose_test.go @@ -158,7 +158,7 @@ func TestSessionsMongocryptdProse(t *testing.T) { defer session.EndSession(context.Background()) - sessionCtx := mongo.ContextWithSession(context.TODO(), session) + sessionCtx := mongo.NewSessionContext(context.TODO(), session) err = session.StartTransaction() require.NoError(mt, err, "expected error to be nil, got %v", err) diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index dc81f883c1..d368e4a76a 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -62,7 +62,7 @@ func TestSessions(t *testing.T) { assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - ctx := mongo.ContextWithSession(context.Background(), sess) + ctx := mongo.NewSessionContext(context.Background(), sess) gotSess := mongo.SessionFromContext(ctx) assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil") @@ -80,7 +80,7 @@ func TestSessions(t *testing.T) { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - return mongo.ContextWithSession(context.Background(), sess) + return mongo.NewSessionContext(context.Background(), sess) } ctx := createSessionContext(mt) diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index b6701638e9..59aa36ae8c 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -83,7 +83,7 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat if err != nil { return nil, err } - ctx = mongo.ContextWithSession(ctx, sess) + ctx = mongo.NewSessionContext(ctx, sess) // Set op.Arguments to a new document that has the "session" field removed so individual operations do // not have to account for it. diff --git a/mongo/client.go b/mongo/client.go index f4bfb9f991..36f6fbc35f 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -788,7 +788,7 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts // Any error returned by the fn callback will be returned without any // modifications. func WithSession(ctx context.Context, sess *Session, fn func(context.Context) error) error { - return fn(ContextWithSession(ctx, sess)) + return fn(NewSessionContext(ctx, sess)) } // UseSession creates a new Session and uses it to create a new session context, @@ -826,7 +826,7 @@ func (c *Client) UseSessionWithOptions( } defer defaultSess.EndSession(ctx) - return fn(ContextWithSession(ctx, defaultSess)) + return fn(NewSessionContext(ctx, defaultSess)) } // Watch returns a change stream for all changes on the deployment. See diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index 4581ab5c3a..e17be4bce4 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -779,7 +779,7 @@ func ExampleClient_StartSession_withTransaction() { fmt.Printf("result: %v\n", result) } -func ExampleContextWithSession() { +func ExampleNewSessionContext() { var client *mongo.Client // Create a new Session and SessionContext. @@ -788,7 +788,7 @@ func ExampleContextWithSession() { panic(err) } defer sess.EndSession(context.TODO()) - ctx := mongo.ContextWithSession(context.TODO(), sess) + ctx := mongo.NewSessionContext(context.TODO(), sess) // Start a transaction and use the context.Context as the Context // parameter for InsertOne and FindOne so both operations are run in the diff --git a/mongo/crypt_retrievers.go b/mongo/crypt_retrievers.go index a210a5c433..5e96da731a 100644 --- a/mongo/crypt_retrievers.go +++ b/mongo/crypt_retrievers.go @@ -20,7 +20,7 @@ type keyRetriever struct { func (kr *keyRetriever) cryptKeys(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error) { // Remove the explicit session from the context if one is set. // The explicit session may be from a different client. - ctx = ContextWithSession(ctx, nil) + ctx = NewSessionContext(ctx, nil) cursor, err := kr.coll.Find(ctx, filter) if err != nil { return nil, EncryptionKeyVaultError{Wrapped: err} @@ -48,7 +48,7 @@ type collInfoRetriever struct { func (cir *collInfoRetriever) cryptCollInfo(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error) { // Remove the explicit session from the context if one is set. // The explicit session may be from a different client. - ctx = ContextWithSession(ctx, nil) + ctx = NewSessionContext(ctx, nil) cursor, err := cir.client.Database(db).ListCollections(ctx, filter) if err != nil { return nil, err diff --git a/mongo/database_test.go b/mongo/database_test.go index c3896345d0..31bd900439 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -120,7 +120,7 @@ func TestDatabase(t *testing.T) { assert.Nil(t, err, "expected nil, got %v", err) defer sess.EndSession(bgCtx) - sessCtx := ContextWithSession(bgCtx, sess) + sessCtx := NewSessionContext(bgCtx, sess) err = sess.StartTransaction() assert.Nil(t, err, "expected nil, got %v", err) diff --git a/mongo/mongocryptd.go b/mongo/mongocryptd.go index 9e20c5ab6d..efb283e208 100644 --- a/mongo/mongocryptd.go +++ b/mongo/mongocryptd.go @@ -88,7 +88,7 @@ func (mc *mongocryptdClient) markCommand(ctx context.Context, dbName string, cmd // Remove the explicit session from the context if one is set. // The explicit session will be from a different client. // If an explicit session is set, it is applied after automatic encryption. - ctx = ContextWithSession(ctx, nil) + ctx = NewSessionContext(ctx, nil) db := mc.client.Database(dbName, databaseOpts) res, err := db.RunCommand(ctx, cmd).Raw() diff --git a/mongo/session.go b/mongo/session.go index 8730d4dd7d..dc46ef6734 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -47,9 +47,9 @@ type Session struct { type sessionKey struct{} -// ContextWithSession creates a new SessionContext associated with the given +// NewSessionContext creates a new SessionContext associated with the given // Context and Session parameters. -func ContextWithSession(parent context.Context, sess *Session) context.Context { +func NewSessionContext(parent context.Context, sess *Session) context.Context { return context.WithValue(parent, sessionKey{}, sess) } @@ -75,12 +75,13 @@ func (s *Session) ClientSession() *session.Client { return s.clientSession } -// ID implements the Session interface. +// ID returns the current ID document associated with the session. The ID +// 1document is in the form {"id": }. func (s *Session) ID() bson.Raw { return bson.Raw(s.clientSession.SessionID) } -// EndSession implements the Session interface. +// EndSession aborts any existing transactions and close the session. func (s *Session) EndSession(ctx context.Context) { if s.clientSession.TransactionInProgress() { // ignore all errors aborting during an end session @@ -89,7 +90,20 @@ func (s *Session) EndSession(ctx context.Context) { s.clientSession.EndSession() } -// WithTransaction implements the Session interface. +// WithTransaction starts a transaction on this session and runs the fn +// callback. Errors with the TransientTransactionError and +// UnknownTransactionCommitResult labels are retried for up to 120 seconds. +// Inside the callback, the SessionContext must be used as the Context parameter +// for any operations that should be part of the transaction. If the ctx +// parameter already has a Session attached to it, it will be replaced by this +// session. The fn callback may be run multiple times during WithTransaction due +// to retry attempts, so it must be idempotent. Non-retryable operation errors +// or any operation errors that occur after the timeout expires will be returned +// without retrying. If the callback fails, the driver will call +// AbortTransaction. Because this method must succeed to ensure that server-side +// resources are properly cleaned up, context deadlines and cancellations will +// not be respected during this call. For a usage example, see the +// Client.StartSession method documentation. func (s *Session) WithTransaction(ctx context.Context, fn func(ctx context.Context) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { timeout := time.NewTimer(withTransactionTimeout) @@ -101,7 +115,7 @@ func (s *Session) WithTransaction(ctx context.Context, fn func(ctx context.Conte return nil, err } - res, err := fn(ContextWithSession(ctx, s)) + res, err := fn(NewSessionContext(ctx, s)) if err != nil { if s.clientSession.TransactionRunning() { // Wrap the user-provided Context in a new one that behaves like context.Background() for deadlines and @@ -169,7 +183,8 @@ func (s *Session) WithTransaction(ctx context.Context, fn func(ctx context.Conte } } -// StartTransaction implements the Session interface. +// StartTransaction starts a new transaction. This method returns an error if +// there is already a transaction in-progress for this session. func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { err := s.clientSession.CheckStartTransaction() if err != nil { @@ -206,7 +221,9 @@ func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { return s.clientSession.StartTransaction(coreOpts) } -// AbortTransaction implements the Session interface. +// AbortTransaction aborts the active transaction for this session. This method +// returns an error if there is no active transaction for this session or if the +// transaction has been committed or aborted. func (s *Session) AbortTransaction(ctx context.Context) error { err := s.clientSession.CheckAbortTransaction() if err != nil { @@ -232,7 +249,9 @@ func (s *Session) AbortTransaction(ctx context.Context) error { return nil } -// CommitTransaction implements the Session interface. +// CommitTransaction commits the active transaction for this session. This +// method returns an error if there is no active transaction for this session or +// if the transaction has been aborted. func (s *Session) CommitTransaction(ctx context.Context) error { err := s.clientSession.CheckCommitTransaction() if err != nil { @@ -276,27 +295,31 @@ func (s *Session) CommitTransaction(ctx context.Context) error { return commitErr } -// ClusterTime implements the Session interface. +// ClusterTime returns the current cluster time document associated with the +// session. func (s *Session) ClusterTime() bson.Raw { return s.clientSession.ClusterTime } -// AdvanceClusterTime implements the Session interface. +// AdvanceClusterTime advances the cluster time for a session. This method +// returns an error if the session has ended. func (s *Session) AdvanceClusterTime(d bson.Raw) error { return s.clientSession.AdvanceClusterTime(d) } -// OperationTime implements the Session interface. +// OperationTime returns the current operation time document associated with the +// session. func (s *Session) OperationTime() *bson.Timestamp { return s.clientSession.OperationTime } -// AdvanceOperationTime implements the Session interface. +// AdvanceOperationTime advances the operation time for a session. This method +// returns an error if the session has ended. func (s *Session) AdvanceOperationTime(ts *bson.Timestamp) error { return s.clientSession.AdvanceOperationTime(ts) } -// Client implements the Session interface. +// Client the Client associated with the session. func (s *Session) Client() *Client { return s.client } @@ -304,8 +327,7 @@ func (s *Session) Client() *Client { // sessionFromContext checks for a sessionImpl in the argued context and returns the session if it // exists func sessionFromContext(ctx context.Context) *session.Client { - s := ctx.Value(sessionKey{}) - if ses, ok := s.(*Session); ses != nil && ok { + if ses := SessionFromContext(ctx); ses != nil { return ses.clientSession } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 478a1c1785..30835fc5e9 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -240,8 +240,8 @@ func TestConvenientTransactions(t *testing.T) { // insert succeeds, it cancels the Context created above and returns a non-retryable error, which forces // WithTransaction to abort the txn. callbackErr := errors.New("error") - callback := func(sc context.Context) (interface{}, error) { - _, err = coll.InsertOne(sc, bson.D{{"x", 1}}) + callback := func(ctx context.Context) (interface{}, error) { + _, err = coll.InsertOne(ctx, bson.D{{"x", 1}}) if err != nil { return nil, err } @@ -306,17 +306,17 @@ func TestConvenientTransactions(t *testing.T) { defer session.EndSession(bgCtx) assert.Nil(t, err, "StartSession error: %v", err) - _ = WithSession(bgCtx, session, func(sessionContext context.Context) error { + _ = WithSession(bgCtx, session, func(ctx context.Context) error { // Start transaction. err = session.StartTransaction() assert.Nil(t, err, "StartTransaction error: %v", err) // Insert a document. - _, err := coll.InsertOne(sessionContext, bson.D{{"val", 17}}) + _, err := coll.InsertOne(ctx, bson.D{{"val", 17}}) assert.Nil(t, err, "InsertOne error: %v", err) // Set a timeout of 0 for commitTransaction. - commitTimeoutCtx, commitCancel := context.WithTimeout(sessionContext, 0) + commitTimeoutCtx, commitCancel := context.WithTimeout(ctx, 0) defer commitCancel() // CommitTransaction results in context.DeadlineExceeded. From 67fdcabeb1b3e201d1e407a7bdbb8317245c3edd Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 10:30:04 -0600 Subject: [PATCH 5/9] GODRIVER-2800 Fix typo --- mongo/session.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mongo/session.go b/mongo/session.go index dc46ef6734..3e5243bb44 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -76,7 +76,7 @@ func (s *Session) ClientSession() *session.Client { } // ID returns the current ID document associated with the session. The ID -// 1document is in the form {"id": }. +// document is in the form {"id": }. func (s *Session) ID() bson.Raw { return bson.Raw(s.clientSession.SessionID) } @@ -319,7 +319,7 @@ func (s *Session) AdvanceOperationTime(ts *bson.Timestamp) error { return s.clientSession.AdvanceOperationTime(ts) } -// Client the Client associated with the session. +// Client is the Client associated with the session. func (s *Session) Client() *Client { return s.client } From 679484c5d83f6f3a4e1c4b164ecc2023a9562bfa Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 10:31:53 -0600 Subject: [PATCH 6/9] GODRIVER-2800 Remove ClientSession --- mongo/session.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mongo/session.go b/mongo/session.go index 3e5243bb44..529c7c6d5f 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -70,11 +70,6 @@ func SessionFromContext(ctx context.Context) *Session { return sess } -// ClientSession returns the experimental client session. -func (s *Session) ClientSession() *session.Client { - return s.clientSession -} - // ID returns the current ID document associated with the session. The ID // document is in the form {"id": }. func (s *Session) ID() bson.Raw { From ed70f1f06bef4648ceda1295fe17217df25c7e9e Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 11:20:55 -0600 Subject: [PATCH 7/9] GODRIVER-2800 Add deprecated warning to session.ClientSession --- mongo/session.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mongo/session.go b/mongo/session.go index 529c7c6d5f..e743de703c 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -70,6 +70,14 @@ func SessionFromContext(ctx context.Context) *Session { return sess } +// ClientSession returns the experimental client session. +// +// Deprecated: This method is for internal use only and should not be used (see +// GODRIVER-2700). It may be changed or removed in any release. +func (s *Session) ClientSession() *session.Client { + return s.clientSession +} + // ID returns the current ID document associated with the session. The ID // document is in the form {"id": }. func (s *Session) ID() bson.Raw { From 67a0db7bec85910cf1bbabf4b867c8fbdf3ec7c7 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 13:52:45 -0600 Subject: [PATCH 8/9] Update mongo/session.go Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- mongo/session.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mongo/session.go b/mongo/session.go index e743de703c..bbcdf6a7f5 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -47,8 +47,12 @@ type Session struct { type sessionKey struct{} -// NewSessionContext creates a new SessionContext associated with the given -// Context and Session parameters. +// NewSessionContext returns a Context that holds the given Session. If the +// Context already contains a Session, that Session will be replaced with the +// one provided. +// +// The returned Context can be used with Collection methods like +// [Collection.InsertOne] or [Collection.Find] to run operations in a Session. func NewSessionContext(parent context.Context, sess *Session) context.Context { return context.WithValue(parent, sessionKey{}, sess) } From 2502638b991524098830192a9893cf1a51170665 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Mon, 29 Apr 2024 13:53:12 -0600 Subject: [PATCH 9/9] GODRIVER-2800 Fix compile errors in CSE tests --- internal/integration/client_side_encryption_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index 0b066ad774..10041a9bc6 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -143,7 +143,7 @@ func TestClientSideEncryptionWithExplicitSessions(t *testing.T) { session, err := client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - sessionCtx := mongo.ContextWithSession(context.Background(), session) + sessionCtx := mongo.NewSessionContext(context.Background(), session) capturedEvents = make([]event.CommandStartedEvent, 0) _, err = coll.InsertOne(sessionCtx, bson.D{{"encryptMe", "test"}, {"keyName", "myKey"}}) @@ -207,7 +207,7 @@ func TestClientSideEncryptionWithExplicitSessions(t *testing.T) { session, err := client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) - sessionCtx := mongo.ContextWithSession(context.Background(), session) + sessionCtx := mongo.NewSessionContext(context.Background(), session) capturedEvents = make([]event.CommandStartedEvent, 0) res := coll.FindOne(sessionCtx, bson.D{{}})