Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add digest authentication #583

Merged
merged 3 commits into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ var (
hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length")
hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding")
hdrLocationKey = http.CanonicalHeaderKey("Location")
hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization")
hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate")

plainTextType = "text/plain; charset=utf-8"
jsonContentType = "application/json"
Expand Down Expand Up @@ -399,6 +401,34 @@ func (c *Client) SetAuthScheme(scheme string) *Client {
return c
}

// SetDigestAuth method sets the Digest Access auth scheme for the client. If a server responds with 401 and sends
// a Digest challenge in the WWW-Authenticate Header, requests will be resent with the appropriate Authorization Header.
//
// For Example: To set the Digest scheme with user "Mufasa" and password "Circle Of Life"
//
// client.SetDigestAuth("Mufasa", "Circle Of Life")
//
// Information about Digest Access Authentication can be found in RFC7616:
//
// https://datatracker.ietf.org/doc/html/rfc7616
//
// See `Request.SetDigestAuth`.
func (c *Client) SetDigestAuth(username, password string) *Client {
oldTransport := c.httpClient.Transport
c.OnBeforeRequest(func(c *Client, _ *Request) error {
c.httpClient.Transport = &digestTransport{
digestCredentials: digestCredentials{username, password},
transport: oldTransport,
}
return nil
})
c.OnAfterResponse(func(c *Client, _ *Response) error {
c.httpClient.Transport = oldTransport
return nil
})
return c
}

// R method creates a new request instance, its used for Get, Post, Put, Delete, Patch, Head, Options, etc.
func (c *Client) R() *Request {
r := &Request{
Expand Down
70 changes: 70 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,76 @@ func TestClientAuthScheme(t *testing.T) {

}

func TestClientDigestAuth(t *testing.T) {
conf := defaultDigestServerConf()
ts := createDigestServer(t, conf)
defer ts.Close()

c := dc().
SetBaseURL(ts.URL+"/").
SetDigestAuth(conf.username, conf.password)

resp, err := c.R().
SetResult(&AuthSuccess{}).
Get(conf.uri)
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())

t.Logf("Result Success: %q", resp.Result().(*AuthSuccess))
logResponse(t, resp)
}

func TestClientDigestSession(t *testing.T) {
conf := defaultDigestServerConf()
conf.algo = "MD5-sess"
ts := createDigestServer(t, conf)
defer ts.Close()

c := dc().
SetBaseURL(ts.URL+"/").
SetDigestAuth(conf.username, conf.password)

resp, err := c.R().
SetResult(&AuthSuccess{}).
Get(conf.uri)
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())

t.Logf("Result Success: %q", resp.Result().(*AuthSuccess))
logResponse(t, resp)
}

func TestClientDigestErrors(t *testing.T) {
type test struct {
mutateConf func(*digestServerConfig)
expect error
}
tests := []test{
{mutateConf: func(c *digestServerConfig) { c.algo = "BAD_ALGO" }, expect: ErrDigestAlgNotSupported},
{mutateConf: func(c *digestServerConfig) { c.qop = "bad-qop" }, expect: ErrDigestQopNotSupported},
{mutateConf: func(c *digestServerConfig) { c.qop = "" }, expect: ErrDigestNoQop},
{mutateConf: func(c *digestServerConfig) { c.charset = "utf-16" }, expect: ErrDigestCharset},
{mutateConf: func(c *digestServerConfig) { c.uri = "/bad" }, expect: ErrDigestBadChallenge},
{mutateConf: func(c *digestServerConfig) { c.uri = "/unknown_param" }, expect: ErrDigestBadChallenge},
{mutateConf: func(c *digestServerConfig) { c.uri = "/no_challenge" }, expect: ErrDigestBadChallenge},
{mutateConf: func(c *digestServerConfig) { c.uri = "/status_500" }, expect: nil},
}

for _, tc := range tests {
conf := defaultDigestServerConf()
tc.mutateConf(conf)
ts := createDigestServer(t, conf)

c := dc().
SetBaseURL(ts.URL+"/").
SetDigestAuth(conf.username, conf.password)

_, err := c.R().Get(conf.uri)
assertErrorIs(t, tc.expect, err)
ts.Close()
}
}

func TestOnAfterMiddleware(t *testing.T) {
ts := createGenServer(t)
defer ts.Close()
Expand Down
272 changes: 272 additions & 0 deletions digest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
package resty
jeevatkm marked this conversation as resolved.
Show resolved Hide resolved

import (
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
"io"
"net/http"
"strings"
)

var (
ErrDigestBadChallenge = errors.New("digest: challenge is bad")
ErrDigestCharset = errors.New("digest: unsupported charset")
ErrDigestAlgNotSupported = errors.New("digest: algorithm is not supported")
ErrDigestQopNotSupported = errors.New("digest: no supported qop in list")
ErrDigestNoQop = errors.New("digest: qop must be specified")
)
segevda marked this conversation as resolved.
Show resolved Hide resolved

var hashFuncs = map[string]func() hash.Hash{
"": md5.New,
"MD5": md5.New,
"MD5-sess": md5.New,
"SHA-256": sha256.New,
"SHA-256-sess": sha256.New,
"SHA-512-256": sha512.New,
"SHA-512-256-sess": sha512.New,
}

type digestCredentials struct {
username, password string
}

type digestTransport struct {
digestCredentials
transport http.RoundTripper
}

func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Copy the request, so we don't modify the input.
req2 := new(http.Request)
*req2 = *req
req2.Header = make(http.Header)
for k, s := range req.Header {
req2.Header[k] = s
}

// Make a request to get the 401 that contains the challenge.
resp, err := dt.transport.RoundTrip(req)
if err != nil || resp.StatusCode != http.StatusUnauthorized {
return resp, err
}
chal := resp.Header.Get(hdrWwwAuthenticateKey)
if chal == "" {
return resp, ErrDigestBadChallenge
}

c, err := parseChallenge(chal)
if err != nil {
return resp, err
}

// Form credentials based on the challenge
cr := dt.newCredentials(req2, c)
auth, err := cr.authorize()
if err != nil {
return resp, err
}
err = resp.Body.Close()
if err != nil {
return nil, err
}

// Make authenticated request
req2.Header.Set(hdrAuthorizationKey, auth)
return dt.transport.RoundTrip(req2)
}

func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *credentials {
return &credentials{
username: dt.username,
userhash: c.userhash,
realm: c.realm,
nonce: c.nonce,
digestURI: req.URL.RequestURI(),
algorithm: c.algorithm,
sessionAlg: strings.HasSuffix(c.algorithm, "-sess"),
opaque: c.opaque,
messageQop: c.qop,
nc: 0,
method: req.Method,
password: dt.password,
}
}

type challenge struct {
realm string
domain string
nonce string
opaque string
stale string
algorithm string
qop string
userhash string
}

func parseChallenge(input string) (*challenge, error) {
const ws = " \n\r\t"
const qs = `"`
s := strings.Trim(input, ws)
if !strings.HasPrefix(s, "Digest ") {
return nil, ErrDigestBadChallenge
}
s = strings.Trim(s[7:], ws)
sl := strings.Split(s, ", ")
c := &challenge{}
var r []string
for i := range sl {
r = strings.SplitN(sl[i], "=", 2)
switch r[0] {
case "realm":
c.realm = strings.Trim(r[1], qs)
case "domain":
c.domain = strings.Trim(r[1], qs)
case "nonce":
c.nonce = strings.Trim(r[1], qs)
case "opaque":
c.opaque = strings.Trim(r[1], qs)
case "stale":
c.stale = r[1]
case "algorithm":
c.algorithm = r[1]
case "qop":
c.qop = strings.Trim(r[1], qs)
case "charset":
if strings.ToUpper(strings.Trim(r[1], qs)) != "UTF-8" {
return nil, ErrDigestCharset
}
case "userhash":
c.userhash = strings.Trim(r[1], qs)
default:
return nil, ErrDigestBadChallenge
}
}
return c, nil
}

type credentials struct {
username string
userhash string
realm string
nonce string
digestURI string
algorithm string
sessionAlg bool
cNonce string
opaque string
messageQop string
nc int
method string
password string
}

func (c *credentials) authorize() (string, error) {
if _, ok := hashFuncs[c.algorithm]; !ok {
return "", ErrDigestAlgNotSupported
}

if err := c.validateQop(); err != nil {
return "", err
}

resp, err := c.resp()
if err != nil {
return "", err
}

sl := make([]string, 0, 10)
if c.userhash == "true" {
// RFC 7616 3.4.4
c.username = c.h(fmt.Sprintf("%s:%s", c.username, c.realm))
sl = append(sl, fmt.Sprintf(`userhash=%s`, c.userhash))
}
sl = append(sl, fmt.Sprintf(`username="%s"`, c.username))
sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm))
sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce))
sl = append(sl, fmt.Sprintf(`uri="%s"`, c.digestURI))
sl = append(sl, fmt.Sprintf(`response="%s"`, resp))
sl = append(sl, fmt.Sprintf(`algorithm=%s`, c.algorithm))
if c.opaque != "" {
sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque))
}
if c.messageQop != "" {
sl = append(sl, fmt.Sprintf("qop=%s", c.messageQop))
sl = append(sl, fmt.Sprintf("nc=%08x", c.nc))
sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cNonce))
}

return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil
}

func (c *credentials) validateQop() error {
// Currently only supporting auth quality of protection. TODO: add auth-int support
// NOTE: cURL support auth-int qop for requests other than POST and PUT (i.e. w/o body) by hashing an empty string
jeevatkm marked this conversation as resolved.
Show resolved Hide resolved
// is this applicable for resty? see: https://github.com/curl/curl/blob/307b7543ea1e73ab04e062bdbe4b5bb409eaba3a/lib/vauth/digest.c#L774
if c.messageQop == "" {
return ErrDigestNoQop
}
possibleQops := strings.Split(c.messageQop, ", ")
var authSupport bool
for _, qop := range possibleQops {
if qop == "auth" {
authSupport = true
break
}
}
if !authSupport {
return ErrDigestQopNotSupported
}

c.messageQop = "auth"

return nil
}

func (c *credentials) h(data string) string {
hfCtor := hashFuncs[c.algorithm]
hf := hfCtor()
_, _ = hf.Write([]byte(data)) // Hash.Write never returns an error
return fmt.Sprintf("%x", hf.Sum(nil))
}

func (c *credentials) resp() (string, error) {
c.nc++

b := make([]byte, 16)
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
return "", err
}
c.cNonce = fmt.Sprintf("%x", b)[:32]

ha1 := c.ha1()
ha2 := c.ha2()

return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s",
c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil
}

func (c *credentials) kd(secret, data string) string {
return c.h(fmt.Sprintf("%s:%s", secret, data))
}

// RFC 7616 3.4.2
func (c *credentials) ha1() string {
ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password))
if c.sessionAlg {
return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce))
}

return ret
}

// RFC 7616 3.4.3
func (c *credentials) ha2() string {
// currently no auth-int support
return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI))
}