-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Each side of a QUIC connection chooses the connection IDs used by its peer. In our case, we use 8-byte random IDs. A connection has a list of connection IDs that it may receive packets on, and a list that it may send packets to. Add a minimal data structure for tracking these lists, and handling of the connection IDs tracked across Initial and Handshake packets. This does not yet handle post-handshake connection ID changes made in NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames. RFC 9000, Section 5.1. For golang/go#58547 Change-Id: I3e059393cacafbcea04a1b4131c0c7dc28acad5e Reviewed-on: https://go-review.googlesource.com/c/net/+/506675 Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> TryBot-Result: Gopher Robot <gobot@golang.org>
- Loading branch information
Showing
2 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
// Copyright 2023 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
//go:build go1.21 | ||
|
||
package quic | ||
|
||
import ( | ||
"crypto/rand" | ||
) | ||
|
||
// connIDState is a conn's connection IDs. | ||
type connIDState struct { | ||
// The destination connection IDs of packets we receive are local. | ||
// The destination connection IDs of packets we send are remote. | ||
// | ||
// Local IDs are usually issued by us, and remote IDs by the peer. | ||
// The exception is the transient destination connection ID sent in | ||
// a client's Initial packets, which is chosen by the client. | ||
local []connID | ||
remote []connID | ||
} | ||
|
||
// A connID is a connection ID and associated metadata. | ||
type connID struct { | ||
// cid is the connection ID itself. | ||
cid []byte | ||
|
||
// seq is the connection ID's sequence number: | ||
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1 | ||
// | ||
// For the transient destination ID in a client's Initial packet, this is -1. | ||
seq int64 | ||
} | ||
|
||
func (s *connIDState) initClient(newID newConnIDFunc) error { | ||
// Client chooses its initial connection ID, and sends it | ||
// in the Source Connection ID field of the first Initial packet. | ||
locid, err := newID() | ||
if err != nil { | ||
return err | ||
} | ||
s.local = append(s.local, connID{ | ||
seq: 0, | ||
cid: locid, | ||
}) | ||
|
||
// Client chooses an initial, transient connection ID for the server, | ||
// and sends it in the Destination Connection ID field of the first Initial packet. | ||
remid, err := newID() | ||
if err != nil { | ||
return err | ||
} | ||
s.remote = append(s.remote, connID{ | ||
seq: -1, | ||
cid: remid, | ||
}) | ||
return nil | ||
} | ||
|
||
func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { | ||
// Client-chosen, transient connection ID received in the first Initial packet. | ||
// The server will not use this as the Source Connection ID of packets it sends, | ||
// but remembers it because it may receive packets sent to this destination. | ||
s.local = append(s.local, connID{ | ||
seq: -1, | ||
cid: cloneBytes(dstConnID), | ||
}) | ||
|
||
// Server chooses a connection ID, and sends it in the Source Connection ID of | ||
// the response to the clent. | ||
locid, err := newID() | ||
if err != nil { | ||
return err | ||
} | ||
s.local = append(s.local, connID{ | ||
seq: 0, | ||
cid: locid, | ||
}) | ||
return nil | ||
} | ||
|
||
// srcConnID is the Source Connection ID to use in a sent packet. | ||
func (s *connIDState) srcConnID() []byte { | ||
if s.local[0].seq == -1 && len(s.local) > 1 { | ||
// Don't use the transient connection ID if another is available. | ||
return s.local[1].cid | ||
} | ||
return s.local[0].cid | ||
} | ||
|
||
// dstConnID is the Destination Connection ID to use in a sent packet. | ||
func (s *connIDState) dstConnID() []byte { | ||
return s.remote[0].cid | ||
} | ||
|
||
// handlePacket updates the connection ID state during the handshake | ||
// (Initial and Handshake packets). | ||
func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) { | ||
switch { | ||
case ptype == packetTypeInitial && side == clientSide: | ||
if len(s.remote) == 1 && s.remote[0].seq == -1 { | ||
// We're a client connection processing the first Initial packet | ||
// from the server. Replace the transient remote connection ID | ||
// with the Source Connection ID from the packet. | ||
s.remote[0] = connID{ | ||
seq: 0, | ||
cid: cloneBytes(srcConnID), | ||
} | ||
} | ||
case ptype == packetTypeInitial && side == serverSide: | ||
if len(s.remote) == 0 { | ||
// We're a server connection processing the first Initial packet | ||
// from the client. Set the client's connection ID. | ||
s.remote = append(s.remote, connID{ | ||
seq: 0, | ||
cid: cloneBytes(srcConnID), | ||
}) | ||
} | ||
case ptype == packetTypeHandshake && side == serverSide: | ||
if len(s.local) > 0 && s.local[0].seq == -1 { | ||
// We're a server connection processing the first Handshake packet from | ||
// the client. Discard the transient, client-chosen connection ID used | ||
// for Initial packets; the client will never send it again. | ||
s.local = append(s.local[:0], s.local[1:]...) | ||
} | ||
} | ||
} | ||
|
||
func cloneBytes(b []byte) []byte { | ||
n := make([]byte, len(b)) | ||
copy(n, b) | ||
return n | ||
} | ||
|
||
type newConnIDFunc func() ([]byte, error) | ||
|
||
func newRandomConnID() ([]byte, error) { | ||
// It is not necessary for connection IDs to be cryptographically secure, | ||
// but it doesn't hurt. | ||
id := make([]byte, connIDLen) | ||
if _, err := rand.Read(id); err != nil { | ||
return nil, err | ||
} | ||
return id, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright 2023 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
//go:build go1.21 | ||
|
||
package quic | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestConnIDClientHandshake(t *testing.T) { | ||
// On initialization, the client chooses local and remote IDs. | ||
// | ||
// The order in which we allocate the two isn't actually important, | ||
// but test is a lot simpler if we assume. | ||
var s connIDState | ||
s.initClient(newConnIDSequence()) | ||
if got, want := string(s.srcConnID()), "local-1"; got != want { | ||
t.Errorf("after initClient: srcConnID = %q, want %q", got, want) | ||
} | ||
if got, want := string(s.dstConnID()), "local-2"; got != want { | ||
t.Errorf("after initClient: dstConnID = %q, want %q", got, want) | ||
} | ||
|
||
// The server's first Initial packet provides the client with a | ||
// non-transient remote connection ID. | ||
s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1")) | ||
if got, want := string(s.dstConnID()), "remote-1"; got != want { | ||
t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want) | ||
} | ||
|
||
wantLocal := []connID{{ | ||
cid: []byte("local-1"), | ||
seq: 0, | ||
}} | ||
if !reflect.DeepEqual(s.local, wantLocal) { | ||
t.Errorf("local ids: %v, want %v", s.local, wantLocal) | ||
} | ||
wantRemote := []connID{{ | ||
cid: []byte("remote-1"), | ||
seq: 0, | ||
}} | ||
if !reflect.DeepEqual(s.remote, wantRemote) { | ||
t.Errorf("remote ids: %v, want %v", s.remote, wantRemote) | ||
} | ||
} | ||
|
||
func TestConnIDServerHandshake(t *testing.T) { | ||
// On initialization, the server is provided with the client-chosen | ||
// transient connection ID, and allocates an ID of its own. | ||
// The Initial packet sets the remote connection ID. | ||
var s connIDState | ||
s.initServer(newConnIDSequence(), []byte("transient")) | ||
s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1")) | ||
if got, want := string(s.srcConnID()), "local-1"; got != want { | ||
t.Errorf("after initClient: srcConnID = %q, want %q", got, want) | ||
} | ||
if got, want := string(s.dstConnID()), "remote-1"; got != want { | ||
t.Errorf("after initClient: dstConnID = %q, want %q", got, want) | ||
} | ||
|
||
wantLocal := []connID{{ | ||
cid: []byte("transient"), | ||
seq: -1, | ||
}, { | ||
cid: []byte("local-1"), | ||
seq: 0, | ||
}} | ||
if !reflect.DeepEqual(s.local, wantLocal) { | ||
t.Errorf("local ids: %v, want %v", s.local, wantLocal) | ||
} | ||
wantRemote := []connID{{ | ||
cid: []byte("remote-1"), | ||
seq: 0, | ||
}} | ||
if !reflect.DeepEqual(s.remote, wantRemote) { | ||
t.Errorf("remote ids: %v, want %v", s.remote, wantRemote) | ||
} | ||
|
||
// The client's first Handshake packet permits the server to discard the | ||
// transient connection ID. | ||
s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1")) | ||
wantLocal = []connID{{ | ||
cid: []byte("local-1"), | ||
seq: 0, | ||
}} | ||
if !reflect.DeepEqual(s.local, wantLocal) { | ||
t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal) | ||
} | ||
} | ||
|
||
func newConnIDSequence() newConnIDFunc { | ||
var n uint64 | ||
return func() ([]byte, error) { | ||
n++ | ||
return []byte(fmt.Sprintf("local-%v", n)), nil | ||
} | ||
} | ||
|
||
func TestNewRandomConnID(t *testing.T) { | ||
cid, err := newRandomConnID() | ||
if len(cid) != connIDLen || err != nil { | ||
t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err) | ||
} | ||
} |