Skip to content

Commit c549e0d

Browse files
Smirlonsi
authored andcommittedJul 14, 2024
Add RoundTripper method to ghttp.Server
1 parent 0e69083 commit c549e0d

File tree

3 files changed

+136
-29
lines changed

3 files changed

+136
-29
lines changed
 

‎docs/index.md

+27
Original file line numberDiff line numberDiff line change
@@ -2424,6 +2424,33 @@ To bring it all together: there are three ways to instruct a `ghttp` server to h
24242424

24252425
When a `ghttp` server receives a request it first checks against the set of handlers registered via `RouteToHandler` if there is no such handler it proceeds to pop an `AppendHandlers` handler off the stack, if the stack of ordered handlers is empty, it will check whether `GetAllowUnhandledRequests` returns `true` or `false`. If `false` the test fails. If `true`, a response is sent with whatever `GetUnhandledRequestStatusCode` returns.
24262426

2427+
### Using a RoundTripper to route requests to the test Server
2428+
2429+
So far you have seen examples of using `server.URL()` to get the string URL of the test server. This is ok if you are testing code where you can pass the URL. In some cases you might need to pass a `http.Client` or similar.
2430+
2431+
You can use `server.RounderTripper(nil)` to create a `http.RounderTripper` which will redirect requests to the test server.
2432+
2433+
The method takes another `http.RounderTripper` to make the request to the test server, this allows chaining `http.Transports` or otherwise.
2434+
2435+
If passed `nil`, then `http.DefaultTransport` is used to make the request.
2436+
2437+
```go
2438+
Describe("The http client", func() {
2439+
var server *ghttp.Server
2440+
var httpClient *http.Client
2441+
2442+
BeforeEach(func() {
2443+
server = ghttp.NewServer()
2444+
httpClient = &http.Client{Transport: server.RounderTripper(nil)}
2445+
})
2446+
2447+
AfterEach(func() {
2448+
//shut down the server between tests
2449+
server.Close()
2450+
})
2451+
})
2452+
```
2453+
24272454
## `gbytes`: Testing Streaming Buffers
24282455

24292456
`gbytes` implements `gbytes.Buffer` - an `io.WriteCloser` that captures all input to an in-memory buffer.

‎ghttp/test_server.go

+50-29
Original file line numberDiff line numberDiff line change
@@ -186,26 +186,26 @@ type Server struct {
186186
calls int
187187
}
188188

189-
//Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
189+
// Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest).
190190
func (s *Server) Start() {
191191
s.HTTPTestServer.Start()
192192
}
193193

194-
//URL() returns a url that will hit the server
194+
// URL() returns a url that will hit the server
195195
func (s *Server) URL() string {
196196
s.rwMutex.RLock()
197197
defer s.rwMutex.RUnlock()
198198
return s.HTTPTestServer.URL
199199
}
200200

201-
//Addr() returns the address on which the server is listening.
201+
// Addr() returns the address on which the server is listening.
202202
func (s *Server) Addr() string {
203203
s.rwMutex.RLock()
204204
defer s.rwMutex.RUnlock()
205205
return s.HTTPTestServer.Listener.Addr().String()
206206
}
207207

208-
//Close() should be called at the end of each test. It spins down and cleans up the test server.
208+
// Close() should be called at the end of each test. It spins down and cleans up the test server.
209209
func (s *Server) Close() {
210210
s.rwMutex.Lock()
211211
server := s.HTTPTestServer
@@ -217,14 +217,14 @@ func (s *Server) Close() {
217217
}
218218
}
219219

220-
//ServeHTTP() makes Server an http.Handler
221-
//When the server receives a request it handles the request in the following order:
220+
// ServeHTTP() makes Server an http.Handler
221+
// When the server receives a request it handles the request in the following order:
222222
//
223-
//1. If the request matches a handler registered with RouteToHandler, that handler is called.
224-
//2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
225-
//3. If all registered handlers have been called then:
226-
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
227-
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
223+
// 1. If the request matches a handler registered with RouteToHandler, that handler is called.
224+
// 2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
225+
// 3. If all registered handlers have been called then:
226+
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
227+
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
228228
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
229229
s.rwMutex.Lock()
230230
defer func() {
@@ -280,18 +280,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
280280
}
281281
}
282282

283-
//ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
283+
// ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests)
284284
func (s *Server) ReceivedRequests() []*http.Request {
285285
s.rwMutex.RLock()
286286
defer s.rwMutex.RUnlock()
287287

288288
return s.receivedRequests
289289
}
290290

291-
//RouteToHandler can be used to register handlers that will always handle requests that match
292-
//the passed in method and path.
291+
// RouteToHandler can be used to register handlers that will always handle requests that match
292+
// the passed in method and path.
293293
//
294-
//The path may be either a string object or a *regexp.Regexp.
294+
// The path may be either a string object or a *regexp.Regexp.
295295
func (s *Server) RouteToHandler(method string, path interface{}, handler http.HandlerFunc) {
296296
s.rwMutex.Lock()
297297
defer s.rwMutex.Unlock()
@@ -337,25 +337,25 @@ func (s *Server) handlerForRoute(method string, path string) (http.HandlerFunc,
337337
return nil, false
338338
}
339339

340-
//AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
340+
// AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc...
341341
func (s *Server) AppendHandlers(handlers ...http.HandlerFunc) {
342342
s.rwMutex.Lock()
343343
defer s.rwMutex.Unlock()
344344

345345
s.requestHandlers = append(s.requestHandlers, handlers...)
346346
}
347347

348-
//SetHandler overrides the registered handler at the passed in index with the passed in handler
349-
//This is useful, for example, when a server has been set up in a shared context, but must be tweaked
350-
//for a particular test.
348+
// SetHandler overrides the registered handler at the passed in index with the passed in handler
349+
// This is useful, for example, when a server has been set up in a shared context, but must be tweaked
350+
// for a particular test.
351351
func (s *Server) SetHandler(index int, handler http.HandlerFunc) {
352352
s.rwMutex.Lock()
353353
defer s.rwMutex.Unlock()
354354

355355
s.requestHandlers[index] = handler
356356
}
357357

358-
//GetHandler returns the handler registered at the passed in index.
358+
// GetHandler returns the handler registered at the passed in index.
359359
func (s *Server) GetHandler(index int) http.HandlerFunc {
360360
s.rwMutex.RLock()
361361
defer s.rwMutex.RUnlock()
@@ -374,12 +374,12 @@ func (s *Server) Reset() {
374374
s.routedHandlers = nil
375375
}
376376

377-
//WrapHandler combines the passed in handler with the handler registered at the passed in index.
378-
//This is useful, for example, when a server has been set up in a shared context but must be tweaked
379-
//for a particular test.
377+
// WrapHandler combines the passed in handler with the handler registered at the passed in index.
378+
// This is useful, for example, when a server has been set up in a shared context but must be tweaked
379+
// for a particular test.
380380
//
381-
//If the currently registered handler is A, and the new passed in handler is B then
382-
//WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
381+
// If the currently registered handler is A, and the new passed in handler is B then
382+
// WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index
383383
func (s *Server) WrapHandler(index int, handler http.HandlerFunc) {
384384
existingHandler := s.GetHandler(index)
385385
s.SetHandler(index, CombineHandlers(existingHandler, handler))
@@ -392,34 +392,55 @@ func (s *Server) CloseClientConnections() {
392392
s.HTTPTestServer.CloseClientConnections()
393393
}
394394

395-
//SetAllowUnhandledRequests enables the server to accept unhandled requests.
395+
// SetAllowUnhandledRequests enables the server to accept unhandled requests.
396396
func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) {
397397
s.rwMutex.Lock()
398398
defer s.rwMutex.Unlock()
399399

400400
s.AllowUnhandledRequests = allowUnhandledRequests
401401
}
402402

403-
//GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
403+
// GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
404404
func (s *Server) GetAllowUnhandledRequests() bool {
405405
s.rwMutex.RLock()
406406
defer s.rwMutex.RUnlock()
407407

408408
return s.AllowUnhandledRequests
409409
}
410410

411-
//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
411+
// SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
412412
func (s *Server) SetUnhandledRequestStatusCode(statusCode int) {
413413
s.rwMutex.Lock()
414414
defer s.rwMutex.Unlock()
415415

416416
s.UnhandledRequestStatusCode = statusCode
417417
}
418418

419-
//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
419+
// GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
420420
func (s *Server) GetUnhandledRequestStatusCode() int {
421421
s.rwMutex.RLock()
422422
defer s.rwMutex.RUnlock()
423423

424424
return s.UnhandledRequestStatusCode
425425
}
426+
427+
// RoundTripper returns a RoundTripper which updates requests to point to the server.
428+
// This is useful when you want to use the server as a RoundTripper in an http.Client.
429+
// If rt is nil, http.DefaultTransport is used.
430+
func (s *Server) RoundTripper(rt http.RoundTripper) http.RoundTripper {
431+
if rt == nil {
432+
rt = http.DefaultTransport
433+
}
434+
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
435+
r.URL.Scheme = "http"
436+
r.URL.Host = s.Addr()
437+
return rt.RoundTrip(r)
438+
})
439+
}
440+
441+
// Helper type for creating a RoundTripper from a function
442+
type RoundTripperFunc func(*http.Request) (*http.Response, error)
443+
444+
func (fn RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
445+
return fn(r)
446+
}

‎ghttp/test_server_test.go

+59
Original file line numberDiff line numberDiff line change
@@ -1190,4 +1190,63 @@ var _ = Describe("TestServer", func() {
11901190
})
11911191
})
11921192
})
1193+
1194+
Describe("RoundTripper", func() {
1195+
var called []string
1196+
BeforeEach(func() {
1197+
called = []string{}
1198+
s.RouteToHandler("GET", "/routed", func(w http.ResponseWriter, req *http.Request) {
1199+
called = append(called, "get")
1200+
})
1201+
s.RouteToHandler("POST", "/routed", func(w http.ResponseWriter, req *http.Request) {
1202+
called = append(called, "post")
1203+
})
1204+
})
1205+
1206+
It("should send http traffic to test server with default transport", func() {
1207+
client := http.Client{Transport: s.RoundTripper(nil)}
1208+
client.Get("http://example.com/routed")
1209+
client.Post("http://example.com/routed", "application/json", nil)
1210+
client.Get("http://foo.bar/routed")
1211+
client.Post("http://foo.bar/routed", "application/json", nil)
1212+
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
1213+
})
1214+
1215+
It("should send https traffic to test server with default transport", func() {
1216+
client := http.Client{Transport: s.RoundTripper(nil)}
1217+
client.Get("https://example.com/routed")
1218+
client.Post("https://example.com/routed", "application/json", nil)
1219+
client.Get("https://foo.bar/routed")
1220+
client.Post("https://foo.bar/routed", "application/json", nil)
1221+
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
1222+
})
1223+
1224+
It("should send http traffic to test server with default transport", func() {
1225+
transport := http.Transport{}
1226+
client := http.Client{Transport: s.RoundTripper(&transport)}
1227+
client.Get("http://example.com/routed")
1228+
client.Post("http://example.com/routed", "application/json", nil)
1229+
client.Get("http://foo.bar/routed")
1230+
client.Post("http://foo.bar/routed", "application/json", nil)
1231+
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
1232+
})
1233+
1234+
It("should send http traffic to test server with default transport", func() {
1235+
transport := http.Transport{}
1236+
client := http.Client{Transport: s.RoundTripper(&transport)}
1237+
client.Get("https://example.com/routed")
1238+
client.Post("https://example.com/routed", "application/json", nil)
1239+
client.Get("https://foo.bar/routed")
1240+
client.Post("https://foo.bar/routed", "application/json", nil)
1241+
Expect(called).Should(Equal([]string{"get", "post", "get", "post"}))
1242+
})
1243+
1244+
It("should not change the path of the request", func() {
1245+
client := http.Client{Transport: s.RoundTripper(nil)}
1246+
client.Get("https://example.com/routed")
1247+
Expect(called).Should(Equal([]string{"get"}))
1248+
Expect(s.ReceivedRequests()).Should(HaveLen(1))
1249+
Expect(s.ReceivedRequests()[0].URL.Path).Should(Equal("/routed"))
1250+
})
1251+
})
11931252
})

0 commit comments

Comments
 (0)
Please sign in to comment.