-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
stunner.go
239 lines (210 loc) · 5.84 KB
/
stunner.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stunner
import (
"context"
"fmt"
"net"
"strconv"
"sync"
"time"
"tailscale.com/stun"
)
// Stunner sends a STUN request to several servers and handles a response.
//
// It is designed to used on a connection owned by other code and so does
// not directly reference a net.Conn of any sort. Instead, the user should
// provide Send function to send packets, and call Receive when a new
// STUN response is received.
//
// In response, a Stunner will call Endpoint with any endpoints determined
// for the connection. (An endpoint may be reported multiple times if
// multiple servers are provided.)
type Stunner struct {
// Send sends a packet.
// It will typically be a PacketConn.WriteTo method value.
Send func([]byte, net.Addr) (int, error) // sends a packet
// Endpoint is called whenever a STUN response is received.
// The server is the STUN server that replied, endpoint is the ip:port
// from the STUN response, and d is the duration that the STUN request
// took on the wire (not including DNS lookup time.
Endpoint func(server, endpoint string, d time.Duration)
Servers []string // STUN servers to contact
// Resolver optionally specifies a resolver to use for DNS lookups.
// If nil, net.DefaultResolver is used.
Resolver *net.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{})
// OnlyIPv6 controls whether IPv6 is exclusively used.
// If false, only IPv4 is used. There is currently no mixed mode.
OnlyIPv6 bool
// sessions tracks the state of each server.
// It's keyed by the STUN server (from the Servers field).
sessions map[string]*session
mu sync.Mutex
inFlight map[stun.TxID]request
}
func (s *Stunner) addTX(tx stun.TxID, server string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.inFlight == nil {
s.inFlight = make(map[stun.TxID]request)
}
s.inFlight[tx] = request{sent: time.Now(), server: server}
}
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.inFlight[tx]
delete(s.inFlight, tx)
return r, ok
}
type request struct {
sent time.Time
server string
}
type session struct {
ctx context.Context // closed via call to done when reply received
cancel context.CancelFunc
}
func (s *Stunner) logf(format string, args ...interface{}) {
if s.Logf != nil {
s.Logf(format, args...)
}
}
// Receive delivers a STUN packet to the stunner.
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
if !stun.Is(p) {
s.logf("stunner: received non-STUN packet")
return
}
now := time.Now()
tx, addr, port, err := stun.ParseResponse(p)
if err != nil {
s.logf("stunner: received bad STUN response: %v", err)
return
}
r, ok := s.removeTX(tx)
if !ok {
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
return
}
d := now.Sub(r.sent)
session := s.sessions[r.server]
if session != nil {
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
s.Endpoint(r.server, host, d)
session.cancel()
}
}
func (s *Stunner) resolver() *net.Resolver {
if s.Resolver != nil {
return s.Resolver
}
return net.DefaultResolver
}
// Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout.
//
// TODO: this always returns success now. It should return errors
// if certain servers are unavailable probably. Or if all are.
// Or some configured threshold are.
func (s *Stunner) Run(ctx context.Context) error {
s.sessions = map[string]*session{}
for _, server := range s.Servers {
sctx, cancel := context.WithCancel(ctx)
s.sessions[server] = &session{
ctx: sctx,
cancel: cancel,
}
}
// after this point, the s.sessions map is read-only
var wg sync.WaitGroup
for _, server := range s.Servers {
wg.Add(1)
go func(server string) {
defer wg.Done()
s.runServer(ctx, server)
}(server)
}
wg.Wait()
return nil
}
func (s *Stunner) runServer(ctx context.Context, server string) {
session := s.sessions[server]
for i, d := range retryDurations {
ctx, cancel := context.WithTimeout(ctx, d)
err := s.sendSTUN(ctx, server)
if err != nil {
s.logf("stunner: %s: %v", server, err)
}
select {
case <-ctx.Done():
cancel()
case <-session.ctx.Done():
cancel()
if i > 0 {
s.logf("stunner: slow STUN response from %s: %d retries", server, i)
}
return
}
}
s.logf("stunner: no STUN response from %s", server)
}
func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrPort, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("port: %v", err)
}
if addrPort == 0 {
addrPort = 3478
}
addr := &net.UDPAddr{Port: addrPort}
ipAddrs, err := s.resolver().LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
for _, ipAddr := range ipAddrs {
ip4 := ipAddr.IP.To4()
if ip4 != nil {
if s.OnlyIPv6 {
continue
}
addr.IP = ip4
break
} else if s.OnlyIPv6 {
addr.IP = ipAddr.IP
addr.Zone = ipAddr.Zone
}
}
if addr.IP == nil {
if s.OnlyIPv6 {
return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
}
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
}
txID := stun.NewTxID()
req := stun.Request(txID)
s.addTX(txID, server)
_, err = s.Send(req, addr)
if err != nil {
return fmt.Errorf("send: %v", err)
}
return nil
}
var retryDurations = []time.Duration{
100 * time.Millisecond,
100 * time.Millisecond,
100 * time.Millisecond,
200 * time.Millisecond,
200 * time.Millisecond,
400 * time.Millisecond,
800 * time.Millisecond,
1600 * time.Millisecond,
3200 * time.Millisecond,
}