Skip to content

Commit

Permalink
Implement DTLS restart
Browse files Browse the repository at this point in the history
Fixes #1636
  • Loading branch information
Antonito committed Jun 26, 2021
1 parent 7948437 commit b0d56c6
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 1 deletion.
27 changes: 26 additions & 1 deletion dtlstransport.go
Expand Up @@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

isAlreadyRunning := func() bool {
select {
case <-t.srtpReady:
return true
default:
return false
}
}()

if isAlreadyRunning {
if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

return nil
}

srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
Expand Down Expand Up @@ -283,7 +308,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
return DTLSRole(0), nil, err
}

if t.state != DTLSTransportStateNew {
if t.state != DTLSTransportStateNew && t.state != DTLSTransportStateClosed {
return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
}

Expand Down
34 changes: 34 additions & 0 deletions peerconnection.go
Expand Up @@ -1108,7 +1108,41 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
pc.ops.Enqueue(func() {
pc.startRTP(true, &desc, currentTransceivers)
})
} else if pc.dtlsTransport.State() != DTLSTransportStateNew {
fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed)
if fErr != nil {
return fErr
}

fingerPrintDidChange := true

for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints {
if fingerprint == fp.Value && fingerprintHash == fp.Algorithm {
fingerPrintDidChange = false
break
}
}

if fingerPrintDidChange {
pc.ops.Enqueue(func() {
if dErr := pc.dtlsTransport.Stop(); dErr != nil {
pc.log.Warnf("Failed to stop DTLS: %s", dErr)
}

// Restart the dtls transport with updated fingerprints
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRoleFromRemoteSDP(desc.parsed),
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
if err != nil {
pc.log.Warnf("Failed to restart DTLS: %s", err)
return
}
})
}
}

return nil
}

Expand Down
185 changes: 185 additions & 0 deletions peerconnection_media_test.go
Expand Up @@ -14,10 +14,12 @@ import (
"testing"
"time"

"github.com/pion/logging"
"github.com/pion/randutil"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/transport/test"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1052,3 +1054,186 @@ func TestPeerConnection_RaceReplaceTrack(t *testing.T) {

assert.NoError(t, pc.Close())
}

// Issue #1636
func TestPeerConnection_DTLS_Restart(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

// First prepare network configuration

router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "0.0.0.0/0",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)

networkA1 := vnet.NewNet(&vnet.NetConfig{
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
})

networkA2 := vnet.NewNet(&vnet.NetConfig{
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
})

networkB := vnet.NewNet(&vnet.NetConfig{
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
})

assert.NoError(t, router.AddNet(networkA1))
assert.NoError(t, router.AddNet(networkA2))
assert.NoError(t, router.AddNet(networkB))

assert.NoError(t, router.Start())
defer func() { _ = router.Stop() }()

// ... then the clients

makeClient := func(network *vnet.Net) (*PeerConnection, *TrackLocalStaticSample) {
m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs())

s := SettingEngine{}
s.SetVNet(network)
s.SetICETimeouts(2*time.Second, 5*time.Second, 1*time.Second)

api := NewAPI(WithSettingEngine(s), WithMediaEngine(m))
pc, cliErr := api.NewPeerConnection(Configuration{})
assert.NoError(t, cliErr)

track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client")
assert.NoError(t, cliErr)

_, cliErr = pc.AddTrack(track)
assert.NoError(t, cliErr)

return pc, track
}

clientA1, _ := makeClient(networkA1)
defer func() { _ = clientA1.Close() }()

clientB, localClientBTrack := makeClient(networkB)
defer func() { _ = clientB.Close() }()

// ... clientB starts publishing media
publishClientBCtx, publishCancel := context.WithCancel(context.Background())
go func() {
ticker := time.NewTicker(20 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-publishClientBCtx.Done():
return
case <-ticker.C:
_ = localClientBTrack.WriteSample(media.Sample{
Data: []byte{0xbb},
Timestamp: time.Now(),
Duration: 20 * time.Millisecond,
})
}
}
}()
defer publishCancel()

clientA1Tracks := make(chan *TrackRemote, 1)
clientA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) {
clientA1Tracks <- remote
})

// ClientA1 connects to ClientB

gatherCompletePromiseA1 := GatheringCompletePromise(clientA1)
offerA1, err := clientA1.CreateOffer(nil)
assert.NoError(t, err)
assert.NoError(t, clientA1.SetLocalDescription(offerA1))
<-gatherCompletePromiseA1

assert.NoError(t, clientB.SetRemoteDescription(*clientA1.LocalDescription()))

gatherCompletePromiseB := GatheringCompletePromise(clientB)
answerB, err := clientB.CreateAnswer(nil)
assert.NoError(t, err)
assert.NoError(t, clientB.SetLocalDescription(answerB))
<-gatherCompletePromiseB

clientA1Connected := make(chan struct{}, 1)
clientA1Disconnected := make(chan struct{}, 1)
clientA1.OnICEConnectionStateChange(func(s ICEConnectionState) {
if s == ICEConnectionStateConnected {
clientA1Connected <- struct{}{}
} else if s == ICEConnectionStateDisconnected {
clientA1Disconnected <- struct{}{}
}
})

assert.NoError(t, clientA1.SetRemoteDescription(answerB))

// Wait for connection
<-clientA1Connected

// At this point, clientA1 should have received a track, and some media
clientA1RemoteTrack := <-clientA1Tracks
pkt, _, err := clientA1RemoteTrack.ReadRTP()
assert.NotNil(t, pkt)
assert.NoError(t, err)

networkA1.SetNetworkConditioner(vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetFullLoss))

<-clientA1Disconnected

// ClientA1 has been disconnected – in a mobile app context, this could be a switch to the background
// or a killed app.
//
// In these scenarios, the client will reconnect with a different PeerConnection – here ClientA2.

clientA2, _ := makeClient(networkA2)
defer func() { _ = clientA2.Close() }()

clientA2Connected := make(chan struct{}, 1)
clientA2.OnICEConnectionStateChange(func(s ICEConnectionState) {
if s == ICEConnectionStateConnected {
clientA2Connected <- struct{}{}
} else if s == ICEConnectionStateFailed {
assert.FailNow(t, "should not fail")
}
})

clientA2Tracks := make(chan *TrackRemote, 1)
clientA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) {
clientA2Tracks <- remote
})

// ClientA2 connects to ClientB

gatherCompletePromiseA2 := GatheringCompletePromise(clientA2)
// We can't do an ICE Restart here, since it's a different PeerConnection
offerA2, err := clientA2.CreateOffer(nil)
assert.NoError(t, err)
assert.NoError(t, clientA2.SetLocalDescription(offerA2))
<-gatherCompletePromiseA2

assert.NoError(t, clientB.SetRemoteDescription(*clientA2.LocalDescription()))

gatherCompletePromiseB = GatheringCompletePromise(clientB)
answerB, err = clientB.CreateAnswer(nil)
assert.NoError(t, err)
assert.NoError(t, clientB.SetLocalDescription(answerB))
<-gatherCompletePromiseB

assert.NoError(t, clientA2.SetRemoteDescription(answerB))

// Wait for connection
<-clientA2Connected

// At this point, clientA2 should have received a track, and some media
clientA2RemoteTrack := <-clientA2Tracks

// Read a bunch of RTPs
for ndx := 0; ndx < 10; ndx++ {
pkt, _, err = clientA2RemoteTrack.ReadRTP()
assert.NotNil(t, pkt)
assert.NoError(t, err)
}
}

0 comments on commit b0d56c6

Please sign in to comment.