Skip to content

Commit

Permalink
Merge pull request #992 from fianulabs/main
Browse files Browse the repository at this point in the history
Allow Context to Configure Default Timeout
  • Loading branch information
embano1 committed May 1, 2024
2 parents 9c00184 + 7c026b6 commit f97061a
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 4 deletions.
32 changes: 32 additions & 0 deletions v2/protocol/http/options.go
Expand Up @@ -83,6 +83,38 @@ func WithShutdownTimeout(timeout time.Duration) Option {
}
}

// WithReadTimeout overwrites the default read timeout (600s) of the http
// server. The specified timeout must not be negative. A timeout of 0 disables
// read timeouts in the http server.
func WithReadTimeout(timeout time.Duration) Option {
return func(p *Protocol) error {
if p == nil {
return fmt.Errorf("http read timeout option can not set nil protocol")
}
if timeout < 0 {
return fmt.Errorf("http read timeout must not be negative")
}
p.readTimeout = &timeout
return nil
}
}

// WithWriteTimeout overwrites the default write timeout (600s) of the http
// server. The specified timeout must not be negative. A timeout of 0 disables
// write timeouts in the http server.
func WithWriteTimeout(timeout time.Duration) Option {
return func(p *Protocol) error {
if p == nil {
return fmt.Errorf("http write timeout option can not set nil protocol")
}
if timeout < 0 {
return fmt.Errorf("http write timeout must not be negative")
}
p.writeTimeout = &timeout
return nil
}
}

func checkListen(p *Protocol, prefix string) error {
switch {
case p.listener.Load() != nil:
Expand Down
114 changes: 112 additions & 2 deletions v2/protocol/http/options_test.go
Expand Up @@ -315,6 +315,106 @@ func TestWithShutdownTimeout(t *testing.T) {
}
}

func TestWithReadTimeout(t *testing.T) {
expected := time.Minute * 4
testCases := map[string]struct {
t *Protocol
timeout time.Duration
want *Protocol
wantErr string
}{
"valid timeout": {
t: &Protocol{},
timeout: time.Minute * 4,
want: &Protocol{
readTimeout: &expected,
},
},
"negative timeout": {
t: &Protocol{},
timeout: -1,
wantErr: "http read timeout must not be negative",
},
"nil protocol": {
wantErr: "http read timeout option can not set nil protocol",
},
}
for n, tc := range testCases {
t.Run(n, func(t *testing.T) {

err := tc.t.applyOptions(WithReadTimeout(tc.timeout))

if tc.wantErr != "" || err != nil {
var gotErr string
if err != nil {
gotErr = err.Error()
}
if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" {
t.Errorf("unexpected error (-want, +got) = %v", diff)
}
return
}

got := tc.t

if diff := cmp.Diff(tc.want, got,
cmpopts.IgnoreUnexported(Protocol{})); diff != "" {
t.Errorf("unexpected (-want, +got) = %v", diff)
}
})
}
}

func TestWithWriteTimeout(t *testing.T) {
expected := time.Minute * 4

testCases := map[string]struct {
t *Protocol
timeout time.Duration
want *Protocol
wantErr string
}{
"valid timeout": {
t: &Protocol{},
timeout: time.Minute * 4,
want: &Protocol{
writeTimeout: &expected,
},
},
"negative timeout": {
t: &Protocol{},
timeout: -1,
wantErr: "http write timeout must not be negative",
},
"nil protocol": {
wantErr: "http write timeout option can not set nil protocol",
},
}
for n, tc := range testCases {
t.Run(n, func(t *testing.T) {

err := tc.t.applyOptions(WithWriteTimeout(tc.timeout))

if tc.wantErr != "" || err != nil {
var gotErr string
if err != nil {
gotErr = err.Error()
}
if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" {
t.Errorf("unexpected error (-want, +got) = %v", diff)
}
return
}

got := tc.t

if diff := cmp.Diff(tc.want, got,
cmpopts.IgnoreUnexported(Protocol{})); diff != "" {
t.Errorf("unexpected (-want, +got) = %v", diff)
}
})
}
}
func TestWithPort(t *testing.T) {
testCases := map[string]struct {
t *Protocol
Expand Down Expand Up @@ -389,9 +489,19 @@ func forceClose(tr *Protocol) {
}

func TestWithPort0(t *testing.T) {
noReadWriteTimeout := time.Duration(0)

testCases := map[string]func() (*Protocol, error){
"WithPort0": func() (*Protocol, error) { return New(WithPort(0)) },
"SetPort0": func() (*Protocol, error) { return &Protocol{Port: 0}, nil },
"WithPort0": func() (*Protocol, error) {
return New(WithPort(0))
},
"SetPort0": func() (*Protocol, error) {
return &Protocol{
Port: 0,
readTimeout: &noReadWriteTimeout,
writeTimeout: &noReadWriteTimeout,
}, nil
},
}
for name, f := range testCases {
t.Run(name, func(t *testing.T) {
Expand Down
23 changes: 23 additions & 0 deletions v2/protocol/http/protocol.go
Expand Up @@ -70,6 +70,18 @@ type Protocol struct {
// If 0, DefaultShutdownTimeout is used.
ShutdownTimeout time.Duration

// readTimeout defines the http.Server ReadTimeout It is the maximum duration
// for reading the entire request, including the body. If not overwritten by an
// option, the default value (600s) is used
readTimeout *time.Duration

// writeTimeout defines the http.Server WriteTimeout It is the maximum duration
// before timing out writes of the response. It is reset whenever a new
// request's header is read. Like ReadTimeout, it does not let Handlers make
// decisions on a per-request basis. If not overwritten by an option, the
// default value (600s) is used
writeTimeout *time.Duration

// Port is the port configured to bind the receiver to. Defaults to 8080.
// If you want to know the effective port you're listening to, use GetListeningPort()
Port int
Expand Down Expand Up @@ -116,6 +128,17 @@ func New(opts ...Option) (*Protocol, error) {
p.ShutdownTimeout = DefaultShutdownTimeout
}

// use default timeout from abuse protection value
defaultTimeout := DefaultTimeout

if p.readTimeout == nil {
p.readTimeout = &defaultTimeout
}

if p.writeTimeout == nil {
p.writeTimeout = &defaultTimeout
}

if p.isRetriableFunc == nil {
p.isRetriableFunc = defaultIsRetriableFunc
}
Expand Down
4 changes: 2 additions & 2 deletions v2/protocol/http/protocol_lifecycle.go
Expand Up @@ -40,8 +40,8 @@ func (p *Protocol) OpenInbound(ctx context.Context) error {
p.server = &http.Server{
Addr: listener.Addr().String(),
Handler: attachMiddleware(p.Handler, p.middleware),
ReadTimeout: DefaultTimeout,
WriteTimeout: DefaultTimeout,
ReadTimeout: *p.readTimeout,
WriteTimeout: *p.writeTimeout,
}

// Shutdown
Expand Down
3 changes: 3 additions & 0 deletions v2/protocol/http/protocol_test.go
Expand Up @@ -26,6 +26,7 @@ import (

func TestNew(t *testing.T) {
dst := DefaultShutdownTimeout
ot := DefaultTimeout

testCases := map[string]struct {
opts []Option
Expand All @@ -36,6 +37,8 @@ func TestNew(t *testing.T) {
want: &Protocol{
Client: http.DefaultClient,
ShutdownTimeout: dst,
readTimeout: &ot,
writeTimeout: &ot,
Port: -1,
},
},
Expand Down

0 comments on commit f97061a

Please sign in to comment.