From 69d0eb9187b6dead8fe84b2423518475e5cc535c Mon Sep 17 00:00:00 2001 From: Yuki Hirasawa <48427044+hirasawayuki@users.noreply.github.com> Date: Wed, 16 Feb 2022 10:15:20 +0900 Subject: [PATCH] Add check for Sec-WebSocket-Key header (#752) * add Sec-WebSocket-Key header verification * add testcase to Sec-WebSocket-Key header verification --- server.go | 4 ++-- util.go | 15 +++++++++++++++ util_test.go | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 24d53b3..bb33597 100644 --- a/server.go +++ b/server.go @@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } challengeKey := r.Header.Get("Sec-Websocket-Key") - if challengeKey == "" { - return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") + if !isValidChallengeKey(challengeKey) { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") } subprotocol := u.selectSubprotocol(r, responseHeader) diff --git a/util.go b/util.go index 7bf2f66..31a5dee 100644 --- a/util.go +++ b/util.go @@ -281,3 +281,18 @@ headers: } return result } + +// isValidChallengeKey checks if the argument meets RFC6455 specification. +func isValidChallengeKey(s string) bool { + // From RFC6455: + // + // A |Sec-WebSocket-Key| header field with a base64-encoded (see + // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + // length. + + if s == "" { + return false + } + decoded, err := base64.StdEncoding.DecodeString(s) + return err == nil && len(decoded) == 16 +} diff --git a/util_test.go b/util_test.go index af710ba..f14d69a 100644 --- a/util_test.go +++ b/util_test.go @@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) { } } +var isValidChallengeKeyTests = []struct { + key string + ok bool +}{ + {"dGhlIHNhbXBsZSBub25jZQ==", true}, + {"", false}, + {"InvalidKey", false}, + {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, +} + +func TestIsValidChallengeKey(t *testing.T) { + for _, tt := range isValidChallengeKeyTests { + ok := isValidChallengeKey(tt.key) + if ok != tt.ok { + t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) + } + } +} + var parseExtensionTests = []struct { value string extensions []map[string]string