Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update RtxSSRC for simulcast track remote #2759

Merged
merged 1 commit into from Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
218 changes: 218 additions & 0 deletions peerconnection_media_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"sync"
"sync/atomic"
Expand All @@ -26,6 +27,7 @@ import (
"github.com/pion/sdp/v3"
"github.com/pion/transport/v3/test"
"github.com/pion/transport/v3/vnet"
"github.com/pion/webrtc/v4/internal/util"
"github.com/pion/webrtc/v4/pkg/media"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1388,6 +1390,222 @@ func TestPeerConnection_Simulcast(t *testing.T) {
})
}

type simulcastTestTrackLocal struct {
*TrackLocalStaticRTP
}

// don't use ssrc&payload in bindings to let the test write different stream packets.
func (s *simulcastTestTrackLocal) WriteRTP(pkt *rtp.Packet) error {
packet := getPacketAllocationFromPool()

defer resetPacketPoolAllocation(packet)

*packet = *pkt

s.mu.RLock()
defer s.mu.RUnlock()

writeErrs := []error{}

for _, b := range s.bindings {
if _, err := b.writeStream.WriteRTP(&packet.Header, packet.Payload); err != nil {
writeErrs = append(writeErrs, err)
}
}

return util.FlattenErrs(writeErrs)
}

func TestPeerConnection_Simulcast_RTX(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

rids := []string{"a", "b"}
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)

vp8WriterAStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]))
assert.NoError(t, err)

vp8WriterBStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]))
assert.NoError(t, err)

vp8WriterA, vp8WriterB := &simulcastTestTrackLocal{vp8WriterAStatic}, &simulcastTestTrackLocal{vp8WriterBStatic}

sender, err := pcOffer.AddTrack(vp8WriterA)
assert.NoError(t, err)
assert.NotNil(t, sender)

assert.NoError(t, sender.AddEncoding(vp8WriterB))

var ridMapLock sync.RWMutex
ridMap := map[string]int{}

assertRidCorrect := func(t *testing.T) {
ridMapLock.Lock()
defer ridMapLock.Unlock()

for _, rid := range rids {
assert.Equal(t, ridMap[rid], 1)
}
assert.Equal(t, len(ridMap), 2)
}

ridsFullfilled := func() bool {
ridMapLock.Lock()
defer ridMapLock.Unlock()

ridCount := len(ridMap)
return ridCount == 2
}

var rtxPacketRead atomic.Int32
var wg sync.WaitGroup
wg.Add(2)

pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
ridMapLock.Lock()
ridMap[trackRemote.RID()] = ridMap[trackRemote.RID()] + 1
ridMapLock.Unlock()

defer wg.Done()

for {
_, attr, rerr := trackRemote.ReadRTP()
if rerr != nil {
break
}
if pt, ok := attr.Get(AttributeRtxPayloadType).(byte); ok {
if pt == 97 {
rtxPacketRead.Add(1)
}
}
}
})

parameters := sender.GetParameters()
assert.Equal(t, "a", parameters.Encodings[0].RID)
assert.Equal(t, "b", parameters.Encodings[1].RID)

var midID, ridID, rsid uint8
for _, extension := range parameters.HeaderExtensions {
switch extension.URI {
case sdp.SDESMidURI:
midID = uint8(extension.ID)
case sdp.SDESRTPStreamIDURI:
ridID = uint8(extension.ID)
case sdesRepairRTPStreamIDURI:
rsid = uint8(extension.ID)
}
}
assert.NotZero(t, midID)
assert.NotZero(t, ridID)
assert.NotZero(t, rsid)

err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")
return res
})
assert.NoError(t, err)

// padding only packets should not affect simulcast probe
var sequenceNumber uint16
for sequenceNumber = 0; sequenceNumber < simulcastProbeCount+10; sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
Padding: true,
SSRC: uint32(i),
},
Payload: []byte{0x00, 0x02},
}

assert.NoError(t, track.WriteRTP(pkt))
}
}
assert.False(t, ridsFullfilled(), "Simulcast probe should not be fulfilled by padding only packets")

for ; !ridsFullfilled(); sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
SSRC: uint32(i),
},
Payload: []byte{0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

assertRidCorrect(t)

for i := 0; i < simulcastProbeCount+10; i++ {
sequenceNumber++
time.Sleep(10 * time.Millisecond)

for j, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 97,
SSRC: uint32(100 + j),
},
Payload: []byte{0x00, 0x00, 0x00, 0x00, 0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))
assert.NoError(t, pkt.Header.SetExtension(rsid, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

for ; rtxPacketRead.Load() == 0; sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
SSRC: uint32(i),
},
Payload: []byte{0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

closePairNow(t, pcOffer, pcAnswer)

wg.Wait()

assert.Greater(t, rtxPacketRead.Load(), int32(0), "no rtx packet read")
}

// Everytime we receieve a new SSRC we probe it and try to determine the proper way to handle it.
// In most cases a Track explicitly declares a SSRC and a OnTrack is fired. In two cases we don't
// know the SSRC ahead of time
Expand Down
4 changes: 4 additions & 0 deletions rtpreceiver.go
Expand Up @@ -418,6 +418,10 @@ func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *intercep
for i := range r.tracks {
if r.tracks[i].track.RID() == rsid {
track = &r.tracks[i]
if track.track.RtxSSRC() == 0 {
track.track.setRtxSSRC(SSRC(streamInfo.SSRC))
}
break
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions track_remote.go
Expand Up @@ -224,3 +224,9 @@ func (t *TrackRemote) HasRTX() bool {
defer t.mu.RUnlock()
return t.rtxSsrc != 0
}

func (t *TrackRemote) setRtxSSRC(ssrc SSRC) {
t.mu.Lock()
defer t.mu.Unlock()
t.rtxSsrc = ssrc
}