Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/gorilla/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio VS committed Feb 18, 2022
2 parents fa3c588 + 69d0eb9 commit 2f8e79d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
4 changes: 2 additions & 2 deletions server.go
Expand Up @@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}

challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
if !isValidChallengeKey(challengeKey) {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
}

subprotocol := u.selectSubprotocol(r, responseHeader)
Expand Down
15 changes: 15 additions & 0 deletions util.go
Expand Up @@ -299,6 +299,21 @@ headers:
return result
}

// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.

if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}

// parseDataHeader returns a list with values if header value is comma-separated
func parseDataHeader(headerValue []byte) [][]byte {
h := bytes.TrimSpace(headerValue)
Expand Down
19 changes: 19 additions & 0 deletions util_test.go
Expand Up @@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
}
}

var isValidChallengeKeyTests = []struct {
key string
ok bool
}{
{"dGhlIHNhbXBsZSBub25jZQ==", true},
{"", false},
{"InvalidKey", false},
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
}

func TestIsValidChallengeKey(t *testing.T) {
for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key)
if ok != tt.ok {
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
}
}
}

var parseExtensionTests = []struct {
value string
extensions []map[string]string
Expand Down

0 comments on commit 2f8e79d

Please sign in to comment.