Skip to content

Commit

Permalink
Refine TCPMux memory usage
Browse files Browse the repository at this point in the history
Reduce read buf of first stun message.
Add timeout to read first message to clean interrupted
connection earlier.
Add alive duration for gather to access connection created
from stun bind, avoid connection leak from malicious client.
  • Loading branch information
cnderrauber committed Feb 22, 2024
1 parent 94e5867 commit 3f14618
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 16 deletions.
63 changes: 51 additions & 12 deletions tcp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"strings"
"sync"
"time"

"github.com/pion/logging"
"github.com/pion/stun"
Expand Down Expand Up @@ -52,6 +53,16 @@ type TCPMuxParams struct {
// if the write buffer is full, the subsequent write packet will be dropped until it has enough space.
// a default 4MB is recommended.
WriteBufferSize int

// A new established connection will be removed if the first STUN binding request is not received within this timeout,
// avoiding the client with bad network or attacker to create a lot of empty connections.
// Default 30s timeout will be used if not set.
FirstStunBindTimeout time.Duration

// TCPMux will create connection from STUN binding request with an unknown username, if
// the connection is not used in the timeout, it will be removed to avoid resource leak / attack.
// Default 30s timeout will be used if not set.
AliveDurationForConnFromStun time.Duration
}

// NewTCPMuxDefault creates a new instance of TCPMuxDefault.
Expand All @@ -60,6 +71,14 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}

if params.FirstStunBindTimeout == 0 {
params.FirstStunBindTimeout = 30 * time.Second
}

if params.AliveDurationForConnFromStun == 0 {
params.AliveDurationForConnFromStun = 30 * time.Second
}

m := &TCPMuxDefault{
params: &params,

Expand Down Expand Up @@ -110,25 +129,32 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP)
}

if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
conn.ClearAliveTimer()
return conn, nil
}

return m.createConn(ufrag, isIPv6, local)
return m.createConn(ufrag, isIPv6, local, false)
}

func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP) (*tcpPacketConn, error) {
func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) {
addr, ok := m.LocalAddr().(*net.TCPAddr)
if !ok {
return nil, ErrGetTransportAddress
}
localAddr := *addr
localAddr.IP = local

var alive time.Duration
if fromStun {
alive = m.params.AliveDurationForConnFromStun
}

conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
WriteBuffer: m.params.WriteBufferSize,
LocalAddr: &localAddr,
Logger: m.params.Logger,
ReadBuffer: m.params.ReadBufferSize,
WriteBuffer: m.params.WriteBufferSize,
LocalAddr: &localAddr,
Logger: m.params.Logger,
AliveDuration: alive,
})

var conns map[ipAddr]*tcpPacketConn
Expand Down Expand Up @@ -163,13 +189,26 @@ func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
}

func (m *TCPMuxDefault) handleConn(conn net.Conn) {
buf := make([]byte, receiveMTU)
buf := make([]byte, 512)

if m.params.FirstStunBindTimeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
}
}
n, err := readStreamingPacket(conn, buf)
if err != nil {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err)
if errors.Is(err, io.ErrShortBuffer) {
m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err)
} else {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
}
m.closeAndLogError(conn)
return
}
if err = conn.SetReadDeadline(time.Time{}); err != nil {
m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err)
}

buf = buf[:n]

Expand Down Expand Up @@ -204,9 +243,6 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
ufrag := strings.Split(string(attr), ":")[0]
m.params.Logger.Debugf("Ufrag: %s", ufrag)

m.mu.Lock()
defer m.mu.Unlock()

host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
m.closeAndLogError(conn)
Expand All @@ -222,15 +258,18 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return
}
m.mu.Lock()
packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP)
if !ok {
packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP)
packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true)
if err != nil {
m.mu.Unlock()
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return
}
}
m.mu.Unlock()

if err := packetConn.AddConn(conn, buf); err != nil {
m.closeAndLogError(conn)
Expand Down
141 changes: 141 additions & 0 deletions tcp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package ice
import (
"io"
"net"
"os"
"testing"
"time"

"github.com/pion/logging"
"github.com/pion/stun"
Expand Down Expand Up @@ -108,6 +110,10 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
ReadBufferSize: 20,
})

defer func() {
_ = tcpMux.Close()
}()

_, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP)
require.NoError(t, err, "error getting conn by ufrag")

Expand All @@ -117,3 +123,138 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
assert.Nil(t, conn, "should receive nil because mux is closed")
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}

func TestTCPMux_FirstPacketTimeout(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

loggerFactory := logging.NewDefaultLoggerFactory()

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()

tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
FirstStunBindTimeout: time.Second,
})

require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

// Don't send any data, the mux should close the connection after the timeout
time.Sleep(1500 * time.Millisecond)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1)
_, err = conn.Read(buf)
require.ErrorIs(t, err, io.EOF)
}

func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

loggerFactory := logging.NewDefaultLoggerFactory()

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()

tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
AliveDurationForConnFromStun: time.Second,
})

defer func() {
_ = tcpMux.Close()
}()

require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

t.Run("close connection from stun msg after timeout", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername("myufrag:otherufrag"),
stun.NewShortTermIntegrity("myufrag"),
stun.Fingerprint,
)
require.NoError(t, err, "error building STUN packet")
msg.Encode()

_, err = writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing TCP STUN packet")

time.Sleep(1500 * time.Millisecond)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1)
_, err = conn.Read(buf)
require.ErrorIs(t, err, io.EOF)
})

t.Run("connection keep alive if access by user", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername("myufrag2:otherufrag2"),
stun.NewShortTermIntegrity("myufrag2"),
stun.Fingerprint,
)
require.NoError(t, err, "error building STUN packet")
msg.Encode()

n, err := writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing TCP STUN packet")

// wait for the connection to be created
time.Sleep(100 * time.Millisecond)

pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
}()

time.Sleep(1500 * time.Millisecond)

// timeout, not closed
buf := make([]byte, 1024)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond)))
_, err = conn.Read(buf)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)

recv := make([]byte, n)
n2, rAddr, err := pktConn.ReadFrom(recv)
require.NoError(t, err, "error receiving data")
assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
assert.Equal(t, n, n2, "received byte size mismatch")
assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
})
}
28 changes: 24 additions & 4 deletions tcp_packet_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type tcpPacketConn struct {
wg sync.WaitGroup
closedChan chan struct{}
closeOnce sync.Once
aliveTimer *time.Timer
}

type streamingPacket struct {
Expand All @@ -94,10 +95,11 @@ type streamingPacket struct {
}

type tcpPacketParams struct {
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
WriteBuffer int
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
WriteBuffer int
AliveDuration time.Duration
}

func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
Expand All @@ -110,9 +112,24 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
closedChan: make(chan struct{}),
}

if params.AliveDuration > 0 {
p.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
p.params.Logger.Warn("close tcp packet conn by alive timeout")
_ = p.Close()
})
}

return p
}

func (t *tcpPacketConn) ClearAliveTimer() {
t.mu.Lock()
if t.aliveTimer != nil {
t.aliveTimer.Stop()
}
t.mu.Unlock()
}

func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())

Expand Down Expand Up @@ -261,6 +278,9 @@ func (t *tcpPacketConn) Close() error {
t.closeOnce.Do(func() {
close(t.closedChan)
shouldCloseRecvChan = true
if t.aliveTimer != nil {
t.aliveTimer.Stop()
}
})

for _, conn := range t.conns {
Expand Down

0 comments on commit 3f14618

Please sign in to comment.