Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] $SYS server request to 'kick' or 'LDM' a client connection #4298

Merged
merged 1 commit into from Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions server/client.go
Expand Up @@ -213,6 +213,7 @@ const (
DuplicateServerName
MinimumVersionRequired
ClusterNamesIdentical
Kicked
)

// Some flags passed to processMsgResults
Expand Down
56 changes: 56 additions & 0 deletions server/events.go
Expand Up @@ -57,6 +57,8 @@ const (
accConnsEventSubjNew = "$SYS.ACCOUNT.%s.SERVER.CONNS"
accConnsEventSubjOld = "$SYS.SERVER.ACCOUNT.%s.CONNS" // kept for backward compatibility
shutdownEventSubj = "$SYS.SERVER.%s.SHUTDOWN"
clientKickReqSubj = "$SYS.REQ.SERVER.%s.KICK"
clientLDMReqSubj = "$SYS.REQ.SERVER.%s.LDM"
authErrorEventSubj = "$SYS.SERVER.%s.CLIENT.AUTH.ERR"
authErrorAccountEventSubj = "$SYS.ACCOUNT.CLIENT.AUTH.ERR"
serverStatsSubj = "$SYS.SERVER.%s.STATSZ"
Expand Down Expand Up @@ -1228,6 +1230,17 @@ func (s *Server) initEventTracking() {
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.reloadConfig)); err != nil {
s.Errorf("Error setting up server reload handler: %v", err)
}

// Client connection kick
subject = fmt.Sprintf(clientKickReqSubj, s.info.ID)
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.kickClient)); err != nil {
s.Errorf("Error setting up client kick service: %v", err)
}
// Client connection LDM
subject = fmt.Sprintf(clientLDMReqSubj, s.info.ID)
if _, err := s.sysSubscribe(subject, s.noInlineCallback(s.ldmClient)); err != nil {
s.Errorf("Error setting up client LDM service: %v", err)
}
}

// UserInfo returns basic information to a user about bound account and user permissions.
Expand Down Expand Up @@ -2714,6 +2727,49 @@ func (s *Server) reloadConfig(sub *subscription, c *client, _ *Account, subject,
})
}

type KickClientReq struct {
CID uint64 `json:"cid"`
}

type LDMClientReq struct {
CID uint64 `json:"cid"`
}

func (s *Server) kickClient(_ *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
if !s.eventsRunning() {
return
}

var req KickClientReq
if err := json.Unmarshal(msg, &req); err != nil {
s.sys.client.Errorf("Error unmarshalling kick client request: %v", err)
return
}

optz := &EventFilterOptions{}
s.zReq(c, reply, hdr, msg, optz, optz, func() (interface{}, error) {
return nil, s.DisconnectClientByID(req.CID)
})

}

func (s *Server) ldmClient(_ *subscription, c *client, _ *Account, subject, reply string, hdr, msg []byte) {
if !s.eventsRunning() {
return
}

var req LDMClientReq
if err := json.Unmarshal(msg, &req); err != nil {
s.sys.client.Errorf("Error unmarshalling kick client request: %v", err)
return
}

optz := &EventFilterOptions{}
s.zReq(c, reply, hdr, msg, optz, optz, func() (interface{}, error) {
return nil, s.LDMClientByID(req.CID)
})
}

// Helper to grab account name for a client.
func accForClient(c *client) string {
if c.acc != nil {
Expand Down
62 changes: 61 additions & 1 deletion server/events_test.go
Expand Up @@ -1668,7 +1668,7 @@ func TestSystemAccountWithGateways(t *testing.T) {

// If this tests fails with wrong number after 10 seconds we may have
// added a new inititial subscription for the eventing system.
checkExpectedSubs(t, 53, sa)
checkExpectedSubs(t, 55, sa)

// Create a client on B and see if we receive the event
urlb := fmt.Sprintf("nats://%s:%d", ob.Host, ob.Port)
Expand Down Expand Up @@ -3412,6 +3412,66 @@ func TestServerEventsReload(t *testing.T) {
require_True(t, s.getOpts().PingInterval == 200*time.Millisecond)
}

func TestServerEventsLDMKick(t *testing.T) {
ldmed := make(chan bool, 1)
disconnected := make(chan bool, 1)

s, opts := runTrustedServer(t)
defer s.Shutdown()

acc, akp := createAccount(s)
s.setSystemAccount(acc)

url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
ncs, err := nats.Connect(url, createUserCreds(t, s, akp))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer ncs.Close()

_, akp2 := createAccount(s)

nc, err := nats.Connect(url, createUserCreds(t, s, akp2), nats.Name("TEST EVENTS LDM+KICK"), nats.LameDuckModeHandler(func(_ *nats.Conn) { ldmed <- true }))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer nc.Close()

nc.SetDisconnectErrHandler(func(_ *nats.Conn, err error) { disconnected <- true })

cid, err := nc.GetClientID()
if err != nil {
t.Fatalf("Error on getting the CID: %v", err)
}

reqldm := LDMClientReq{CID: cid}
reqldmpayload, _ := json.Marshal(reqldm)
reqkick := KickClientReq{CID: cid}
reqkickpayload, _ := json.Marshal(reqkick)

_, err = ncs.Request(fmt.Sprintf("$SYS.REQ.SERVER.%s.LDM", s.ID()), reqldmpayload, time.Second)
if err != nil {
t.Fatalf("Error trying to publish the LDM request: %v", err)
}

select {
case <-ldmed:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for the connection to receive the LDM signal")
}

_, err = ncs.Request(fmt.Sprintf("$SYS.REQ.SERVER.%s.KICK", s.ID()), reqkickpayload, time.Second)
if err != nil {
t.Fatalf("Error trying to publish the KICK request: %v", err)
}

select {
case <-disconnected:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for the client to get disconnected")
}
}

func Benchmark_GetHash(b *testing.B) {
b.StopTimer()
// Get 100 random names
Expand Down
2 changes: 2 additions & 0 deletions server/monitor.go
Expand Up @@ -2416,6 +2416,8 @@ func (reason ClosedState) String() string {
return "Minimum Version Required"
case ClusterNamesIdentical:
return "Cluster Names Identical"
case Kicked:
return "Kicked"
}

return "Unknown State"
Expand Down
2 changes: 1 addition & 1 deletion server/monitor_test.go
Expand Up @@ -3946,7 +3946,7 @@ func TestMonitorAccountz(t *testing.T) {
body = string(readBody(t, fmt.Sprintf("http://127.0.0.1:%d%s?acc=$SYS", s.MonitorAddr().Port, AccountzPath)))
require_Contains(t, body, `"account_detail": {`)
require_Contains(t, body, `"account_name": "$SYS",`)
require_Contains(t, body, `"subscriptions": 47,`)
require_Contains(t, body, `"subscriptions": 49,`)
require_Contains(t, body, `"is_system": true,`)
require_Contains(t, body, `"system_account": "$SYS"`)

Expand Down
33 changes: 33 additions & 0 deletions server/server.go
Expand Up @@ -4252,3 +4252,36 @@ func (s *Server) changeRateLimitLogInterval(d time.Duration) {
default:
}
}

// DisconnectClientByID disconnects a client by connection ID
func (s *Server) DisconnectClientByID(id uint64) error {
client := s.clients[id]
if client != nil {
client.closeConnection(Kicked)
return nil
}
return errors.New("no such client id")
}

// LDMClientByID sends a Lame Duck Mode info message to a client by connection ID
func (s *Server) LDMClientByID(id uint64) error {
jnmoyne marked this conversation as resolved.
Show resolved Hide resolved
info := s.copyInfo()
info.LameDuckMode = true

c := s.clients[id]
if c != nil {
c.mu.Lock()
defer c.mu.Unlock()
if c.opts.Protocol >= ClientProtoInfo &&
c.flags.isSet(firstPongSent) {
// sendInfo takes care of checking if the connection is still
// valid or not, so don't duplicate tests here.
c.Debugf("sending Lame Duck Mode info to client")
c.enqueueProto(c.generateClientInfoJSON(info))
return nil
} else {
return errors.New("ClientProtoInfo < ClientOps.Protocol or first pong not sent")
}
}
return errors.New("no such client id")
}