From f460194f7d3dec20621eea7a23b49f85f82d6a72 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 12 Mar 2024 04:14:07 +0900 Subject: [PATCH] Fix PLAIN mechanism authentication incorrectly working (#4163) Closes #1877. --- integration/users/connection_test.go | 37 ++++++++++++++++--- internal/backends/mysql/metadata/registry.go | 2 +- .../backends/postgresql/metadata/registry.go | 2 +- internal/clientconn/conninfo/conn_info.go | 12 +++--- internal/handler/authenticate.go | 31 ++++++++++------ internal/handler/msg_logout.go | 2 +- internal/handler/msg_saslstart.go | 4 +- 7 files changed, 64 insertions(+), 26 deletions(-) diff --git a/integration/users/connection_test.go b/integration/users/connection_test.go index af03b512a1df..55dd04fe92d2 100644 --- a/integration/users/connection_test.go +++ b/integration/users/connection_test.go @@ -26,6 +26,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" + "github.com/FerretDB/FerretDB/integration" "github.com/FerretDB/FerretDB/integration/setup" "github.com/FerretDB/FerretDB/internal/util/testutil/testtb" ) @@ -353,12 +354,20 @@ func TestAuthenticationEnableNewAuthPLAIN(t *testing.T) { }).Err() require.NoError(t, err, "cannot create user") + err = db.RunCommand(ctx, bson.D{ + {"createUser", "scram-user"}, + {"roles", bson.A{}}, + {"pwd", "correct"}, + {"mechanisms", bson.A{"SCRAM-SHA-1", "SCRAM-SHA-256"}}, + }).Err() + require.NoError(t, err, "cannot create user") + testCases := map[string]struct { username string password string mechanism string - err string + err *mongo.CommandError }{ "Success": { username: "plain-user", @@ -369,13 +378,31 @@ func TestAuthenticationEnableNewAuthPLAIN(t *testing.T) { username: "plain-user", password: "wrong", mechanism: "PLAIN", - err: "AuthenticationFailed", + err: &mongo.CommandError{ + Code: 18, + Name: "AuthenticationFailed", + Message: "Authentication failed", + }, }, "NonExistentUser": { username: "not-found-user", password: "something", mechanism: "PLAIN", - err: "AuthenticationFailed", + err: &mongo.CommandError{ + Code: 18, + Name: "AuthenticationFailed", + Message: "Authentication failed", + }, + }, + "NonPLAINUser": { + username: "scram-user", + password: "correct", + mechanism: "PLAIN", + err: &mongo.CommandError{ + Code: 334, + Name: "ErrMechanismUnavailable", + Message: "Unable to use PLAIN based authentication for user without any PLAIN credentials registered", + }, }, } @@ -403,8 +430,8 @@ func TestAuthenticationEnableNewAuthPLAIN(t *testing.T) { c := client.Database(db.Name()).Collection(cName) _, err = c.InsertOne(ctx, bson.D{{"ping", "pong"}}) - if tc.err != "" { - require.ErrorContains(t, err, tc.err) + if tc.err != nil { + integration.AssertEqualCommandError(t, *tc.err, err) return } diff --git a/internal/backends/mysql/metadata/registry.go b/internal/backends/mysql/metadata/registry.go index 47c49e88eca4..0bb30a3ed4be 100644 --- a/internal/backends/mysql/metadata/registry.go +++ b/internal/backends/mysql/metadata/registry.go @@ -124,7 +124,7 @@ func (r *Registry) getPool(ctx context.Context) (*fsql.DB, error) { return nil, lazyerrors.New("no connection pool") } } else { - username, password := connInfo.Auth() + username, password, _ := connInfo.Auth() var err error if p, err = r.p.Get(username, password); err != nil { diff --git a/internal/backends/postgresql/metadata/registry.go b/internal/backends/postgresql/metadata/registry.go index a1bf7f704d97..c920a439a9a8 100644 --- a/internal/backends/postgresql/metadata/registry.go +++ b/internal/backends/postgresql/metadata/registry.go @@ -129,7 +129,7 @@ func (r *Registry) getPool(ctx context.Context) (*pgxpool.Pool, error) { } } } else { - username, password := connInfo.Auth() + username, password, _ := connInfo.Auth() var err error if p, err = r.p.Get(username, password); err != nil { diff --git a/internal/clientconn/conninfo/conn_info.go b/internal/clientconn/conninfo/conn_info.go index 88fe95b49eff..8d70b81d631f 100644 --- a/internal/clientconn/conninfo/conn_info.go +++ b/internal/clientconn/conninfo/conn_info.go @@ -35,6 +35,7 @@ type ConnInfo struct { PeerAddr string username string // protected by rw password string // protected by rw + mechanism string // protected by rw metadataRecv bool // protected by rw sc *scram.ServerConversation // protected by rw @@ -62,21 +63,22 @@ func (connInfo *ConnInfo) Username() string { return connInfo.username } -// Auth returns stored username and password. -func (connInfo *ConnInfo) Auth() (username, password string) { +// Auth returns stored username, password and mechanism. +func (connInfo *ConnInfo) Auth() (username, password, mechanism string) { connInfo.rw.RLock() defer connInfo.rw.RUnlock() - return connInfo.username, connInfo.password + return connInfo.username, connInfo.password, connInfo.mechanism } -// SetAuth stores username and password. -func (connInfo *ConnInfo) SetAuth(username, password string) { +// SetAuth stores username, password. +func (connInfo *ConnInfo) SetAuth(username, password, mechanism string) { connInfo.rw.Lock() defer connInfo.rw.Unlock() connInfo.username = username connInfo.password = password + connInfo.mechanism = mechanism } // Conv returns stored SCRAM server conversation. diff --git a/internal/handler/authenticate.go b/internal/handler/authenticate.go index f6b92c7e9128..3ba5102db34b 100644 --- a/internal/handler/authenticate.go +++ b/internal/handler/authenticate.go @@ -54,7 +54,18 @@ func (h *Handler) authenticate(ctx context.Context) error { return lazyerrors.Error(err) } - username, userPassword := conninfo.Get(ctx).Auth() + username, userPassword, mechanism := conninfo.Get(ctx).Auth() + + switch mechanism { + case "SCRAM-SHA-256", "SCRAM-SHA-1": + // SCRAM calls back scramCredentialLookup each time Step is called, + // and that checks the authentication. + return nil + case "PLAIN": + break + default: + return lazyerrors.Errorf("Unsupported authentication mechanism %q", mechanism) + } // For `PLAIN` mechanism $db field is always `$external` upon saslStart. // For `SCRAM-SHA-1` and `SCRAM-SHA-256` mechanisms $db field contains @@ -121,19 +132,15 @@ func (h *Handler) authenticate(ctx context.Context) error { credentials := must.NotFail(storedUser.Get("credentials")).(*types.Document) - switch { - case credentials.Has("SCRAM-SHA-256"), credentials.Has("SCRAM-SHA-1"): - // SCRAM calls back scramCredentialLookup each time Step is called, - // and that checks the authentication. - return nil - case credentials.Has("PLAIN"): - break - default: - panic("credentials does not contain a known mechanism") + v, _ := credentials.Get("PLAIN") + if v == nil { + return handlererrors.NewCommandErrorMsgWithArgument( + handlererrors.ErrMechanismUnavailable, + "Unable to use PLAIN based authentication for user without any PLAIN credentials registered", + "authenticate", + ) } - v := must.NotFail(credentials.Get("PLAIN")) - doc, ok := v.(*types.Document) if !ok { return lazyerrors.Errorf("field 'PLAIN' has type %T, expected Document", v) diff --git a/internal/handler/msg_logout.go b/internal/handler/msg_logout.go index 1e40e48a23c5..f079de6d8fde 100644 --- a/internal/handler/msg_logout.go +++ b/internal/handler/msg_logout.go @@ -25,7 +25,7 @@ import ( // MsgLogout implements `logout` command. func (h *Handler) MsgLogout(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) { - conninfo.Get(ctx).SetAuth("", "") + conninfo.Get(ctx).SetAuth("", "", "") var reply wire.OpMsg must.NoError(reply.SetSections(wire.MakeOpMsgSection( diff --git a/internal/handler/msg_saslstart.go b/internal/handler/msg_saslstart.go index 7e1208f4a790..fb4be24aa72d 100644 --- a/internal/handler/msg_saslstart.go +++ b/internal/handler/msg_saslstart.go @@ -78,7 +78,7 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types. conninfo.Get(ctx).SetBypassBackendAuth() } - conninfo.Get(ctx).SetAuth(username, password) + conninfo.Get(ctx).SetAuth(username, password, mechanism) var emptyPayload types.Binary @@ -95,6 +95,8 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types. ) } + conninfo.Get(ctx).SetAuth("", "", mechanism) + response, err := h.saslStartSCRAM(ctx, dbName, mechanism, document) if err != nil { return nil, err