Skip to content

Commit 54fdc7a

Browse files
timandyerikdubbelboer
andauthoredAug 10, 2023
Abstracts the RoundTripper interface and provides a default implement (#1602)
* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601) * test: Add custom transport test case (#1601) * Make default RoundTripper implmention none public Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com> --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
1 parent e181af1 commit 54fdc7a

File tree

2 files changed

+227
-129
lines changed

2 files changed

+227
-129
lines changed
 

‎client.go

+129-113
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
628628
// Request argument passed to RetryIfFunc, if there are any request errors.
629629
type RetryIfFunc func(request *Request) bool
630630

631-
// TransportFunc wraps every request/response.
632-
type TransportFunc func(*Request, *Response) error
631+
// RoundTripper wraps every request/response.
632+
type RoundTripper interface {
633+
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
634+
}
633635

634636
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
635637
type ConnPoolStrategyType int
@@ -791,7 +793,7 @@ type HostClient struct {
791793
RetryIf RetryIfFunc
792794

793795
// Transport defines a transport-like mechanism that wraps every request/response.
794-
Transport TransportFunc
796+
Transport RoundTripper
795797

796798
// Connection pool strategy. Can be either LIFO or FIFO (default).
797799
ConnPoolStrategy ConnPoolStrategyType
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
13431345
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
13441346
}
13451347
}
1346-
if c.Transport != nil {
1347-
err := c.Transport(req, resp)
1348-
return err == nil, err
1349-
}
1350-
1351-
var deadline time.Time
1352-
if req.timeout > 0 {
1353-
deadline = time.Now().Add(req.timeout)
1354-
}
1355-
1356-
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
1357-
if err != nil {
1358-
return false, err
1359-
}
1360-
conn := cc.c
1361-
1362-
resp.parseNetConn(conn)
1363-
1364-
writeDeadline := deadline
1365-
if c.WriteTimeout > 0 {
1366-
tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
1367-
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
1368-
writeDeadline = tmpWriteDeadline
1369-
}
1370-
}
13711348

1372-
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
1373-
c.closeConn(cc)
1374-
return true, err
1375-
}
1376-
1377-
resetConnection := false
1378-
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
1379-
req.SetConnectionClose()
1380-
resetConnection = true
1381-
}
1382-
1383-
bw := c.acquireWriter(conn)
1384-
err = req.Write(bw)
1385-
1386-
if resetConnection {
1387-
req.Header.ResetConnectionClose()
1388-
}
1389-
1390-
if err == nil {
1391-
err = bw.Flush()
1392-
}
1393-
c.releaseWriter(bw)
1394-
1395-
// Return ErrTimeout on any timeout.
1396-
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
1397-
err = ErrTimeout
1398-
}
1399-
1400-
isConnRST := isConnectionReset(err)
1401-
if err != nil && !isConnRST {
1402-
c.closeConn(cc)
1403-
return true, err
1404-
}
1405-
1406-
readDeadline := deadline
1407-
if c.ReadTimeout > 0 {
1408-
tmpReadDeadline := time.Now().Add(c.ReadTimeout)
1409-
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
1410-
readDeadline = tmpReadDeadline
1411-
}
1412-
}
1413-
1414-
if err = conn.SetReadDeadline(readDeadline); err != nil {
1415-
c.closeConn(cc)
1416-
return true, err
1417-
}
1418-
1419-
if customSkipBody || req.Header.IsHead() {
1420-
resp.SkipBody = true
1421-
}
1422-
if c.DisableHeaderNamesNormalizing {
1423-
resp.Header.DisableNormalizing()
1424-
}
1425-
1426-
br := c.acquireReader(conn)
1427-
err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
1428-
c.releaseReader(br)
1429-
if err != nil {
1430-
c.closeConn(cc)
1431-
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
1432-
retry := err != ErrBodyTooLarge
1433-
return retry, err
1434-
}
1435-
1436-
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
1437-
if customStreamBody && resp.bodyStream != nil {
1438-
rbs := resp.bodyStream
1439-
resp.bodyStream = newCloseReader(rbs, func() error {
1440-
if r, ok := rbs.(*requestStream); ok {
1441-
releaseRequestStream(r)
1442-
}
1443-
if closeConn {
1444-
c.closeConn(cc)
1445-
} else {
1446-
c.releaseConn(cc)
1447-
}
1448-
return nil
1449-
})
1450-
return false, nil
1451-
}
1349+
return c.transport().RoundTrip(c, req, resp)
1350+
}
14521351

1453-
if closeConn {
1454-
c.closeConn(cc)
1455-
} else {
1456-
c.releaseConn(cc)
1352+
func (c *HostClient) transport() RoundTripper {
1353+
if c.Transport == nil {
1354+
return DefaultTransport
14571355
}
1458-
return false, nil
1356+
return c.Transport
14591357
}
14601358

14611359
var (
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
29092807
}
29102808

29112809
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
2810+
2811+
var DefaultTransport RoundTripper = &transport{}
2812+
2813+
type transport struct{}
2814+
2815+
func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
2816+
customSkipBody := resp.SkipBody
2817+
customStreamBody := resp.StreamBody
2818+
2819+
var deadline time.Time
2820+
if req.timeout > 0 {
2821+
deadline = time.Now().Add(req.timeout)
2822+
}
2823+
2824+
cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
2825+
if err != nil {
2826+
return false, err
2827+
}
2828+
conn := cc.c
2829+
2830+
resp.parseNetConn(conn)
2831+
2832+
writeDeadline := deadline
2833+
if hc.WriteTimeout > 0 {
2834+
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
2835+
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
2836+
writeDeadline = tmpWriteDeadline
2837+
}
2838+
}
2839+
2840+
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
2841+
hc.closeConn(cc)
2842+
return true, err
2843+
}
2844+
2845+
resetConnection := false
2846+
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
2847+
req.SetConnectionClose()
2848+
resetConnection = true
2849+
}
2850+
2851+
bw := hc.acquireWriter(conn)
2852+
err = req.Write(bw)
2853+
2854+
if resetConnection {
2855+
req.Header.ResetConnectionClose()
2856+
}
2857+
2858+
if err == nil {
2859+
err = bw.Flush()
2860+
}
2861+
hc.releaseWriter(bw)
2862+
2863+
// Return ErrTimeout on any timeout.
2864+
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
2865+
err = ErrTimeout
2866+
}
2867+
2868+
isConnRST := isConnectionReset(err)
2869+
if err != nil && !isConnRST {
2870+
hc.closeConn(cc)
2871+
return true, err
2872+
}
2873+
2874+
readDeadline := deadline
2875+
if hc.ReadTimeout > 0 {
2876+
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
2877+
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
2878+
readDeadline = tmpReadDeadline
2879+
}
2880+
}
2881+
2882+
if err = conn.SetReadDeadline(readDeadline); err != nil {
2883+
hc.closeConn(cc)
2884+
return true, err
2885+
}
2886+
2887+
if customSkipBody || req.Header.IsHead() {
2888+
resp.SkipBody = true
2889+
}
2890+
if hc.DisableHeaderNamesNormalizing {
2891+
resp.Header.DisableNormalizing()
2892+
}
2893+
2894+
br := hc.acquireReader(conn)
2895+
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
2896+
hc.releaseReader(br)
2897+
if err != nil {
2898+
hc.closeConn(cc)
2899+
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
2900+
needRetry := err != ErrBodyTooLarge
2901+
return needRetry, err
2902+
}
2903+
2904+
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
2905+
if customStreamBody && resp.bodyStream != nil {
2906+
rbs := resp.bodyStream
2907+
resp.bodyStream = newCloseReader(rbs, func() error {
2908+
if r, ok := rbs.(*requestStream); ok {
2909+
releaseRequestStream(r)
2910+
}
2911+
if closeConn {
2912+
hc.closeConn(cc)
2913+
} else {
2914+
hc.releaseConn(cc)
2915+
}
2916+
return nil
2917+
})
2918+
return false, nil
2919+
}
2920+
2921+
if closeConn {
2922+
hc.closeConn(cc)
2923+
} else {
2924+
hc.releaseConn(cc)
2925+
}
2926+
return false, nil
2927+
}

0 commit comments

Comments
 (0)
Please sign in to comment.