Skip to content

Commit 54fdc7a

Browse files
timandyerikdubbelboer
andauthoredAug 10, 2023
Abstracts the RoundTripper interface and provides a default implement (#1602)
* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601) * test: Add custom transport test case (#1601) * Make default RoundTripper implmention none public Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com> --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
1 parent e181af1 commit 54fdc7a

File tree

2 files changed

+227
-129
lines changed

2 files changed

+227
-129
lines changed
 

‎client.go

+129-113
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
628628
// Request argument passed to RetryIfFunc, if there are any request errors.
629629
type RetryIfFunc func(request *Request) bool
630630

631-
// TransportFunc wraps every request/response.
632-
type TransportFunc func(*Request, *Response) error
631+
// RoundTripper wraps every request/response.
632+
type RoundTripper interface {
633+
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
634+
}
633635

634636
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
635637
type ConnPoolStrategyType int
@@ -791,7 +793,7 @@ type HostClient struct {
791793
RetryIf RetryIfFunc
792794

793795
// Transport defines a transport-like mechanism that wraps every request/response.
794-
Transport TransportFunc
796+
Transport RoundTripper
795797

796798
// Connection pool strategy. Can be either LIFO or FIFO (default).
797799
ConnPoolStrategy ConnPoolStrategyType
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
13431345
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
13441346
}
13451347
}
1346-
if c.Transport != nil {
1347-
err := c.Transport(req, resp)
1348-
return err == nil, err
1349-
}
1350-
1351-
var deadline time.Time
1352-
if req.timeout > 0 {
1353-
deadline = time.Now().Add(req.timeout)
1354-
}
1355-
1356-
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
1357-
if err != nil {
1358-
return false, err
1359-
}
1360-
conn := cc.c
1361-
1362-
resp.parseNetConn(conn)
1363-
1364-
writeDeadline := deadline
1365-
if c.WriteTimeout > 0 {
1366-
tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
1367-
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
1368-
writeDeadline = tmpWriteDeadline
1369-
}
1370-
}
13711348

1372-
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
1373-
c.closeConn(cc)
1374-
return true, err
1375-
}
1376-
1377-
resetConnection := false
1378-
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
1379-
req.SetConnectionClose()
1380-
resetConnection = true
1381-
}
1382-
1383-
bw := c.acquireWriter(conn)
1384-
err = req.Write(bw)
1385-
1386-
if resetConnection {
1387-
req.Header.ResetConnectionClose()
1388-
}
1389-
1390-
if err == nil {
1391-
err = bw.Flush()
1392-
}
1393-
c.releaseWriter(bw)
1394-
1395-
// Return ErrTimeout on any timeout.
1396-
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
1397-
err = ErrTimeout
1398-
}
1399-
1400-
isConnRST := isConnectionReset(err)
1401-
if err != nil && !isConnRST {
1402-
c.closeConn(cc)
1403-
return true, err
1404-
}
1405-
1406-
readDeadline := deadline
1407-
if c.ReadTimeout > 0 {
1408-
tmpReadDeadline := time.Now().Add(c.ReadTimeout)
1409-
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
1410-
readDeadline = tmpReadDeadline
1411-
}
1412-
}
1413-
1414-
if err = conn.SetReadDeadline(readDeadline); err != nil {
1415-
c.closeConn(cc)
1416-
return true, err
1417-
}
1418-
1419-
if customSkipBody || req.Header.IsHead() {
1420-
resp.SkipBody = true
1421-
}
1422-
if c.DisableHeaderNamesNormalizing {
1423-
resp.Header.DisableNormalizing()
1424-
}
1425-
1426-
br := c.acquireReader(conn)
1427-
err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
1428-
c.releaseReader(br)
1429-
if err != nil {
1430-
c.closeConn(cc)
1431-
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
1432-
retry := err != ErrBodyTooLarge
1433-
return retry, err
1434-
}
1435-
1436-
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
1437-
if customStreamBody && resp.bodyStream != nil {
1438-
rbs := resp.bodyStream
1439-
resp.bodyStream = newCloseReader(rbs, func() error {
1440-
if r, ok := rbs.(*requestStream); ok {
1441-
releaseRequestStream(r)
1442-
}
1443-
if closeConn {
1444-
c.closeConn(cc)
1445-
} else {
1446-
c.releaseConn(cc)
1447-
}
1448-
return nil
1449-
})
1450-
return false, nil
1451-
}
1349+
return c.transport().RoundTrip(c, req, resp)
1350+
}
14521351

1453-
if closeConn {
1454-
c.closeConn(cc)
1455-
} else {
1456-
c.releaseConn(cc)
1352+
func (c *HostClient) transport() RoundTripper {
1353+
if c.Transport == nil {
1354+
return DefaultTransport
14571355
}
1458-
return false, nil
1356+
return c.Transport
14591357
}
14601358

14611359
var (
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
29092807
}
29102808

29112809
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
2810+
2811+
var DefaultTransport RoundTripper = &transport{}
2812+
2813+
type transport struct{}
2814+
2815+
func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
2816+
customSkipBody := resp.SkipBody
2817+
customStreamBody := resp.StreamBody
2818+
2819+
var deadline time.Time
2820+
if req.timeout > 0 {
2821+
deadline = time.Now().Add(req.timeout)
2822+
}
2823+
2824+
cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
2825+
if err != nil {
2826+
return false, err
2827+
}
2828+
conn := cc.c
2829+
2830+
resp.parseNetConn(conn)
2831+
2832+
writeDeadline := deadline
2833+
if hc.WriteTimeout > 0 {
2834+
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
2835+
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
2836+
writeDeadline = tmpWriteDeadline
2837+
}
2838+
}
2839+
2840+
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
2841+
hc.closeConn(cc)
2842+
return true, err
2843+
}
2844+
2845+
resetConnection := false
2846+
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
2847+
req.SetConnectionClose()
2848+
resetConnection = true
2849+
}
2850+
2851+
bw := hc.acquireWriter(conn)
2852+
err = req.Write(bw)
2853+
2854+
if resetConnection {
2855+
req.Header.ResetConnectionClose()
2856+
}
2857+
2858+
if err == nil {
2859+
err = bw.Flush()
2860+
}
2861+
hc.releaseWriter(bw)
2862+
2863+
// Return ErrTimeout on any timeout.
2864+
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
2865+
err = ErrTimeout
2866+
}
2867+
2868+
isConnRST := isConnectionReset(err)
2869+
if err != nil && !isConnRST {
2870+
hc.closeConn(cc)
2871+
return true, err
2872+
}
2873+
2874+
readDeadline := deadline
2875+
if hc.ReadTimeout > 0 {
2876+
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
2877+
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
2878+
readDeadline = tmpReadDeadline
2879+
}
2880+
}
2881+
2882+
if err = conn.SetReadDeadline(readDeadline); err != nil {
2883+
hc.closeConn(cc)
2884+
return true, err
2885+
}
2886+
2887+
if customSkipBody || req.Header.IsHead() {
2888+
resp.SkipBody = true
2889+
}
2890+
if hc.DisableHeaderNamesNormalizing {
2891+
resp.Header.DisableNormalizing()
2892+
}
2893+
2894+
br := hc.acquireReader(conn)
2895+
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
2896+
hc.releaseReader(br)
2897+
if err != nil {
2898+
hc.closeConn(cc)
2899+
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
2900+
needRetry := err != ErrBodyTooLarge
2901+
return needRetry, err
2902+
}
2903+
2904+
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
2905+
if customStreamBody && resp.bodyStream != nil {
2906+
rbs := resp.bodyStream
2907+
resp.bodyStream = newCloseReader(rbs, func() error {
2908+
if r, ok := rbs.(*requestStream); ok {
2909+
releaseRequestStream(r)
2910+
}
2911+
if closeConn {
2912+
hc.closeConn(cc)
2913+
} else {
2914+
hc.releaseConn(cc)
2915+
}
2916+
return nil
2917+
})
2918+
return false, nil
2919+
}
2920+
2921+
if closeConn {
2922+
hc.closeConn(cc)
2923+
} else {
2924+
hc.releaseConn(cc)
2925+
}
2926+
return false, nil
2927+
}

‎client_test.go

+98-16
Original file line numberDiff line numberDiff line change
@@ -2111,6 +2111,22 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
21112111
}
21122112
}
21132113

2114+
type TransportDemo struct {
2115+
br *bufio.Reader
2116+
bw *bufio.Writer
2117+
}
2118+
2119+
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
2120+
if err = req.Write(t.bw); err != nil {
2121+
return false, err
2122+
}
2123+
if err = t.bw.Flush(); err != nil {
2124+
return false, err
2125+
}
2126+
err = res.Read(t.br)
2127+
return err != nil, err
2128+
}
2129+
21142130
func TestHostClientTransport(t *testing.T) {
21152131
t.Parallel()
21162132

@@ -2131,23 +2147,13 @@ func TestHostClientTransport(t *testing.T) {
21312147

21322148
c := &HostClient{
21332149
Addr: "foobar",
2134-
Transport: func() TransportFunc {
2150+
Transport: func() RoundTripper {
21352151
c, _ := ln.Dial()
21362152

21372153
br := bufio.NewReader(c)
21382154
bw := bufio.NewWriter(c)
21392155

2140-
return func(req *Request, res *Response) error {
2141-
if err := req.Write(bw); err != nil {
2142-
return err
2143-
}
2144-
2145-
if err := bw.Flush(); err != nil {
2146-
return err
2147-
}
2148-
2149-
return res.Read(br)
2150-
}
2156+
return TransportDemo{br: br, bw: bw}
21512157
}(),
21522158
}
21532159

@@ -3060,14 +3066,18 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
30603066
}
30613067
}
30623068

3069+
type TransportEmpty struct{}
3070+
3071+
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
3072+
return false, nil
3073+
}
3074+
30633075
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
30643076
t.Parallel()
30653077

30663078
client := HostClient{
3067-
IsTLS: true,
3068-
Transport: func(r1 *Request, r2 *Response) error {
3069-
return nil
3070-
},
3079+
IsTLS: true,
3080+
Transport: TransportEmpty{},
30713081
}
30723082

30733083
req := &Request{}
@@ -3182,3 +3192,75 @@ func Test_AddMissingPort(t *testing.T) {
31823192
})
31833193
}
31843194
}
3195+
3196+
type TransportWrapper struct {
3197+
base RoundTripper
3198+
count *int
3199+
t *testing.T
3200+
}
3201+
3202+
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
3203+
req.Header.Set("trace-id", "123")
3204+
tw.assertRequestLog(req.String())
3205+
retry, err := tw.transport().RoundTrip(hc, req, resp)
3206+
resp.Header.Set("trace-id", "124")
3207+
tw.assertResponseLog(resp.String())
3208+
*tw.count++
3209+
return retry, err
3210+
}
3211+
3212+
func (tw *TransportWrapper) transport() RoundTripper {
3213+
if tw.base == nil {
3214+
return DefaultTransport
3215+
}
3216+
return tw.base
3217+
}
3218+
3219+
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
3220+
if !strings.Contains(reqLog, "Trace-Id: 123") {
3221+
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
3222+
}
3223+
}
3224+
3225+
func (tw *TransportWrapper) assertResponseLog(respLog string) {
3226+
if !strings.Contains(respLog, "Trace-Id: 124") {
3227+
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
3228+
}
3229+
}
3230+
3231+
func TestClientTransportEx(t *testing.T) {
3232+
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
3233+
defer sHTTP.Stop()
3234+
3235+
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
3236+
defer sHTTPS.Stop()
3237+
3238+
count := 0
3239+
c := &Client{
3240+
TLSConfig: &tls.Config{
3241+
InsecureSkipVerify: true,
3242+
},
3243+
ConfigureClient: func(hc *HostClient) error {
3244+
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
3245+
return nil
3246+
},
3247+
}
3248+
// test transport
3249+
const loopCount = 4
3250+
const getCount = 20
3251+
const postCount = 10
3252+
for i := 0; i < loopCount; i++ {
3253+
addr := "http://" + sHTTP.Addr()
3254+
if i&1 != 0 {
3255+
addr = "https://" + sHTTPS.Addr()
3256+
}
3257+
// test get
3258+
testClientGet(t, c, addr, getCount)
3259+
// test post
3260+
testClientPost(t, c, addr, postCount)
3261+
}
3262+
roundTripCount := loopCount * (getCount + postCount)
3263+
if count != roundTripCount {
3264+
t.Errorf("round trip count should be: %v", roundTripCount)
3265+
}
3266+
}

0 commit comments

Comments
 (0)
Please sign in to comment.