Skip to content

Commit

Permalink
Implement DTLS/SRTP/SCTP restart
Browse files Browse the repository at this point in the history
Fixes #1636
  • Loading branch information
Antonito committed Jun 29, 2021
1 parent 7948437 commit 0d1e50c
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 7 deletions.
6 changes: 3 additions & 3 deletions datachannel.go
Expand Up @@ -69,7 +69,7 @@ func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelPara
return nil, err
}

err = d.open(transport)
err = d.open(transport, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -103,14 +103,14 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele
}

// open opens the datachannel over the sctp transport
func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
func (d *DataChannel) open(sctpTransport *SCTPTransport, restart bool) error {
association := sctpTransport.association()
if association == nil {
return errSCTPNotEstablished
}

d.mu.Lock()
if d.sctpTransport != nil { // already open
if d.sctpTransport != nil && !restart { // already open & not restarting
d.mu.Unlock()
return nil
}
Expand Down
25 changes: 25 additions & 0 deletions dtlstransport.go
Expand Up @@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

isAlreadyRunning := func() bool {
select {
case <-t.srtpReady:
return true
default:
return false
}
}()

if isAlreadyRunning {
if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

return nil
}

srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
Expand Down
56 changes: 54 additions & 2 deletions peerconnection.go
Expand Up @@ -1108,7 +1108,59 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
pc.ops.Enqueue(func() {
pc.startRTP(true, &desc, currentTransceivers)
})
} else if pc.dtlsTransport.State() != DTLSTransportStateNew {
fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed)
if fErr != nil {
return fErr
}

fingerPrintDidChange := true

for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints {
if fingerprint == fp.Value && fingerprintHash == fp.Algorithm {
fingerPrintDidChange = false
break
}
}

if fingerPrintDidChange {
pc.ops.Enqueue(func() {
// SCTP uses DTLS, so prevent any use, by locking, while
// DTLS is restarting.
pc.sctpTransport.lock.Lock()
defer pc.sctpTransport.lock.Unlock()

if dErr := pc.dtlsTransport.Stop(); dErr != nil {
pc.log.Warnf("Failed to stop DTLS: %s", dErr)
}

// libwebrtc switches the connection back to `new`.
pc.dtlsTransport.lock.Lock()
pc.dtlsTransport.onStateChange(DTLSTransportStateNew)
pc.dtlsTransport.lock.Unlock()

// Restart the dtls transport with updated fingerprints
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRoleFromRemoteSDP(desc.parsed),
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
if err != nil {
pc.log.Warnf("Failed to restart DTLS: %s", err)
return
}

// If SCTP was enabled, restart it with the new DTLS transport.
if pc.sctpTransport.isStarted {
if dErr := pc.sctpTransport.restart(pc.dtlsTransport.conn); dErr != nil {
pc.log.Warnf("Failed to restart SCTP: %s", dErr)
return
}
}
})
}
}

return nil
}

Expand Down Expand Up @@ -1317,7 +1369,7 @@ func (pc *PeerConnection) startSCTP() {
var openedDCCount uint32
for _, d := range dataChannels {
if d.ReadyState() == DataChannelStateConnecting {
err := d.open(pc.sctpTransport)
err := d.open(pc.sctpTransport, false)
if err != nil {
pc.log.Warnf("failed to open data channel: %s", err)
continue
Expand Down Expand Up @@ -1775,7 +1827,7 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn

// If SCTP already connected open all the channels
if pc.sctpTransport.State() == SCTPTransportStateConnected {
if err = d.open(pc.sctpTransport); err != nil {
if err = d.open(pc.sctpTransport, false); err != nil {
return nil, err
}
}
Expand Down

0 comments on commit 0d1e50c

Please sign in to comment.