Skip to content

Commit

Permalink
Add AcceptStreamTimeout to Config
Browse files Browse the repository at this point in the history
By setting the AcceptStreamTimeout value in
the config parameter when calling the NewSessionSRTP
or NewSessionSRTCP method, the nextConn.SetReadDeadline
is set in the session.start method to prevent the
AcceptStream method from waiting indefinitely.

When AcceptStream is notified via newStream, the
nextConn.SetReadDeadline is released.
  • Loading branch information
lolgopher authored and Sean-Der committed May 3, 2023
1 parent f11f66a commit 6ca78a7
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 33 deletions.
2 changes: 1 addition & 1 deletion keying_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type mockKeyingMaterialExporter struct {
exported []byte
}

func (m *mockKeyingMaterialExporter) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
func (m *mockKeyingMaterialExporter) ExportKeyingMaterial(label string, _ []byte, length int) ([]byte, error) {
if label != labelExtractorDtlsSrtp {
return nil, fmt.Errorf("%w: expected(%s) actual(%s)", errExporterWrongLabel, label, labelExtractorDtlsSrtp)
}
Expand Down
17 changes: 12 additions & 5 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"sync"
"time"

"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
Expand All @@ -21,7 +22,8 @@ type session struct {
localContext, remoteContext *Context
localOptions, remoteOptions []ContextOption

newStream chan readStream
newStream chan readStream
acceptStreamTimeout time.Time

started chan interface{}
closed chan interface{}
Expand All @@ -41,10 +43,11 @@ type session struct {
// or directly pass the keys themselves.
// After a Config is passed to a session it must not be modified.
type Config struct {
Keys SessionKeys
Profile ProtectionProfile
BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
LoggerFactory logging.LoggerFactory
Keys SessionKeys
Profile ProtectionProfile
BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
LoggerFactory logging.LoggerFactory
AcceptStreamTimeout time.Time

// List of local/remote context options.
// ReplayProtection is enabled on remote context by default.
Expand Down Expand Up @@ -118,6 +121,10 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote
return err
}

if err = s.nextConn.SetReadDeadline(s.acceptStreamTimeout); err != nil {
return err
}

go func() {
defer func() {
close(s.newStream)
Expand Down
22 changes: 13 additions & 9 deletions session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n

s := &SessionSRTCP{
session: session{
nextConn: conn,
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
started: make(chan interface{}),
closed: make(chan interface{}),
bufferFactory: config.BufferFactory,
log: loggerFactory.NewLogger("srtp"),
nextConn: conn,
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
acceptStreamTimeout: config.AcceptStreamTimeout,
started: make(chan interface{}),
closed: make(chan interface{}),
bufferFactory: config.BufferFactory,
log: loggerFactory.NewLogger("srtp"),
},
}
s.writeStream = &WriteStreamSRTCP{s}
Expand Down Expand Up @@ -165,6 +166,9 @@ func (s *SessionSRTCP) decrypt(buf []byte) error {
if r == nil {
return nil // Session has been closed
} else if isNew {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
}

Expand Down
36 changes: 36 additions & 0 deletions session_srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,42 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
}
}

// nolint: dupl
func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) {
lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

pipe, _ := net.Pipe()
config := &Config{
Profile: ProtectionProfileAes128CmHmacSha1_80,
Keys: SessionKeys{
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
},
AcceptStreamTimeout: time.Now().Add(3 * time.Second),
}

newSession, err := NewSessionSRTCP(pipe, config)
if err != nil {
t.Fatal(err)
} else if newSession == nil {
t.Fatal("NewSessionSRTCP did not error, but returned nil session")
}

if _, _, err = newSession.AcceptStream(); err == nil || !errors.Is(err, errStreamAlreadyClosed) {
t.Fatal(err)
}

if err = newSession.Close(); err != nil {
t.Fatal(err)
}
}

func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) {
authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.rtcpAuthTagLen()
if err != nil {
Expand Down
22 changes: 13 additions & 9 deletions session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol

s := &SessionSRTP{
session: session{
nextConn: conn,
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
started: make(chan interface{}),
closed: make(chan interface{}),
bufferFactory: config.BufferFactory,
log: loggerFactory.NewLogger("srtp"),
nextConn: conn,
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
acceptStreamTimeout: config.AcceptStreamTimeout,
started: make(chan interface{}),
closed: make(chan interface{}),
bufferFactory: config.BufferFactory,
log: loggerFactory.NewLogger("srtp"),
},
}
s.writeStream = &WriteStreamSRTP{s}
Expand Down Expand Up @@ -171,6 +172,9 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
if r == nil {
return nil // Session has been closed
} else if isNew {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
}

Expand Down
36 changes: 36 additions & 0 deletions session_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,42 @@ func TestSessionSRTPReplayProtection(t *testing.T) {
}
}

// nolint: dupl
func TestSessionSRTPAcceptStreamTimeout(t *testing.T) {
lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

pipe, _ := net.Pipe()
config := &Config{
Profile: ProtectionProfileAes128CmHmacSha1_80,
Keys: SessionKeys{
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
},
AcceptStreamTimeout: time.Now().Add(3 * time.Second),
}

newSession, err := NewSessionSRTP(pipe, config)
if err != nil {
t.Fatal(err)
} else if newSession == nil {
t.Fatal("NewSessionSRTP did not error, but returned nil session")
}

if _, _, err = newSession.AcceptStream(); err == nil || !errors.Is(err, errStreamAlreadyClosed) {
t.Fatal(err)
}

if err = newSession.Close(); err != nil {
t.Fatal(err)
}
}

func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) {
readBuffer := make([]byte, headerSize+len(expectedPayload))
n, hdr, err := stream.ReadRTP(readBuffer)
Expand Down
18 changes: 9 additions & 9 deletions stream_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import (

type noopConn struct{ closed chan struct{} }

func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} }
func (c *noopConn) Read(b []byte) (n int, err error) { <-c.closed; return 0, io.EOF }
func (c *noopConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *noopConn) Close() error { close(c.closed); return nil }
func (c *noopConn) LocalAddr() net.Addr { return nil }
func (c *noopConn) RemoteAddr() net.Addr { return nil }
func (c *noopConn) SetDeadline(t time.Time) error { return nil }
func (c *noopConn) SetReadDeadline(t time.Time) error { return nil }
func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil }
func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} }
func (c *noopConn) Read([]byte) (n int, err error) { <-c.closed; return 0, io.EOF }
func (c *noopConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *noopConn) Close() error { close(c.closed); return nil }
func (c *noopConn) LocalAddr() net.Addr { return nil }
func (c *noopConn) RemoteAddr() net.Addr { return nil }
func (c *noopConn) SetDeadline(time.Time) error { return nil }
func (c *noopConn) SetReadDeadline(time.Time) error { return nil }
func (c *noopConn) SetWriteDeadline(time.Time) error { return nil }

func TestBufferFactory(t *testing.T) {
wg := sync.WaitGroup{}
Expand Down

0 comments on commit 6ca78a7

Please sign in to comment.