Skip to content

Commit

Permalink
Adds LDM and KICK server $SYS requests
Browse files Browse the repository at this point in the history
Signed-off-by: Jean-Noël Moyne <jnmoyne@gmail.com>
  • Loading branch information
jnmoyne committed Aug 11, 2023
1 parent 37d3220 commit fc41ab1
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 2 deletions.
1 change: 1 addition & 0 deletions server/client.go
Expand Up @@ -213,6 +213,7 @@ const (
DuplicateServerName
MinimumVersionRequired
ClusterNamesIdentical
Kicked
)

// Some flags passed to processMsgResults
Expand Down
56 changes: 56 additions & 0 deletions server/events.go
Expand Up @@ -57,6 +57,8 @@ const (
accConnsEventSubjNew = "$SYS.ACCOUNT.%s.SERVER.CONNS"
accConnsEventSubjOld = "$SYS.SERVER.ACCOUNT.%s.CONNS" // kept for backward compatibility
shutdownEventSubj = "$SYS.SERVER.%s.SHUTDOWN"
clientKickReqSubj = "$SYS.REQ.SERVER.%s.KICK"
clientLDMReqSubj = "$SYS.REQ.SERVER.%s.LDM"
authErrorEventSubj = "$SYS.SERVER.%s.CLIENT.AUTH.ERR"
authErrorAccountEventSubj = "$SYS.ACCOUNT.CLIENT.AUTH.ERR"
serverStatsSubj = "$SYS.SERVER.%s.STATSZ"
Expand Down Expand Up @@ -1228,6 +1230,17 @@ func (s *Server) initEventTracking() {
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.reloadConfig)); err != nil {
s.Errorf("Error setting up server reload handler: %v", err)
}

// Client connection kick
subject = fmt.Sprintf(clientKickReqSubj, s.info.ID)
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.kickClient)); err != nil {
s.Errorf("Error setting up client kick service: %v", err)
}
// Client connection LDM
subject = fmt.Sprintf(clientLDMReqSubj, s.info.ID)
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.ldmClient)); err != nil {
s.Errorf("Error setting up client LDM service: %v", err)
}
}

// UserInfo returns basic information to a user about bound account and user permissions.
Expand Down Expand Up @@ -2714,6 +2727,49 @@ func (s *Server) reloadConfig(sub *subscription, c *client, _ *Account, subject,
})
}

type KickClientReq struct {
CID uint64 `json:"cid"`
}

type LDMClientReq struct {
CID uint64 `json:"cid"`
}

func (s *Server) kickClient(_ *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
if !s.eventsRunning() {
return
}

var req KickClientReq
if err := json.Unmarshal(msg, &req); err != nil {
s.sys.client.Errorf("Error unmarshalling kick client request: %v", err)
return
}

optz := &EventFilterOptions{}
s.zReq(c, reply, hdr, msg, optz, optz, func() (interface{}, error) {
return nil, s.DisconnectClientByID(req.CID)
})

}

func (s *Server) ldmClient(_ *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
if !s.eventsRunning() {
return
}

var req LDMClientReq
if err := json.Unmarshal(msg, &req); err != nil {
s.sys.client.Errorf("Error unmarshalling kick client request: %v", err)
return
}

optz := &EventFilterOptions{}
s.zReq(c, reply, hdr, msg, optz, optz, func() (interface{}, error) {
return nil, s.LDMClientByID(req.CID)
})
}

// Helper to grab account name for a client.
func accForClient(c *client) string {
if c.acc != nil {
Expand Down
62 changes: 61 additions & 1 deletion server/events_test.go
Expand Up @@ -1668,7 +1668,7 @@ func TestSystemAccountWithGateways(t *testing.T) {

// If this tests fails with wrong number after 10 seconds we may have
// added a new inititial subscription for the eventing system.
checkExpectedSubs(t, 53, sa)
checkExpectedSubs(t, 55, sa)

// Create a client on B and see if we receive the event
urlb := fmt.Sprintf("nats://%s:%d", ob.Host, ob.Port)
Expand Down Expand Up @@ -3412,6 +3412,66 @@ func TestServerEventsReload(t *testing.T) {
require_True(t, s.getOpts().PingInterval == 200*time.Millisecond)
}

func TestServerEventsLDMKick(t *testing.T) {
ldmed := make(chan bool, 1)
disconnected := make(chan bool, 1)

s, opts := runTrustedServer(t)
defer s.Shutdown()

acc, akp := createAccount(s)
s.setSystemAccount(acc)

url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
ncs, err := nats.Connect(url, createUserCreds(t, s, akp))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer ncs.Close()

_, akp2 := createAccount(s)

nc, err := nats.Connect(url, createUserCreds(t, s, akp2), nats.Name("TEST EVENTS LDM+KICK"), nats.LameDuckModeHandler(func(_ *nats.Conn) { ldmed <- true }))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer nc.Close()

nc.SetDisconnectErrHandler(func(_ *nats.Conn, err error) { disconnected <- true })

cid, err := nc.GetClientID()
if err != nil {
t.Fatalf("Error on getting the CID: %v", err)
}

reqldm := LDMClientReq{CID: cid}
reqldmpayload, _ := json.Marshal(reqldm)
reqkick := KickClientReq{CID: cid}
reqkickpayload, _ := json.Marshal(reqkick)

_, err = ncs.Request(fmt.Sprintf("$SYS.REQ.SERVER.%s.LDM", s.ID()), reqldmpayload, time.Second)
if err != nil {
t.Fatalf("Error trying to publish the LDM request: %v", err)
}

select {
case <-ldmed:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for the connection to receive the LDM signal")
}

_, err = ncs.Request(fmt.Sprintf("$SYS.REQ.SERVER.%s.KICK", s.ID()), reqkickpayload, time.Second)
if err != nil {
t.Fatalf("Error trying to publish the KICK request: %v", err)
}

select {
case <-disconnected:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for the client to get disconnected")
}
}

func Benchmark_GetHash(b *testing.B) {
b.StopTimer()
// Get 100 random names
Expand Down
2 changes: 2 additions & 0 deletions server/monitor.go
Expand Up @@ -2416,6 +2416,8 @@ func (reason ClosedState) String() string {
return "Minimum Version Required"
case ClusterNamesIdentical:
return "Cluster Names Identical"
case Kicked:
return "Kicked"
}

return "Unknown State"
Expand Down
2 changes: 1 addition & 1 deletion server/monitor_test.go
Expand Up @@ -3946,7 +3946,7 @@ func TestMonitorAccountz(t *testing.T) {
body = string(readBody(t, fmt.Sprintf("http://127.0.0.1:%d%s?acc=$SYS", s.MonitorAddr().Port, AccountzPath)))
require_Contains(t, body, `"account_detail": {`)
require_Contains(t, body, `"account_name": "$SYS",`)
require_Contains(t, body, `"subscriptions": 47,`)
require_Contains(t, body, `"subscriptions": 49,`)
require_Contains(t, body, `"is_system": true,`)
require_Contains(t, body, `"system_account": "$SYS"`)

Expand Down
33 changes: 33 additions & 0 deletions server/server.go
Expand Up @@ -4252,3 +4252,36 @@ func (s *Server) changeRateLimitLogInterval(d time.Duration) {
default:
}
}

// DisconnectClientByID disconnects a client by connection ID
func (s *Server) DisconnectClientByID(id uint64) error {
client := s.clients[id]
if client != nil {
client.closeConnection(Kicked)
return nil
}
return errors.New("no such client id")
}

// LDMClientByID sends a Lame Duck Mode info message to a client by connection ID
func (s *Server) LDMClientByID(id uint64) error {
info := s.copyInfo()
info.LameDuckMode = true

c := s.clients[id]
if c != nil {
c.mu.Lock()
defer c.mu.Unlock()
if c.opts.Protocol >= ClientProtoInfo &&
c.flags.isSet(firstPongSent) {
// sendInfo takes care of checking if the connection is still
// valid or not, so don't duplicate tests here.
c.Debugf("sending Lame Duck Mode info to client")
c.enqueueProto(c.generateClientInfoJSON(info))
return nil
} else {
return errors.New("ClientProtoInfo < ClientOps.Protocol or first pong not sent")
}
}
return errors.New("no such client id")
}

0 comments on commit fc41ab1

Please sign in to comment.