Skip to content

Commit

Permalink
Add option to include loopback candidate
Browse files Browse the repository at this point in the history
Add option to include loopback candidate
  • Loading branch information
cnderrauber committed Nov 22, 2022
1 parent 7f13fd1 commit e90a58e
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 14 deletions.
3 changes: 3 additions & 0 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ type Agent struct {

interfaceFilter func(string) bool
ipFilter func(net.IP) bool
includeLoopback bool

insecureSkipVerify bool

Expand Down Expand Up @@ -317,6 +318,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
ipFilter: config.IPFilter,

insecureSkipVerify: config.InsecureSkipVerify,

includeLoopback: config.IncludeLoopback,
}

a.tcpMux = config.TCPMux
Expand Down
5 changes: 4 additions & 1 deletion agent_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,11 @@ type AgentConfig struct {
// dial interface in order to support corporate proxies
ProxyDialer proxy.Dialer

// Accept aggressive nomination in RFC 5245 for compatible with chrome and other browsers
// Deprecated: AcceptAggressiveNomination always enabled.
AcceptAggressiveNomination bool

// Include loopback addresses in the candidate list.
IncludeLoopback bool
}

// initWithDefaults populates an agent and falls back to defaults if fields are unset
Expand Down
2 changes: 1 addition & 1 deletion gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
delete(networks, udp)
}

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil {
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
return
Expand Down
85 changes: 84 additions & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sort"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -31,7 +32,7 @@ func TestListenUDP(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
assert.NoError(t, err)

Expand Down Expand Up @@ -86,6 +87,88 @@ func TestListenUDP(t *testing.T) {
assert.NoError(t, a.Close())
}

func TestLoopbackCandidate(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
type testCase struct {
name string
agentConfig *AgentConfig
loExpected bool
}
mux, err := NewMultiUDPMuxFromPort(12500)
assert.NoError(t, err)
muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback())
assert.NoError(t, errlo)
testCases := []testCase{
{
name: "mux should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: mux,
},
loExpected: false,
},
{
name: "mux with loopback should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: muxWithLo,
},
loExpected: true,
},
{
name: "includeloopback enabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: true,
},
loExpected: true,
},
{
name: "includeloopback disabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: false,
},
loExpected: false,
},
}

for _, tc := range testCases {
tcase := tc
t.Run(tcase.name, func(t *testing.T) {
a, err := NewAgent(tc.agentConfig)
assert.NoError(t, err)

candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
var loopback int32
assert.NoError(t, a.OnCandidate(func(c Candidate) {
if c != nil {
if net.ParseIP(c.Address()).IsLoopback() {
atomic.StoreInt32(&loopback, 1)
}
} else {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
assert.NoError(t, a.GatherCandidates())

<-candidateGathered.Done()

assert.NoError(t, a.Close())
assert.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1)
})
}

assert.NoError(t, mux.Close())
assert.NoError(t, muxWithLo.Close())
}

// Assert that STUN gathering is done concurrently
func TestSTUNConcurrency(t *testing.T) {
report := test.CheckRoutines(t)
Expand Down
12 changes: 6 additions & 6 deletions gather_vnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestVNetGather(t *testing.T) {
})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) > 0 {
t.Fatal("should return no local IP")
} else if err != nil {
Expand Down Expand Up @@ -69,7 +69,7 @@ func TestVNetGather(t *testing.T) {
})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 {
t.Fatal("should have one local IP")
} else if err != nil {
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create agent: %s", err)
}

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test")
} else if err != nil {
Expand Down Expand Up @@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) != 0 {
Expand All @@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) != 0 {
Expand All @@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)

localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) == 0 {
Expand Down
2 changes: 1 addition & 1 deletion udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}
if len(networks) > 0 {
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, nil, nil, networks)
ips, err := localInterfaces(muxNet, nil, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
Expand Down
12 changes: 11 additions & 1 deletion udp_mux_multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
opt.apply(&params)
}
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -130,6 +130,7 @@ type multiUDPMuxFromPortParam struct {
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
includeLoopback bool
}

type udpMuxFromPortOption struct {
Expand Down Expand Up @@ -193,3 +194,12 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
},
}
}

// UDPMuxFromPortWithLoopback set loopback interface should be included
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.includeLoopback = true
},
}
}
6 changes: 3 additions & 3 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func stunRequest(read func([]byte) (int, error), write func([]byte) (int, error)
return res, nil
}

func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType) ([]net.IP, error) { //nolint:gocognit
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := vnet.Interfaces()
if err != nil {
Expand All @@ -154,7 +154,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}

Expand All @@ -175,7 +175,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || ip.IsLoopback() {
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}

Expand Down

0 comments on commit e90a58e

Please sign in to comment.