Skip to content

Commit

Permalink
Add check for Sec-WebSocket-Key header (#752)
Browse files Browse the repository at this point in the history
* add Sec-WebSocket-Key header verification

* add testcase to Sec-WebSocket-Key header verification
  • Loading branch information
hirasawayuki committed Feb 16, 2022
1 parent 9111bb8 commit 69d0eb9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
4 changes: 2 additions & 2 deletions server.go
Expand Up @@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}

challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
if !isValidChallengeKey(challengeKey) {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
}

subprotocol := u.selectSubprotocol(r, responseHeader)
Expand Down
15 changes: 15 additions & 0 deletions util.go
Expand Up @@ -281,3 +281,18 @@ headers:
}
return result
}

// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.

if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}
19 changes: 19 additions & 0 deletions util_test.go
Expand Up @@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
}
}

var isValidChallengeKeyTests = []struct {
key string
ok bool
}{
{"dGhlIHNhbXBsZSBub25jZQ==", true},
{"", false},
{"InvalidKey", false},
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
}

func TestIsValidChallengeKey(t *testing.T) {
for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key)
if ok != tt.ok {
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
}
}
}

var parseExtensionTests = []struct {
value string
extensions []map[string]string
Expand Down

0 comments on commit 69d0eb9

Please sign in to comment.