diff --git a/integration/users/connection_test.go b/integration/users/connection_test.go index 6f288b7e2a4d..ee09f0a9289f 100644 --- a/integration/users/connection_test.go +++ b/integration/users/connection_test.go @@ -392,12 +392,20 @@ func TestAuthenticationEnableNewAuthPLAIN(tt *testing.T) { }).Err() require.NoError(t, err, "cannot create user") + err = db.RunCommand(ctx, bson.D{ + {"createUser", "scram-user"}, + {"roles", bson.A{}}, + {"pwd", "correct"}, + {"mechanisms", bson.A{"SCRAM-SHA-1", "SCRAM-SHA-256"}}, + }).Err() + require.NoError(t, err, "cannot create user") + testCases := map[string]struct { username string password string mechanism string - err string + err *mongo.CommandError }{ "Success": { username: "plain-user", @@ -408,13 +416,31 @@ func TestAuthenticationEnableNewAuthPLAIN(tt *testing.T) { username: "plain-user", password: "wrong", mechanism: "PLAIN", - err: "AuthenticationFailed", + err: &mongo.CommandError{ + Code: 18, + Name: "AuthenticationFailed", + Message: "Authentication failed", + }, }, "NonExistentUser": { username: "not-found-user", password: "something", mechanism: "PLAIN", - err: "AuthenticationFailed", + err: &mongo.CommandError{ + Code: 18, + Name: "AuthenticationFailed", + Message: "Authentication failed", + }, + }, + "NonPLAINUser": { + username: "scram-user", + password: "correct", + mechanism: "PLAIN", + err: &mongo.CommandError{ + Code: 334, + Name: "ErrMechanismUnavailable", + Message: "Unable to use PLAIN based authentication for user without any PLAIN credentials registered", + }, }, } @@ -442,8 +468,8 @@ func TestAuthenticationEnableNewAuthPLAIN(tt *testing.T) { c := client.Database(db.Name()).Collection(cName) _, err = c.InsertOne(ctx, bson.D{{"ping", "pong"}}) - if tc.err != "" { - require.ErrorContains(t, err, tc.err) + if tc.err != nil { + integration.AssertEqualCommandError(t, *tc.err, err) return } diff --git a/internal/backends/mysql/metadata/registry.go b/internal/backends/mysql/metadata/registry.go index 47c49e88eca4..0bb30a3ed4be 100644 --- a/internal/backends/mysql/metadata/registry.go +++ b/internal/backends/mysql/metadata/registry.go @@ -124,7 +124,7 @@ func (r *Registry) getPool(ctx context.Context) (*fsql.DB, error) { return nil, lazyerrors.New("no connection pool") } } else { - username, password := connInfo.Auth() + username, password, _ := connInfo.Auth() var err error if p, err = r.p.Get(username, password); err != nil { diff --git a/internal/backends/postgresql/metadata/registry.go b/internal/backends/postgresql/metadata/registry.go index a1bf7f704d97..c920a439a9a8 100644 --- a/internal/backends/postgresql/metadata/registry.go +++ b/internal/backends/postgresql/metadata/registry.go @@ -129,7 +129,7 @@ func (r *Registry) getPool(ctx context.Context) (*pgxpool.Pool, error) { } } } else { - username, password := connInfo.Auth() + username, password, _ := connInfo.Auth() var err error if p, err = r.p.Get(username, password); err != nil { diff --git a/internal/clientconn/conninfo/conn_info.go b/internal/clientconn/conninfo/conn_info.go index afdfa5985043..52676c92919f 100644 --- a/internal/clientconn/conninfo/conn_info.go +++ b/internal/clientconn/conninfo/conn_info.go @@ -37,8 +37,9 @@ type ConnInfo struct { Peer netip.AddrPort // invalid for Unix domain sockets - username string // protected by rw - password string // protected by rw + username string // protected by rw + password string // protected by rw + mechanism string // protected by rw rw sync.RWMutex @@ -70,21 +71,22 @@ func (connInfo *ConnInfo) Username() string { return connInfo.username } -// Auth returns stored username and password. -func (connInfo *ConnInfo) Auth() (username, password string) { +// Auth returns stored username, password and mechanism. +func (connInfo *ConnInfo) Auth() (username, password, mechanism string) { connInfo.rw.RLock() defer connInfo.rw.RUnlock() - return connInfo.username, connInfo.password + return connInfo.username, connInfo.password, connInfo.mechanism } -// SetAuth stores username and password. -func (connInfo *ConnInfo) SetAuth(username, password string) { +// SetAuth stores username, password. +func (connInfo *ConnInfo) SetAuth(username, password, mechanism string) { connInfo.rw.Lock() defer connInfo.rw.Unlock() connInfo.username = username connInfo.password = password + connInfo.mechanism = mechanism } // Conv returns stored SCRAM server conversation. diff --git a/internal/handler/authenticate.go b/internal/handler/authenticate.go index cf3c934a5b3f..9fcf8ccf9077 100644 --- a/internal/handler/authenticate.go +++ b/internal/handler/authenticate.go @@ -51,7 +51,18 @@ func (h *Handler) authenticate(ctx context.Context) error { return lazyerrors.Error(err) } - username, userPassword := conninfo.Get(ctx).Auth() + username, userPassword, mechanism := conninfo.Get(ctx).Auth() + + switch mechanism { + case "SCRAM-SHA-256", "SCRAM-SHA-1": + // SCRAM calls back scramCredentialLookup each time Step is called, + // and that checks the authentication. + return nil + case "PLAIN": + break + default: + return lazyerrors.Errorf("Unsupported authentication mechanism %q", mechanism) + } // For `PLAIN` mechanism $db field is always `$external` upon saslStart. // For `SCRAM-SHA-1` and `SCRAM-SHA-256` mechanisms $db field contains @@ -111,19 +122,15 @@ func (h *Handler) authenticate(ctx context.Context) error { credentials := must.NotFail(storedUser.Get("credentials")).(*types.Document) - switch { - case credentials.Has("SCRAM-SHA-256"), credentials.Has("SCRAM-SHA-1"): - // SCRAM calls back scramCredentialLookup each time Step is called, - // and that checks the authentication. - return nil - case credentials.Has("PLAIN"): - break - default: - panic("credentials does not contain a known mechanism") + v, _ := credentials.Get("PLAIN") + if v == nil { + return handlererrors.NewCommandErrorMsgWithArgument( + handlererrors.ErrMechanismUnavailable, + "Unable to use PLAIN based authentication for user without any PLAIN credentials registered", + "authenticate", + ) } - v := must.NotFail(credentials.Get("PLAIN")) - doc, ok := v.(*types.Document) if !ok { return lazyerrors.Errorf("field 'PLAIN' has type %T, expected Document", v) diff --git a/internal/handler/msg_logout.go b/internal/handler/msg_logout.go index 1e40e48a23c5..f079de6d8fde 100644 --- a/internal/handler/msg_logout.go +++ b/internal/handler/msg_logout.go @@ -25,7 +25,7 @@ import ( // MsgLogout implements `logout` command. func (h *Handler) MsgLogout(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) { - conninfo.Get(ctx).SetAuth("", "") + conninfo.Get(ctx).SetAuth("", "", "") var reply wire.OpMsg must.NoError(reply.SetSections(wire.MakeOpMsgSection( diff --git a/internal/handler/msg_saslstart.go b/internal/handler/msg_saslstart.go index 7e1208f4a790..fb4be24aa72d 100644 --- a/internal/handler/msg_saslstart.go +++ b/internal/handler/msg_saslstart.go @@ -78,7 +78,7 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types. conninfo.Get(ctx).SetBypassBackendAuth() } - conninfo.Get(ctx).SetAuth(username, password) + conninfo.Get(ctx).SetAuth(username, password, mechanism) var emptyPayload types.Binary @@ -95,6 +95,8 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types. ) } + conninfo.Get(ctx).SetAuth("", "", mechanism) + response, err := h.saslStartSCRAM(ctx, dbName, mechanism, document) if err != nil { return nil, err