Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

middleware: basic auth middleware can extract and check multiple auth… #2539

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 26 additions & 18 deletions middleware/basic_auth.go
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"encoding/base64"
"net/http"
"strconv"
Expand All @@ -15,7 +16,8 @@ type (
// Skipper defines a function to skip middleware.
Skipper Skipper

// Validator is a function to validate BasicAuth credentials.
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic
// auth headers this function would be called once for each header until first valid result is returned
// Required.
Validator BasicAuthValidator

Expand Down Expand Up @@ -71,30 +73,36 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}

auth := c.Request().Header.Get(echo.HeaderAuthorization)
var lastError error
l := len(basic)
for i, auth := range c.Request().Header[echo.HeaderAuthorization] {
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}

if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
continue
}

cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
} else if valid {
return next(c)
}
break
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c)
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
}
if i >= headerCountLimit { // guard against attacker maliciously sending huge amount of invalid headers
break
}
}

if lastError != nil {
return lastError
}

realm := defaultRealm
Expand Down
182 changes: 136 additions & 46 deletions middleware/basic_auth_test.go
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -11,11 +12,139 @@ import (
"github.com/stretchr/testify/assert"
)

func TestBasicAuthWithConfig(t *testing.T) {
validatorFunc := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}

// we can not add OK value here because ranging over map returns random order. We just try to trigger break
tooManyAuths := make([]string, 0)
for i := 0; i < extractorLimit+2; i++ {
tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope")))
}

var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
aldas marked this conversation as resolved.
Show resolved Hide resolved
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, from multiple auth headers one is ok",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type
basic + " NOT_BASE64", // invalid basic auth
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, too many invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: tooManyAuths,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()

mw := BasicAuthWithConfig(tc.givenConfig)

h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()

if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err := h(e.NewContext(req, res))

if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
aldas marked this conversation as resolved.
Show resolved Hide resolved
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}

func TestBasicAuth(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
Expand All @@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
}
6 changes: 6 additions & 0 deletions middleware/middleware.go
Expand Up @@ -9,6 +9,12 @@ import (
"github.com/labstack/echo/v4"
)

const (
// headerCountLimit is arbitrary number to limit number of headers processed. this limits possible resource exhaustion
// attack vector
headerCountLimit = 20
)

type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Expand Down