diff --git a/integration/commands_administration_test.go b/integration/commands_administration_test.go index db4b9245e849..b0ed592e1be5 100644 --- a/integration/commands_administration_test.go +++ b/integration/commands_administration_test.go @@ -252,7 +252,6 @@ func TestCommandsAdministrationListDatabases(t *testing.T) { } for name, tc := range testCases { - tc, name := tc, name t.Run(name, func(t *testing.T) { t.Parallel() @@ -858,7 +857,7 @@ func TestGetParameterCommandAuthenticationMechanisms(t *testing.T) { require.NoError(t, err) expected := bson.D{ - {"authenticationMechanisms", bson.A{"PLAIN"}}, + {"authenticationMechanisms", bson.A{"SCRAM-SHA-1", "SCRAM-SHA-256", "PLAIN"}}, {"ok", float64(1)}, } require.Equal(t, expected, res) diff --git a/integration/create_test.go b/integration/create_test.go index 2b2cc836ff96..47b2f70bda79 100644 --- a/integration/create_test.go +++ b/integration/create_test.go @@ -333,7 +333,6 @@ func TestCreateCappedCommandInvalidSpec(t *testing.T) { }, }, } { - tc, name := tc, name t.Run(name, func(t *testing.T) { t.Parallel() diff --git a/integration/hello_command_test.go b/integration/hello_command_test.go new file mode 100644 index 000000000000..081e55d73c69 --- /dev/null +++ b/integration/hello_command_test.go @@ -0,0 +1,190 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + + "github.com/FerretDB/FerretDB/integration/setup" + "github.com/FerretDB/FerretDB/integration/shareddata" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" + "github.com/FerretDB/FerretDB/internal/util/testutil/testtb" +) + +func TestHello(t *testing.T) { + t.Parallel() + + ctx, collection := setup.Setup(t, shareddata.Scalars, shareddata.Composites) + db := collection.Database() + + var res bson.D + + require.NoError(t, db.RunCommand(ctx, bson.D{ + {"hello", "1"}, + }).Decode(&res)) + + actual := ConvertDocument(t, res) + + assert.Equal(t, must.NotFail(actual.Get("isWritablePrimary")), true) + assert.Equal(t, must.NotFail(actual.Get("maxBsonObjectSize")), int32(16777216)) + assert.Equal(t, must.NotFail(actual.Get("maxMessageSizeBytes")), int32(48000000)) + assert.Equal(t, must.NotFail(actual.Get("maxMessageSizeBytes")), int32(48000000)) + assert.Equal(t, must.NotFail(actual.Get("maxWriteBatchSize")), int32(100000)) + assert.IsType(t, must.NotFail(actual.Get("localTime")), time.Time{}) + assert.IsType(t, must.NotFail(actual.Get("connectionId")), int32(1)) + assert.Equal(t, must.NotFail(actual.Get("minWireVersion")), int32(0)) + assert.Equal(t, must.NotFail(actual.Get("maxWireVersion")), int32(21)) + assert.Equal(t, must.NotFail(actual.Get("readOnly")), false) + assert.Equal(t, must.NotFail(actual.Get("ok")), float64(1)) +} + +func TestHelloWithSupportedMechs(t *testing.T) { + t.Parallel() + + ctx, collection := setup.Setup(t, shareddata.Scalars, shareddata.Composites) + db := collection.Database() + + usersPayload := []bson.D{ + { + {"createUser", "hello_user"}, + {"roles", bson.A{}}, + {"pwd", "hello_password"}, + }, + { + {"createUser", "hello_user_scram1"}, + {"roles", bson.A{}}, + {"pwd", "hello_password"}, + {"mechanisms", bson.A{"SCRAM-SHA-1"}}, + }, + { + {"createUser", "hello_user_scram256"}, + {"roles", bson.A{}}, + {"pwd", "hello_password"}, + {"mechanisms", bson.A{"SCRAM-SHA-256"}}, + }, + } + + if !setup.IsMongoDB(t) { + usersPayload = append(usersPayload, primitive.D{ + {"createUser", "hello_user_plain"}, + {"roles", bson.A{}}, + {"pwd", "hello_password"}, + {"mechanisms", bson.A{"PLAIN"}}, + }) + } + + for _, u := range usersPayload { + require.NoError(t, db.RunCommand(ctx, u).Err()) + } + + testCases := map[string]struct { //nolint:vet // used for test only + user string + mechs *types.Array + + err *mongo.CommandError + failsForMongoDB string + }{ + "NotFound": { + user: db.Name() + ".not_found", + }, + "AnotherDB": { + user: db.Name() + "_not_found.another_db", + }, + "HelloUser": { + user: db.Name() + ".hello_user", + mechs: must.NotFail(types.NewArray("SCRAM-SHA-1", "SCRAM-SHA-256")), + }, + "HelloUserPlain": { + user: db.Name() + ".hello_user_plain", + mechs: must.NotFail(types.NewArray("PLAIN")), + failsForMongoDB: "PLAIN authentication mechanism is not support by MongoDB", + }, + "HelloUserSCRAM1": { + user: db.Name() + ".hello_user_scram1", + mechs: must.NotFail(types.NewArray("SCRAM-SHA-1")), + }, + "HelloUserSCRAM256": { + user: db.Name() + ".hello_user_scram256", + mechs: must.NotFail(types.NewArray("SCRAM-SHA-256")), + }, + "EmptyUsername": { + user: db.Name() + ".", + mechs: nil, + }, + "MissingSeparator": { + user: db.Name(), + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "UserName must contain a '.' separated database.user pair", + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + tt.Parallel() + + var t testtb.TB = tt + + if tc.failsForMongoDB != "" { + t = setup.FailsForMongoDB(t, tc.failsForMongoDB) + } + + var res bson.D + + err := db.RunCommand(ctx, bson.D{ + {"hello", "1"}, + {"saslSupportedMechs", tc.user}, + }).Decode(&res) + + if tc.err != nil { + AssertEqualCommandError(t, *tc.err, err) + return + } + + actual := ConvertDocument(t, res) + + assert.Equal(t, must.NotFail(actual.Get("isWritablePrimary")), true) + assert.Equal(t, must.NotFail(actual.Get("maxBsonObjectSize")), int32(16777216)) + assert.Equal(t, must.NotFail(actual.Get("maxMessageSizeBytes")), int32(48000000)) + assert.Equal(t, must.NotFail(actual.Get("maxMessageSizeBytes")), int32(48000000)) + assert.Equal(t, must.NotFail(actual.Get("maxWriteBatchSize")), int32(100000)) + assert.IsType(t, must.NotFail(actual.Get("localTime")), time.Time{}) + assert.IsType(t, must.NotFail(actual.Get("connectionId")), int32(1)) + assert.Equal(t, must.NotFail(actual.Get("minWireVersion")), int32(0)) + assert.Equal(t, must.NotFail(actual.Get("maxWireVersion")), int32(21)) + assert.Equal(t, must.NotFail(actual.Get("readOnly")), false) + assert.Equal(t, must.NotFail(actual.Get("ok")), float64(1)) + + if tc.mechs == nil { + assert.False(t, actual.Has("saslSupportedMechs")) + return + } + + mechanisms, err := actual.Get("saslSupportedMechs") + require.NoError(t, err) + assert.True(t, mechanisms.(*types.Array).ContainsAll(tc.mechs)) + }) + } +} diff --git a/integration/query_test.go b/integration/query_test.go index dcb83fc7d1fb..2c5c3e5487ce 100644 --- a/integration/query_test.go +++ b/integration/query_test.go @@ -728,7 +728,6 @@ func TestQueryCommandLimitPushDown(t *testing.T) { limitPushdown: noPushdown, }, } { - tc, name := tc, name t.Run(name, func(t *testing.T) { t.Parallel() diff --git a/integration/setup/helpers.go b/integration/setup/helpers.go index 07abd7ed8fc0..f6d3dc6bf3be 100644 --- a/integration/setup/helpers.go +++ b/integration/setup/helpers.go @@ -89,7 +89,7 @@ func FailsForMongoDB(tb testtb.TB, reason string) testtb.TB { // SkipForMongoDB skips the current test for MongoDB. // -// Use [FailsForMongoDB] in new code. +// Deprecated: Use [FailsForMongoDB] in new code. func SkipForMongoDB(tb testtb.TB, reason string) { tb.Helper() diff --git a/integration/users/usersinfo_test.go b/integration/users/usersinfo_test.go index e41b1f8c8ce0..6bfdb77483ac 100644 --- a/integration/users/usersinfo_test.go +++ b/integration/users/usersinfo_test.go @@ -27,6 +27,7 @@ import ( "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/util/testutil" + "github.com/FerretDB/FerretDB/internal/util/testutil/testtb" ) // createUser creates a bson.D command payload to create an user with the given username and password. @@ -87,6 +88,28 @@ func TestUsersinfo(t *testing.T) { }, }, }, + { + dbSuffix: "allbackends", + payloads: []bson.D{ + { + {"createUser", "WithSCRAMSHA1"}, + {"roles", bson.A{}}, + {"pwd", "pwd1"}, + {"mechanisms", bson.A{"SCRAM-SHA-1"}}, + }, + }, + }, + { + dbSuffix: "allbackends", + payloads: []bson.D{ + { + {"createUser", "WithSCRAMSHA256"}, + {"roles", bson.A{}}, + {"pwd", "pwd1"}, + {"mechanisms", bson.A{"SCRAM-SHA-256"}}, + }, + }, + }, } dbPrefix := testutil.DatabaseName(t) @@ -117,13 +140,14 @@ func TestUsersinfo(t *testing.T) { } testCases := map[string]struct { //nolint:vet // for readability - dbSuffix string - payload bson.D - err *mongo.CommandError - altMessage string - expected bson.D - hasUser map[string]struct{} - skipForMongoDB string // optional, skip test for MongoDB backend with a specific reason + dbSuffix string + payload bson.D + err *mongo.CommandError + altMessage string + expected bson.D + hasUser map[string]struct{} + showCredentials []string // showCredentials list the credentials types expected to be returned + failsForMongoDB string }{ "NoUserFound": { dbSuffix: "no_users", @@ -182,6 +206,7 @@ func TestUsersinfo(t *testing.T) { {"usersInfo", "WithPLAIN"}, {"showCredentials", true}, }, + showCredentials: []string{"PLAIN"}, expected: bson.D{ {"users", bson.A{ bson.D{ @@ -193,7 +218,45 @@ func TestUsersinfo(t *testing.T) { }}, {"ok", float64(1)}, }, - skipForMongoDB: "Only MongoDB Enterprise offers PLAIN", + failsForMongoDB: "Only MongoDB Enterprise offers PLAIN", + }, + "WithSCRAMSHA1": { + dbSuffix: "allbackends", + payload: bson.D{ + {"usersInfo", "WithSCRAMSHA1"}, + {"showCredentials", true}, + }, + showCredentials: []string{"SCRAM-SHA-1"}, + expected: bson.D{ + {"users", bson.A{ + bson.D{ + {"_id", "TestUsersinfo.WithSCRAMSHA1"}, + {"user", "scramsha1"}, + {"db", "TestUsersinfo"}, + {"roles", bson.A{}}, + }, + }}, + {"ok", float64(1)}, + }, + }, + "WithSCRAMSHA256": { + dbSuffix: "allbackends", + payload: bson.D{ + {"usersInfo", "WithSCRAMSHA256"}, + {"showCredentials", true}, + }, + showCredentials: []string{"SCRAM-SHA-256"}, + expected: bson.D{ + {"users", bson.A{ + bson.D{ + {"_id", "TestUsersinfo.WithSCRAMSHA256"}, + {"user", "scramsha256"}, + {"db", "TestUsersinfo"}, + {"roles", bson.A{}}, + }, + }}, + {"ok", float64(1)}, + }, }, "FromSameDatabase": { dbSuffix: "_example", @@ -469,12 +532,13 @@ func TestUsersinfo(t *testing.T) { } for name, tc := range testCases { - name, tc := name, tc - t.Run(name, func(t *testing.T) { - t.Parallel() + t.Run(name, func(tt *testing.T) { + tt.Parallel() + + var t testtb.TB = tt - if tc.skipForMongoDB != "" { - setup.SkipForMongoDB(t, tc.skipForMongoDB) + if tc.failsForMongoDB != "" { + t = setup.FailsForMongoDB(t, tc.failsForMongoDB) } var res bson.D @@ -512,43 +576,41 @@ func TestUsersinfo(t *testing.T) { require.True(t, (tc.hasUser == nil) != (tc.expected == nil)) - payload := integration.ConvertDocument(t, tc.payload) - var showCredentials bool + id, err := actualUser.Get("_id") + require.NoError(t, err) - if payload.Has("showCredentials") { - showCredentials = must.NotFail(payload.Get("showCredentials")).(bool) - } - if showCredentials { - if !setup.IsMongoDB(t) { - cred, ok := actualUser.Get("credentials") - assert.Nil(t, ok, "credentials not found") - assertPlainCredentials(t, "PLAIN", cred.(*types.Document)) - } - } else { - assert.False(t, actualUser.Has("credentials")) - } + // when `forAllDBs` is set true, it may contain more users from other databases, + // so we check expected users were found rather than exact match + foundUsers[id.(string)] = struct{}{} - if payload.Has("mechanisms") { - payloadMechanisms := must.NotFail(payload.Get("mechanisms")).(*types.Array) + userIDV, err := actualUser.Get("userId") + require.NoError(t, err) - if payloadMechanisms.Contains("PLAIN") { - assertPlainCredentials(t, "PLAIN", must.NotFail(actualUser.Get("credentials")).(*types.Document)) - } + userID := userIDV.(types.Binary) + assert.Equal(t, userID.Subtype.String(), types.BinaryUUID.String(), "uuid subtype") + assert.Equal(t, 16, len(userID.B), "UUID length") - if payloadMechanisms.Contains("SCRAM-SHA-1") { - assertSCRAMSHA1Credentials(t, "SCRAM-SHA-1", must.NotFail(actualUser.Get("credentials")).(*types.Document)) - } + if tc.showCredentials == nil { + assert.False(t, actualUser.Has("credentials")) - if payloadMechanisms.Contains("SCRAM-SHA-256") { - assertSCRAMSHA256Credentials(t, "SCRAM-SHA-256", must.NotFail(actualUser.Get("credentials")).(*types.Document)) - } + continue } - foundUsers[must.NotFail(actualUser.Get("_id")).(string)] = struct{}{} + credV, err := actualUser.Get("credentials") + require.NoError(t, err) - uuid := must.NotFail(actualUser.Get("userId")).(types.Binary) - assert.Equal(t, uuid.Subtype.String(), types.BinaryUUID.String(), "uuid subtype") - assert.Equal(t, 16, len(uuid.B), "UUID length") + cred := credV.(*types.Document) + + for _, typ := range tc.showCredentials { + switch typ { + case "PLAIN": + assertPlainCredentials(t, "PLAIN", cred) + case "SCRAM-SHA-1": + assertSCRAMSHA1Credentials(t, "SCRAM-SHA-1", cred) + case "SCRAM-SHA-256": + assertSCRAMSHA256Credentials(t, "SCRAM-SHA-256", cred) + } + } } if tc.hasUser != nil { diff --git a/internal/handler/handlerparams/typecode_test.go b/internal/handler/handlerparams/typecode_test.go index b57a60487cca..5fe5ec084223 100644 --- a/internal/handler/handlerparams/typecode_test.go +++ b/internal/handler/handlerparams/typecode_test.go @@ -51,7 +51,6 @@ func TestHasSameTypeElements(t *testing.T) { same: true, }, } { - tc, name := tc, name t.Run(name, func(t *testing.T) { t.Parallel() diff --git a/internal/handler/msg_getparameter.go b/internal/handler/msg_getparameter.go index 47180050a34e..1194d7eca763 100644 --- a/internal/handler/msg_getparameter.go +++ b/internal/handler/msg_getparameter.go @@ -52,7 +52,7 @@ func (h *Handler) MsgGetParameter(ctx context.Context, msg *wire.OpMsg) (*wire.O // "settableAtStartup", , //)), "authenticationMechanisms", must.NotFail(types.NewDocument( - "value", must.NotFail(types.NewArray("PLAIN")), + "value", must.NotFail(types.NewArray("SCRAM-SHA-1", "SCRAM-SHA-256", "PLAIN")), "settableAtRuntime", false, "settableAtStartup", true, )), diff --git a/internal/handler/msg_hello.go b/internal/handler/msg_hello.go index 9845961f024c..bc7fe2de60ca 100644 --- a/internal/handler/msg_hello.go +++ b/internal/handler/msg_hello.go @@ -16,10 +16,14 @@ package handler import ( "context" + "errors" + "strings" "time" "github.com/FerretDB/FerretDB/internal/handler/common" + "github.com/FerretDB/FerretDB/internal/handler/handlererrors" "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/wire" @@ -36,21 +40,110 @@ func (h *Handler) MsgHello(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, e return nil, lazyerrors.Error(err) } + saslSupportedMechs, err := common.GetOptionalParam(doc, "saslSupportedMechs", "") + if err != nil { + return nil, lazyerrors.Error(err) + } + var reply wire.OpMsg - must.NoError(reply.SetSections(wire.MakeOpMsgSection( - must.NotFail(types.NewDocument( - "isWritablePrimary", true, - "maxBsonObjectSize", int32(types.MaxDocumentLen), - "maxMessageSizeBytes", int32(wire.MaxMsgLen), - "maxWriteBatchSize", int32(100000), - "localTime", time.Now(), - "connectionId", int32(42), - "minWireVersion", common.MinWireVersion, - "maxWireVersion", common.MaxWireVersion, - "readOnly", false, - "ok", float64(1), - )), - ))) + resp := must.NotFail(types.NewDocument( + "isWritablePrimary", true, + "maxBsonObjectSize", int32(types.MaxDocumentLen), + "maxMessageSizeBytes", int32(wire.MaxMsgLen), + "maxWriteBatchSize", int32(100000), + "localTime", time.Now(), + "connectionId", int32(42), + "minWireVersion", common.MinWireVersion, + "maxWireVersion", common.MaxWireVersion, + "readOnly", false, + )) + + if saslSupportedMechs == "" { + resp.Set("ok", float64(1)) + must.NoError(reply.SetSections(wire.MakeOpMsgSection(resp))) + + return &reply, nil + } + + db, username, ok := strings.Cut(saslSupportedMechs, ".") + if !ok { + return nil, handlererrors.NewCommandErrorMsg( + handlererrors.ErrBadValue, + "UserName must contain a '.' separated database.user pair", + ) + } + + mechs, err := h.getUserSupportedMechs(ctx, db, username) + if err != nil { + return nil, lazyerrors.Error(err) + } + + saslSupportedMechsResp := must.NotFail(types.NewArray()) + for _, k := range mechs { + saslSupportedMechsResp.Append(k) + } + + if saslSupportedMechsResp.Len() != 0 { + resp.Set("saslSupportedMechs", saslSupportedMechsResp) + } + + resp.Set("ok", float64(1)) + must.NoError(reply.SetSections(wire.MakeOpMsgSection(resp))) return &reply, nil } + +// getUserSupportedMechs for a given user. +func (h *Handler) getUserSupportedMechs(ctx context.Context, db, username string) ([]string, error) { + adminDB, err := h.b.Database("admin") + if err != nil { + return nil, lazyerrors.Error(err) + } + + usersCol, err := adminDB.Collection("system.users") + if err != nil { + return nil, lazyerrors.Error(err) + } + + filter, err := usersInfoFilter(false, false, db, []usersInfoPair{ + {username: username, db: db}, + }) + if err != nil { + return nil, lazyerrors.Error(err) + } + + qr, err := usersCol.Query(ctx, nil) + if err != nil { + return nil, lazyerrors.Error(err) + } + + defer qr.Iter.Close() + + for { + _, v, err := qr.Iter.Next() + + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + if err != nil { + return nil, lazyerrors.Error(err) + } + + matches, err := common.FilterDocument(v, filter) + if err != nil { + return nil, lazyerrors.Error(err) + } + + if !matches { + continue + } + + if v.Has("credentials") { + credentials := must.NotFail(v.Get("credentials")).(*types.Document) + return credentials.Keys(), nil + } + } + + return nil, nil +} diff --git a/internal/handler/msg_usersinfo.go b/internal/handler/msg_usersinfo.go index fc5624279ff9..74f635d0eef0 100644 --- a/internal/handler/msg_usersinfo.go +++ b/internal/handler/msg_usersinfo.go @@ -59,7 +59,7 @@ func (h *Handler) MsgUsersInfo(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs common.Ignored( document, h.L, - "showCredentials", "showCustomData", "showPrivileges", + "showCustomData", "showPrivileges", "showAuthenticationRestrictions", "comment", "filter", ) diff --git a/website/docs/reference/supported-commands.md b/website/docs/reference/supported-commands.md index 2659a02a1d75..a5e580c0ff4e 100644 --- a/website/docs/reference/supported-commands.md +++ b/website/docs/reference/supported-commands.md @@ -205,7 +205,7 @@ Related [issue](https://github.com/FerretDB/FerretDB/issues/78). | | `digestPassword` | ⚠️ | | | | `comment` | ⚠️ | | | `usersInfo` | | ✅ | | -| | `showCredentials` | ⚠️ | | +| | `showCredentials` | ✅ | | | | `showCustomData` | ⚠️ | | | | `showPrivileges` | ⚠️ | | | | `showAuthenticationRestrictions` | ⚠️ | |