diff --git a/middleware/proxy/README.md b/middleware/proxy/README.md index c3ceee9dfdc..7a2b370a472 100644 --- a/middleware/proxy/README.md +++ b/middleware/proxy/README.md @@ -37,6 +37,12 @@ proxy.WithTlsConfig(&tls.Config{ InsecureSkipVerify: true, }) +// if you need to use self-custom client,you should use proxy.WithClient +proxy.WithClient(&fasthttp.Client{ + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, +}) + // Forward to url app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index bb96b03382e..6ca3c01df0f 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -97,17 +97,24 @@ func Balancer(config Config) fiber.Handler { } } -var client = fasthttp.Client{ +var client = &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, } // WithTlsConfig update http client with a user specified tls.config // This function should be called before Do and Forward. +// Deprecated: use WithClient instead. func WithTlsConfig(tlsConfig *tls.Config) { client.TLSConfig = tlsConfig } +// WithClient sets the proxy client. +// This function should be called before Do and Forward. +func WithClient(cli *fasthttp.Client) { + client = cli +} + // Forward performs the given http request and fills the given http response. // This method will return an fiber.Handler func Forward(addr string) fiber.Handler { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 08160b2e868..5561ea0dcf4 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -13,6 +13,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/internal/tlstest" "github.com/gofiber/fiber/v2/utils" + "github.com/valyala/fasthttp" ) func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) { @@ -364,6 +365,7 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) { utils.AssertEqual(t, nil, err2) } +// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { t.Parallel() @@ -390,3 +392,36 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "hello world", string(s)) } + +// go test -race -run Test_Proxy_Forward_With_Client +func Test_Proxy_Forward_With_Client(t *testing.T) { + t.Parallel() + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:7777") + utils.AssertEqual(t, nil, err) + WithClient(&fasthttp.Client{ + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, + Dial: func(addr string) (net.Conn, error) { + utils.AssertEqual(t, "127.0.0.1:7777", addr) + return fasthttp.Dial(addr) + }, + }) + // reset global client + defer WithClient(&fasthttp.Client{ + NoDefaultUserAgentHeader: true, + DisablePathNormalizing: true, + }) + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app.Get("/test_client", func(c *fiber.Ctx) error { + return c.SendString("test_client") + }) + + addr := ln.Addr().String() + app.Use(Forward("http://" + addr + "/test_client")) + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + code, body, errs := fiber.Get("http://" + addr).String() + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, fiber.StatusOK, code) + utils.AssertEqual(t, "test_client", body) +}