Skip to content

Commit

Permalink
optimize: add WithClient
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Sep 26, 2022
1 parent d461bf2 commit 9860d34
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 35 deletions.
3 changes: 2 additions & 1 deletion internal/gopsutil/net/net_darwin.go
Expand Up @@ -7,11 +7,12 @@ import (
"context"
"errors"
"fmt"
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
"os/exec"
"regexp"
"strconv"
"strings"

"github.com/gofiber/fiber/v2/internal/gopsutil/common"
)

var (
Expand Down
23 changes: 19 additions & 4 deletions middleware/proxy/README.md
Expand Up @@ -13,8 +13,8 @@ Proxy middleware for [Fiber](https://github.com/gofiber/fiber) that allows you t

```go
func Balancer(config Config) fiber.Handler
func Forward(addr string) fiber.Handler
func Do(c *fiber.Ctx, addr string) error
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
```

### Examples
Expand All @@ -37,9 +37,21 @@ proxy.WithTlsConfig(&tls.Config{
InsecureSkipVerify: true,
})

// if you need to use global self-custom client, you should use proxy.WithClient.
proxy.WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})

// Forward to url
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif"))

// Forward to url with local custom client
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}))

// Make request within handler
app.Get("/:id", func(c *fiber.Ctx) error {
url := "https://i.imgur.com/"+c.Params("id")+".gif"
Expand Down Expand Up @@ -120,8 +132,11 @@ type Config struct {
// Per-connection buffer size for responses' writing.
WriteBufferSize int

// tls config for the http client
TlsConfig *tls.Config
// tls config for the http client.
TlsConfig *tls.Config

// Client is custom client when client config is complex.
Client *fasthttp.LBClient
}
```

Expand Down
7 changes: 5 additions & 2 deletions middleware/proxy/config.go
Expand Up @@ -47,8 +47,11 @@ type Config struct {
// Per-connection buffer size for responses' writing.
WriteBufferSize int

// tls config for the http client
// tls config for the http client.
TlsConfig *tls.Config

// Client is custom client when client config is complex.
Client *fasthttp.LBClient
}

// ConfigDefault is the default config
Expand All @@ -75,7 +78,7 @@ func configDefault(config ...Config) Config {
}

// Set default values
if len(cfg.Servers) == 0 {
if len(cfg.Servers) == 0 && cfg.Client == nil {
panic("Servers cannot be empty")
}
return cfg
Expand Down
74 changes: 46 additions & 28 deletions middleware/proxy/proxy.go
Expand Up @@ -24,34 +24,37 @@ func Balancer(config Config) fiber.Handler {
cfg := configDefault(config)

// Load balanced client
var lbc fasthttp.LBClient
// Set timeout
lbc.Timeout = cfg.Timeout

// Scheme must be provided, falls back to http
// TODO add https support
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
var lbc = &fasthttp.LBClient{}
if config.Client == nil {
// Set timeout
lbc.Timeout = cfg.Timeout
// Scheme must be provided, falls back to http
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}

u, err := url.Parse(server)
if err != nil {
panic(err)
}
u, err := url.Parse(server)
if err != nil {
panic(err)
}

client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,

ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,

TLSConfig: config.TlsConfig,
}
TLSConfig: config.TlsConfig,
}

lbc.Clients = append(lbc.Clients, client)
lbc.Clients = append(lbc.Clients, client)
}
} else {
// Set custom client
lbc = config.Client
}

// Return new handler
Expand Down Expand Up @@ -97,28 +100,43 @@ func Balancer(config Config) fiber.Handler {
}
}

var client = fasthttp.Client{
var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}

// WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward.
// Deprecated: use WithClient instead.
func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig
}

// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {
client = cli
}

// Forward performs the given http request and fills the given http response.
// This method will return an fiber.Handler
func Forward(addr string) fiber.Handler {
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c *fiber.Ctx) error {
return Do(c, addr)
return Do(c, addr, clients...)
}
}

// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string) error {
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
var cli *fasthttp.Client
if len(clients) != 0 {
// Set local client
cli = clients[0]
} else {
// Set global client
cli = client
}
req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
Expand All @@ -134,7 +152,7 @@ func Do(c *fiber.Ctx, addr string) error {
}

req.Header.Del(fiber.HeaderConnection)
if err := client.Do(req, res); err != nil {
if err := cli.Do(req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
Expand Down
84 changes: 84 additions & 0 deletions middleware/proxy/proxy_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)

func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) {
Expand Down Expand Up @@ -364,6 +365,7 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
utils.AssertEqual(t, nil, err2)
}

// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel()

Expand All @@ -390,3 +392,85 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "hello world", string(s))
}

// go test -race -run Test_Proxy_Forward_Global_Client
func Test_Proxy_Forward_Global_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_global_client", func(c *fiber.Ctx) error {
return c.SendString("test_global_client")
})

addr := ln.Addr().String()
app.Use(Forward("http://" + addr + "/test_global_client"))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()

code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_global_client", body)
}

// go test -race -run Test_Proxy_Forward_Local_Client
func Test_Proxy_Forward_Local_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_local_client", func(c *fiber.Ctx) error {
return c.SendString("test_local_client")
})

addr := ln.Addr().String()
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: func(addr string) (net.Conn, error) {
return fasthttp.Dial(addr)
},
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()

code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_local_client", body)
}

// go test -run Test_ProxyBalancer_Custom_Client
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel()

target, addr := createProxyTestServer(
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
)

resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)

app := fiber.New(fiber.Config{DisableStartupMessage: true})

app.Use(Balancer(Config{Client: &fasthttp.LBClient{
Clients: []fasthttp.BalancingClient{
&fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: addr,
},
},
Timeout: time.Second,
}}))

req := httptest.NewRequest("GET", "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}

0 comments on commit 9860d34

Please sign in to comment.