Skip to content

Commit

Permalink
use lock less
Browse files Browse the repository at this point in the history
  • Loading branch information
chilagrow committed Mar 11, 2024
1 parent 5fa1d3e commit 4748386
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 30 deletions.
2 changes: 1 addition & 1 deletion internal/backends/mysql/metadata/registry.go
Expand Up @@ -124,7 +124,7 @@ func (r *Registry) getPool(ctx context.Context) (*fsql.DB, error) {
return nil, lazyerrors.New("no connection pool")
}
} else {
username, password := connInfo.Auth()
username, password, _ := connInfo.Auth()

var err error
if p, err = r.p.Get(username, password); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/backends/postgresql/metadata/registry.go
Expand Up @@ -129,7 +129,7 @@ func (r *Registry) getPool(ctx context.Context) (*pgxpool.Pool, error) {
}
}
} else {
username, password := connInfo.Auth()
username, password, _ := connInfo.Auth()

var err error
if p, err = r.p.Get(username, password); err != nil {
Expand Down
25 changes: 5 additions & 20 deletions internal/clientconn/conninfo/conn_info.go
Expand Up @@ -63,36 +63,21 @@ func (connInfo *ConnInfo) Username() string {
return connInfo.username
}

// Auth returns stored username and password.
func (connInfo *ConnInfo) Auth() (username, password string) {
// Auth returns stored username, password and mechanism.
func (connInfo *ConnInfo) Auth() (username, password, mechanism string) {
connInfo.rw.RLock()
defer connInfo.rw.RUnlock()

return connInfo.username, connInfo.password
return connInfo.username, connInfo.password, connInfo.mechanism
}

// SetAuth stores username and password.
func (connInfo *ConnInfo) SetAuth(username, password string) {
// SetAuth stores username, password.
func (connInfo *ConnInfo) SetAuth(username, password, mechanism string) {
connInfo.rw.Lock()
defer connInfo.rw.Unlock()

connInfo.username = username
connInfo.password = password
}

// Mechanism returns stored mechanism.
func (connInfo *ConnInfo) Mechanism() (mechanism string) {
connInfo.rw.RLock()
defer connInfo.rw.RUnlock()

return connInfo.mechanism
}

// SetMechanism stores mechanism.
func (connInfo *ConnInfo) SetMechanism(mechanism string) {
connInfo.rw.Lock()
defer connInfo.rw.Unlock()

connInfo.mechanism = mechanism
}

Expand Down
4 changes: 1 addition & 3 deletions internal/handler/authenticate.go
Expand Up @@ -54,7 +54,7 @@ func (h *Handler) authenticate(ctx context.Context) error {
return lazyerrors.Error(err)
}

mechanism := conninfo.Get(ctx).Mechanism()
username, userPassword, mechanism := conninfo.Get(ctx).Auth()

switch mechanism {
case "SCRAM-SHA-256", "SCRAM-SHA-1":
Expand All @@ -67,8 +67,6 @@ func (h *Handler) authenticate(ctx context.Context) error {
return lazyerrors.Errorf("Unsupported authentication mechanism %q", mechanism)

Check warning on line 67 in internal/handler/authenticate.go

View check run for this annotation

Codecov / codecov/patch

internal/handler/authenticate.go#L66-L67

Added lines #L66 - L67 were not covered by tests
}

username, userPassword := conninfo.Get(ctx).Auth()

// For `PLAIN` mechanism $db field is always `$external` upon saslStart.
// For `SCRAM-SHA-1` and `SCRAM-SHA-256` mechanisms $db field contains
// authSource option of the client.
Expand Down
3 changes: 1 addition & 2 deletions internal/handler/msg_logout.go
Expand Up @@ -25,8 +25,7 @@ import (

// MsgLogout implements `logout` command.
func (h *Handler) MsgLogout(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) {
conninfo.Get(ctx).SetAuth("", "")
conninfo.Get(ctx).SetMechanism("")
conninfo.Get(ctx).SetAuth("", "", "")

var reply wire.OpMsg
must.NoError(reply.SetSections(wire.MakeOpMsgSection(
Expand Down
6 changes: 3 additions & 3 deletions internal/handler/msg_saslstart.go
Expand Up @@ -67,8 +67,6 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types.
return nil, lazyerrors.Error(err)
}

conninfo.Get(ctx).SetMechanism(mechanism)

switch mechanism {
case "PLAIN":
username, password, err := saslStartPlain(document)
Expand All @@ -80,7 +78,7 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types.
conninfo.Get(ctx).SetBypassBackendAuth()
}

conninfo.Get(ctx).SetAuth(username, password)
conninfo.Get(ctx).SetAuth(username, password, mechanism)

var emptyPayload types.Binary

Expand All @@ -97,6 +95,8 @@ func (h *Handler) saslStart(ctx context.Context, dbName string, document *types.
)
}

conninfo.Get(ctx).SetAuth("", "", mechanism)

response, err := h.saslStartSCRAM(ctx, dbName, mechanism, document)
if err != nil {
return nil, err
Expand Down

0 comments on commit 4748386

Please sign in to comment.