Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for server enforcement of keepalive policy. #1147

Merged
merged 5 commits into from
Mar 31, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions keepalive/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,12 @@ type ServerParameters struct {
// the connection is closed.
Timeout time.Duration
}

// EnforcementPolicy is used to set keepalive enforcement policy on the server-side.
// Server will close connection with a client that violates this policy.
type EnforcementPolicy struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some notes in the ClientParameters to the effect of "make sure you set these parameters in coordination with the service owners, as incompatible settings can result in RPC failures"?

Especially since there's a default enforcement policy of 5m and no pings without streams, even if the service doesn't declare a policy.

// MinTime is the minimum amount of time a client should wait before sending a keepalive ping.
MinTime time.Duration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind documenting the default values for all of the settings in this file?

// If true, server expects keepalive pings even when there are no active streams(RPCs).
PermitWithoutStream bool
}
9 changes: 9 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ type options struct {
useHandlerImpl bool // use http.Handler-based server
unknownStreamDesc *StreamDesc
keepaliveParams keepalive.ServerParameters
keepalivePolicy keepalive.EnforcementPolicy
}

var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
Expand All @@ -133,6 +134,13 @@ func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
}
}

// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
return func(o *options) {
o.keepalivePolicy = kep
}
}

// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
func CustomCodec(codec Codec) ServerOption {
return func(o *options) {
Expand Down Expand Up @@ -479,6 +487,7 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo)
InTapHandle: s.opts.inTapHandle,
StatsHandler: s.opts.statsHandler,
KeepaliveParams: s.opts.keepaliveParams,
KeepalivePolicy: s.opts.keepalivePolicy,
}
st, err := transport.NewServerTransport("http2", c, config)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions transport/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const (
defaultMaxConnectionAgeGrace = infinity
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
)

// The following defines various control items which could flow through
Expand Down Expand Up @@ -84,6 +85,8 @@ type resetStream struct {
func (*resetStream) item() {}

type goAway struct {
code http2.ErrCode
debugData []byte
}

func (*goAway) item() {}
Expand Down
75 changes: 72 additions & 3 deletions transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ type http2Server struct {
// Keepalive and max-age parameters for the server.
kp keepalive.ServerParameters

// Keepalive enforcement policy.
kep keepalive.EnforcementPolicy
// The time instance last ping was received.
lastPingAt time.Time
// Number of times the client has violated keepalive ping policy so far.
pingStrikes uint8
// Flag to signify that number of ping strikes should be reset to 0.
// This is set whenever data or header frames are sent.
// 1 means yes.
resetPingStrikes uint32 // Accessed atomically.

mu sync.Mutex // guard the following
state transportState
activeStreams map[uint32]*Stream
Expand Down Expand Up @@ -161,6 +172,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
if kp.Timeout == 0 {
kp.Timeout = defaultServerKeepaliveTimeout
}
kep := config.KeepalivePolicy
if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime
}
var buf bytes.Buffer
t := &http2Server{
ctx: context.Background(),
Expand All @@ -184,6 +199,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
stats: config.StatsHandler,
kp: kp,
idle: time.Now(),
kep: kep,
}
if t.stats != nil {
t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{
Expand Down Expand Up @@ -504,13 +520,49 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
t.controlBuf.put(&settings{ack: true, ss: ss})
}

const (
maxPingStrikes = 2
)

func (t *http2Server) handlePing(f *http2.PingFrame) {
if f.IsAck() { // Do nothing.
return
}
pingAck := &ping{ack: true}
copy(pingAck.data[:], f.Data[:])
t.controlBuf.put(pingAck)

now := time.Now()
defer func() {
t.lastPingAt = now
}()
// A reset ping strikes means that we don't need to check for policy
// violation for this ping and the pingStrikes counter should be set
// to 0.
if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) {
t.pingStrikes = 0
return
}
t.mu.Lock()
ns := len(t.activeStreams)
t.mu.Unlock()
if ns < 1 && !t.kep.PermitWithoutStream {
// Keepalive shouldn't be active thus, this new ping should
// have come after atleast two hours.
if t.lastPingAt.Add(2 * time.Hour).After(now) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 2*time.Hour should probably be a constant (next to maxPingStrikes?).

t.pingStrikes++
}
} else {
// Check if keepalive policy is respected.
if t.lastPingAt.Add(t.kep.MinTime).After(now) {
t.pingStrikes++
}
}

if t.pingStrikes > maxPingStrikes {
// Send goaway and close the connection.
t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings")})
}
}

func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
Expand All @@ -529,6 +581,13 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
first := true
endHeaders := false
var err error
defer func() {
if err == nil {
// Reset ping strikes when seding headers since that might cause the
// peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
}()
// Sends the headers in a single batch.
for !endHeaders {
size := t.hBuf.Len()
Expand Down Expand Up @@ -672,7 +731,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s

// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
// TODO(zhaoq): Support multi-writers for a single stream.
var writeHeaderFrame bool
s.mu.Lock()
Expand All @@ -687,6 +746,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
if writeHeaderFrame {
t.WriteHeader(s, nil)
}
defer func() {
if err == nil {
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
}()
r := bytes.NewBuffer(data)
for {
if r.Len() == 0 {
Expand Down Expand Up @@ -892,7 +958,10 @@ func (t *http2Server) controller() {
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
t.framer.writeGoAway(true, sid, i.code, i.debugData)
if i.code == http2.ErrCodeEnhanceYourCalm {
t.Close()
}
case *flushIO:
t.framer.flushWrite()
case *ping:
Expand Down Expand Up @@ -972,7 +1041,7 @@ func (t *http2Server) RemoteAddr() net.Addr {
}

func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{})
t.controlBuf.put(&goAway{code: http2.ErrCodeNo})
}

var rgen = rand.New(rand.NewSource(time.Now().UnixNano()))
Expand Down
1 change: 1 addition & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ type ServerConfig struct {
InTapHandle tap.ServerInHandle
StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters
KeepalivePolicy keepalive.EnforcementPolicy
}

// NewServerTransport creates a ServerTransport with conn or non-nil error
Expand Down
129 changes: 129 additions & 0 deletions transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ func TestMaxConnectionIdle(t *testing.T) {
timeout := time.NewTimer(time.Second * 4)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test timed out, expected a GoAway from the server.")
}
Expand All @@ -345,6 +348,9 @@ func TestMaxConnectionIdleNegative(t *testing.T) {
timeout := time.NewTimer(time.Second * 4)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
t.Fatalf("A non-idle client received a GoAway.")
case <-timeout.C:
}
Expand All @@ -369,6 +375,9 @@ func TestMaxConnectionAge(t *testing.T) {
timeout := time.NewTimer(4 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test timer out, expected a GoAway from the server.")
}
Expand Down Expand Up @@ -523,6 +532,126 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) {
}
}

func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 2 * time.Second,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 50 * time.Millisecond,
Timeout: 50 * time.Millisecond,
PermitWithoutStream: true,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions)
defer server.stop()
defer client.Close()

timeout := time.NewTimer(2 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test failed: Expected a GoAway from server.")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check for connection closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 2 * time.Second,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 50 * time.Millisecond,
Timeout: 50 * time.Millisecond,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions)
defer server.stop()
defer client.Close()

if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil {
t.Fatalf("Client failed to create stream.")
}
timeout := time.NewTimer(2 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test failed: Expected a GoAway from server.")
}

}

func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 100 * time.Millisecond,
PermitWithoutStream: true,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 100 * time.Millisecond,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be 101ms (or more) to avoid a race between the timers?

Timeout: 50 * time.Millisecond,
PermitWithoutStream: true,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions)
defer server.stop()
defer client.Close()

// Give keepalive enough time.
time.Sleep(2 * time.Second)
// Asser that connection is healthy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsserT

ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected connection to be healthy.")
}
}

func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 100 * time.Millisecond,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 100 * time.Millisecond,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern.

Timeout: 50 * time.Millisecond,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions)
defer server.stop()
defer client.Close()

if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil {
t.Fatalf("Client failed to create stream.")
}

// Give keepalive enough time.
time.Sleep(2 * time.Second)
// Asser that connection is healthy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsserT

ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected connection to be healthy.")
}

}

func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Expand Down