Skip to content

Commit

Permalink
Merge pull request #3799 from nats-io/auth-callout-events
Browse files Browse the repository at this point in the history
Auth-Callout connect events
  • Loading branch information
derekcollison committed Jan 21, 2023
2 parents 3c4e47d + acad660 commit e582f01
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 15 deletions.
10 changes: 3 additions & 7 deletions server/accounts.go
@@ -1,4 +1,4 @@
// Copyright 2018-2022 The NATS Authors
// Copyright 2018-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -958,13 +958,9 @@ func (a *Account) removeClient(c *client) int {
a.mu.Unlock()

if c != nil && c.srv != nil && removed {
c.srv.mu.Lock()
doRemove := a != c.srv.gacc
c.srv.mu.Unlock()
if doRemove {
c.srv.accConnsUpdate(a)
}
c.srv.accConnsUpdate(a)
}

return n
}

Expand Down
135 changes: 135 additions & 0 deletions server/auth_callout_test.go
Expand Up @@ -998,3 +998,138 @@ func TestAuthCalloutAuthErrEvents(t *testing.T) {
checkAuthErrEvent("dlc", "xxx", "WRONG PASSWORD")
checkAuthErrEvent("rip", "abc", "BAD CREDS")
}

func TestAuthCalloutConnectEvents(t *testing.T) {
conf := createConfFile(t, []byte(`
listen: "127.0.0.1:-1"
server_name: A
accounts {
AUTH { users [ {user: "auth", password: "pwd"} ] }
FOO {}
BAR {}
$SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] }
}
authorization {
auth_callout {
issuer: "ABJHLOVMPA4CI6R5KLNGOB4GSLNIY7IOUPAJC4YFNDLQVIOBYQGUWVLA"
account: AUTH
auth_users: [ auth, admin ]
}
}
`))
defer removeFile(t, conf)
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("auth", "pwd"))
require_NoError(t, err)
defer nc.Close()

_, err = nc.Subscribe(AuthCalloutSubject, func(m *nats.Msg) {
user, si, _, opts, _ := decodeAuthRequest(t, m.Data)
// Allow dlc user and map to the BAZ account.
if opts.Username == "dlc" && opts.Password == "zzz" {
ujwt := createAuthUser(t, si.ID, user, _EMPTY_, "FOO", nil, 0, nil)
m.Respond([]byte(ujwt))
} else if opts.Username == "rip" && opts.Password == "xxx" {
ujwt := createAuthUser(t, si.ID, user, _EMPTY_, "BAR", nil, 0, nil)
m.Respond([]byte(ujwt))
} else {
errResp := createErrResponse(t, user, si.ID, "BAD CREDS", nil)
m.Respond([]byte(errResp))
}
})
require_NoError(t, err)

// Setup system user.
snc, err := nats.Connect(s.ClientURL(), nats.UserInfo("admin", "s3cr3t!"))
require_NoError(t, err)
defer nc.Close()

// Allow this connect event to pass us by..
time.Sleep(250 * time.Millisecond)

// Watch for connect events.
csub, err := snc.SubscribeSync(fmt.Sprintf(connectEventSubj, "*"))
require_NoError(t, err)

// Watch for disconnect events.
dsub, err := snc.SubscribeSync(fmt.Sprintf(disconnectEventSubj, "*"))
require_NoError(t, err)

// Connections updates. Old
acOldSub, err := snc.SubscribeSync(fmt.Sprintf(accConnsEventSubjOld, "*"))
require_NoError(t, err)

// Connections updates. New
acNewSub, err := snc.SubscribeSync(fmt.Sprintf(accConnsEventSubjNew, "*"))
require_NoError(t, err)

snc.Flush()

checkConnectEvents := func(user, pass, acc string) {
nc, err := nats.Connect(s.ClientURL(), nats.UserInfo(user, pass))
require_NoError(t, err)

m, err := csub.NextMsg(time.Second)
require_NoError(t, err)

var cm ConnectEventMsg
err = json.Unmarshal(m.Data, &cm)
require_NoError(t, err)
require_True(t, cm.Client.User == user)
require_True(t, cm.Client.Account == acc)

// Check that we have updates, 1 each, for the connections updates.
m, err = acOldSub.NextMsg(time.Second)
require_NoError(t, err)

var anc AccountNumConns
err = json.Unmarshal(m.Data, &anc)
require_NoError(t, err)
require_True(t, anc.AccountStat.Account == acc)
require_True(t, anc.AccountStat.Conns == 1)

m, err = acNewSub.NextMsg(time.Second)
require_NoError(t, err)

err = json.Unmarshal(m.Data, &anc)
require_NoError(t, err)
require_True(t, anc.AccountStat.Account == acc)
require_True(t, anc.AccountStat.Conns == 1)

// Force the disconnect.
nc.Close()

m, err = dsub.NextMsg(time.Second)
require_NoError(t, err)

var dm DisconnectEventMsg
err = json.Unmarshal(m.Data, &dm)
require_NoError(t, err)

m, err = acOldSub.NextMsg(time.Second)
require_NoError(t, err)
err = json.Unmarshal(m.Data, &anc)
require_NoError(t, err)
require_True(t, anc.AccountStat.Account == acc)
require_True(t, anc.AccountStat.Conns == 0)

m, err = acNewSub.NextMsg(time.Second)
require_NoError(t, err)
err = json.Unmarshal(m.Data, &anc)
require_NoError(t, err)
require_True(t, anc.AccountStat.Account == acc)
require_True(t, anc.AccountStat.Conns == 0)

// Make sure no double events sent.
time.Sleep(200 * time.Millisecond)
checkSubsPending(t, csub, 0)
checkSubsPending(t, dsub, 0)
checkSubsPending(t, acOldSub, 0)
checkSubsPending(t, acNewSub, 0)
}

checkConnectEvents("dlc", "zzz", "FOO")
checkConnectEvents("rip", "xxx", "BAR")
}
3 changes: 2 additions & 1 deletion server/events.go
Expand Up @@ -1919,10 +1919,11 @@ func (a *Account) statz() *AccountStat {

// accConnsUpdate is called whenever there is a change to the account's
// number of active connections, or during a heartbeat.
// We will not send for $G.
func (s *Server) accConnsUpdate(a *Account) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.eventsEnabled() || a == nil {
if !s.eventsEnabled() || a == nil || a == s.gacc {
return
}
s.sendAccConnsUpdate(a, fmt.Sprintf(accConnsEventSubjOld, a.Name), fmt.Sprintf(accConnsEventSubjNew, a.Name))
Expand Down
8 changes: 1 addition & 7 deletions server/events_test.go
Expand Up @@ -1675,12 +1675,10 @@ func TestSystemAccountWithGateways(t *testing.T) {
require_NoError(t, err)
msgs[1], err = sub.NextMsg(time.Second)
require_NoError(t, err)
msgs[2], err = sub.NextMsg(time.Second)
require_NoError(t, err)
// TODO: There is a race currently that can cause the server to process the
// system event *after* the subscription on "A" has been registered, and so
// the "nca" client would receive its own CONNECT message.
msgs[3], _ = sub.NextMsg(250 * time.Millisecond)
msgs[2], _ = sub.NextMsg(250 * time.Millisecond)

findMsgs := func(sub string) []*nats.Msg {
rMsgs := []*nats.Msg{}
Expand Down Expand Up @@ -1710,10 +1708,6 @@ func TestSystemAccountWithGateways(t *testing.T) {
if len(connsMsgA) != 1 {
t.Fatal("Expected a message")
}
connsMsgG := findMsgs("$SYS.ACCOUNT.$G.SERVER.CONNS")
if len(connsMsgG) != 1 {
t.Fatal("Expected a message")
}
}

func TestSystemAccountNoAuthUser(t *testing.T) {
Expand Down

0 comments on commit e582f01

Please sign in to comment.