diff --git a/README.md b/README.md index 007c340d..6385ffcb 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu * [Sean DuBois](https://github.com/sean-der) - *Original Author* * [Atsushi Watanabe](https://github.com/at-wat) * [Alessandro Ros](https://github.com/aler9) +* [Mathis Engelbart](https://github.com/mengelbart) ### License MIT License - see [LICENSE](LICENSE) for full text diff --git a/pkg/twcc/header_extension.go b/pkg/twcc/header_extension.go new file mode 100644 index 00000000..d1254b1c --- /dev/null +++ b/pkg/twcc/header_extension.go @@ -0,0 +1,49 @@ +package twcc + +import ( + "sync/atomic" + + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +// HeaderExtensionInterceptor adds transport wide sequence numbers as header extension to each RTP packet +type HeaderExtensionInterceptor struct { + interceptor.NoOp + nextSequenceNr uint32 +} + +// NewHeaderExtensionInterceptor returns a HeaderExtensionInterceptor +func NewHeaderExtensionInterceptor() (*HeaderExtensionInterceptor, error) { + return &HeaderExtensionInterceptor{}, nil +} + +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + +// BindLocalStream returns a writer that adds a rtp.TransportCCExtension +// header with increasing sequence numbers to each outgoing packet. +func (h *HeaderExtensionInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + var hdrExtID uint8 + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + hdrExtID = uint8(e.ID) + break + } + } + if hdrExtID == 0 { // Don't add header extension if ID is 0, because 0 is an invalid extension ID + return writer + } + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + seqNr := atomic.AddUint32(&h.nextSequenceNr, 1) - 1 + + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(seqNr)}).Marshal() + if err != nil { + return 0, err + } + err = header.SetExtension(hdrExtID, tcc) + if err != nil { + return 0, err + } + return writer.Write(header, payload, attributes) + }) +} diff --git a/pkg/twcc/header_extension_test.go b/pkg/twcc/header_extension_test.go new file mode 100644 index 00000000..aea0a531 --- /dev/null +++ b/pkg/twcc/header_extension_test.go @@ -0,0 +1,63 @@ +package twcc + +import ( + "sync" + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +func TestHeaderExtensionInterceptor(t *testing.T) { + t.Run("add transport wide cc to each packet", func(t *testing.T) { + inter, err := NewHeaderExtensionInterceptor() + assert.NoError(t, err) + + pChan := make(chan *rtp.Packet, 10*5) + go func() { + // start some parallel streams using the same interceptor to test for race conditions + var wg sync.WaitGroup + num := 10 + wg.Add(num) + for i := 0; i < num; i++ { + go func(ch chan *rtp.Packet, id uint16) { + stream := test.NewMockStream(&interceptor.StreamInfo{RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, inter) + defer func() { + wg.Done() + assert.NoError(t, stream.Close()) + }() + + for _, seqNum := range []uint16{id * 1, id * 2, id * 3, id * 4, id * 5} { + assert.NoError(t, stream.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}})) + select { + case p := <-stream.WrittenRTP(): + assert.Equal(t, seqNum, p.SequenceNumber) + ch <- p + case <-time.After(10 * time.Millisecond): + panic("written rtp packet not found") + } + } + }(pChan, uint16(i+1)) + } + wg.Wait() + close(pChan) + }() + + for p := range pChan { + // Can't check for increasing transport cc sequence number, since we can't ensure ordering between the streams + // on pChan is same as in the interceptor, but at least make sure each packet has a seq nr. + extensionHeader := p.GetExtension(1) + twcc := &rtp.TransportCCExtension{} + err = twcc.Unmarshal(extensionHeader) + assert.NoError(t, err) + } + }) +} diff --git a/pkg/twcc/sender_interceptor.go b/pkg/twcc/sender_interceptor.go new file mode 100644 index 00000000..cb2ab0b7 --- /dev/null +++ b/pkg/twcc/sender_interceptor.go @@ -0,0 +1,174 @@ +package twcc + +import ( + "math/rand" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtp" +) + +// SenderInterceptor sends transport wide congestion control reports as specified in: +// https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +type SenderInterceptor struct { + interceptor.NoOp + + log logging.LeveledLogger + + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + + interval time.Duration + + recorder *Recorder + packetChan chan packet +} + +// An Option is a function that can be used to configure a SenderInterceptor +type Option func(*SenderInterceptor) error + +// SendInterval sets the interval at which the interceptor +// will send new feedback reports. +func SendInterval(interval time.Duration) Option { + return func(s *SenderInterceptor) error { + s.interval = interval + return nil + } +} + +// NewSenderInterceptor returns a new SenderInterceptor configured with the given options. +func NewSenderInterceptor(opts ...Option) (*SenderInterceptor, error) { + i := &SenderInterceptor{ + log: logging.NewDefaultLoggerFactory().NewLogger("twcc_sender_interceptor"), + packetChan: make(chan packet), + close: make(chan struct{}), + interval: 100 * time.Millisecond, + } + + for _, opt := range opts { + err := opt(i) + if err != nil { + return nil, err + } + } + + return i, nil +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (s *SenderInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + s.m.Lock() + defer s.m.Unlock() + + s.recorder = NewRecorder(rand.Uint32()) // #nosec + + if s.isClosed() { + return writer + } + + s.wg.Add(1) + + go s.loop(writer) + + return writer +} + +type packet struct { + hdr *rtp.Header + seqNr uint16 + ts int64 + ssrc uint32 +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (s *SenderInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + var hdrExtID uint8 + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + hdrExtID = uint8(e.ID) + break + } + } + if hdrExtID == 0 { // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID + return reader + } + return interceptor.RTPReaderFunc(func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(buf, attributes) + if err != nil { + return 0, nil, err + } + p := rtp.Packet{} + err = p.Unmarshal(buf[:i]) + if err != nil { + return 0, nil, err + } + var tccExt rtp.TransportCCExtension + if ext := p.GetExtension(hdrExtID); ext != nil { + err = tccExt.Unmarshal(ext) + if err != nil { + return 0, nil, err + } + + s.packetChan <- packet{ + hdr: &p.Header, + seqNr: tccExt.TransportSequence, + ts: time.Now().UnixNano(), + ssrc: info.SSRC, + } + } + + return i, attr, nil + }) +} + +// Close closes the interceptor. +func (s *SenderInterceptor) Close() error { + defer s.wg.Wait() + s.m.Lock() + defer s.m.Unlock() + + if !s.isClosed() { + close(s.close) + } + + return nil +} + +func (s *SenderInterceptor) isClosed() bool { + select { + case <-s.close: + return true + default: + return false + } +} + +func (s *SenderInterceptor) loop(w interceptor.RTCPWriter) { + defer s.wg.Done() + + ticker := time.NewTicker(s.interval) + + for { + select { + case <-s.close: + return + case p := <-s.packetChan: + s.recorder.Record(p.ssrc, p.seqNr, p.ts/1e6) // ns -> ms: divide by 1e6 + + case <-ticker.C: + // build and send twcc + if s.recorder != nil { + pkts := s.recorder.BuildFeedbackPacket() + _, err := w.Write(pkts, nil) + if err != nil { + s.log.Error(err.Error()) + } + } + } + } +} diff --git a/pkg/twcc/sender_interceptor_test.go b/pkg/twcc/sender_interceptor_test.go new file mode 100644 index 00000000..61a024f4 --- /dev/null +++ b/pkg/twcc/sender_interceptor_test.go @@ -0,0 +1,263 @@ +package twcc + +import ( + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +func TestSenderInterceptor(t *testing.T) { + t.Run("before any packets", func(t *testing.T) { + i, err := NewSenderInterceptor() + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{SSRC: 1, RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, i) + defer func() { + assert.NoError(t, stream.Close()) + }() + + pkts := <-stream.WrittenRTCP() + assert.Equal(t, 1, len(pkts)) + tlcc, ok := pkts[0].(*rtcp.TransportLayerCC) + assert.True(t, ok) + assert.Equal(t, uint16(0), tlcc.PacketStatusCount) + assert.Equal(t, uint8(0), tlcc.FbPktCount) + assert.Equal(t, uint16(0), tlcc.BaseSequenceNumber) + assert.Equal(t, uint32(0), tlcc.MediaSSRC) + assert.Equal(t, uint32(0), tlcc.ReferenceTime) + assert.Equal(t, 0, len(tlcc.RecvDeltas)) + assert.Equal(t, 0, len(tlcc.PacketChunks)) + }) + + t.Run("after RTP packets", func(t *testing.T) { + i, err := NewSenderInterceptor() + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{SSRC: 1, RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, i) + defer func() { + assert.NoError(t, stream.Close()) + }() + + for i := 0; i < 10; i++ { + hdr := rtp.Header{} + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(i)}).Marshal() + assert.NoError(t, err) + err = hdr.SetExtension(1, tcc) + assert.NoError(t, err) + stream.ReceiveRTP(&rtp.Packet{Header: hdr}) + } + + pkts := <-stream.WrittenRTCP() + assert.Equal(t, 1, len(pkts)) + cc, ok := pkts[0].(*rtcp.TransportLayerCC) + assert.True(t, ok) + assert.Equal(t, uint32(1), cc.MediaSSRC) + assert.Equal(t, uint16(0), cc.BaseSequenceNumber) + assert.Equal(t, []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 10, + }, + }, cc.PacketChunks) + }) + + t.Run("different delays between RTP packets", func(t *testing.T) { + i, err := NewSenderInterceptor( + SendInterval(500 * time.Millisecond), + ) + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, i) + defer func() { + assert.NoError(t, stream.Close()) + }() + + delays := []int{0, 10, 100, 200} + for i, d := range delays { + time.Sleep(time.Duration(d) * time.Millisecond) + + hdr := rtp.Header{} + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(i)}).Marshal() + assert.NoError(t, err) + err = hdr.SetExtension(1, tcc) + assert.NoError(t, err) + stream.ReceiveRTP(&rtp.Packet{Header: hdr}) + } + + pkts := <-stream.WrittenRTCP() + assert.Equal(t, 1, len(pkts)) + cc, ok := pkts[0].(*rtcp.TransportLayerCC) + assert.True(t, ok) + assert.Equal(t, uint16(0), cc.BaseSequenceNumber) + assert.Equal(t, []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + }, + }, + }, cc.PacketChunks) + }) + + t.Run("packet loss", func(t *testing.T) { + i, err := NewSenderInterceptor( + SendInterval(2 * time.Second), + ) + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, i) + defer func() { + assert.NoError(t, stream.Close()) + }() + + seqNrToDelay := map[int]int{ + 0: 0, + 1: 10, + 4: 100, + 8: 200, + 9: 20, + 10: 20, + 30: 300, + } + for _, i := range []int{0, 1, 4, 8, 9, 10, 30} { + d := seqNrToDelay[i] + time.Sleep(time.Duration(d) * time.Millisecond) + + hdr := rtp.Header{} + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(i)}).Marshal() + assert.NoError(t, err) + err = hdr.SetExtension(1, tcc) + assert.NoError(t, err) + stream.ReceiveRTP(&rtp.Packet{Header: hdr}) + } + + pkts := <-stream.WrittenRTCP() + assert.Equal(t, 1, len(pkts)) + cc, ok := pkts[0].(*rtcp.TransportLayerCC) + assert.True(t, ok) + assert.Equal(t, uint16(0), cc.BaseSequenceNumber) + assert.Equal(t, []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketNotReceived, + RunLength: 16, + }, + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedLargeDelta, + RunLength: 1, + }, + }, cc.PacketChunks) + }) + + t.Run("overflow", func(t *testing.T) { + i, err := NewSenderInterceptor( + SendInterval(2 * time.Second), + ) + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{RTPHeaderExtensions: []interceptor.RTPHeaderExtension{ + { + URI: transportCCURI, + ID: 1, + }, + }}, i) + defer func() { + assert.NoError(t, stream.Close()) + }() + + for _, i := range []int{65530, 65534, 65535, 1, 2, 10} { + hdr := rtp.Header{} + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(i)}).Marshal() + assert.NoError(t, err) + err = hdr.SetExtension(1, tcc) + assert.NoError(t, err) + stream.ReceiveRTP(&rtp.Packet{Header: hdr}) + } + + pkts := <-stream.WrittenRTCP() + assert.Equal(t, 1, len(pkts)) + cc, ok := pkts[0].(*rtcp.TransportLayerCC) + assert.True(t, ok) + assert.Equal(t, uint16(65530), cc.BaseSequenceNumber) + assert.Equal(t, []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + }, + }, + }, cc.PacketChunks) + }) +} diff --git a/pkg/twcc/twcc.go b/pkg/twcc/twcc.go new file mode 100644 index 00000000..37b8e2ac --- /dev/null +++ b/pkg/twcc/twcc.go @@ -0,0 +1,258 @@ +// Package twcc provides interceptors to implement transport wide congestion control. +package twcc + +import ( + "sort" + + "github.com/pion/rtcp" +) + +type pktInfo struct { + seqNr uint32 + timestamp int64 +} + +// Recorder records incoming RTP packets and their delays and creates +// transport wide congestion control feedback reports as specified in +// https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +type Recorder struct { + receivedPackets []pktInfo + + cycles uint32 + lastSeqNr uint16 + + senderSSRC uint32 + mediaSSRC uint32 + fbPktCnt uint8 +} + +// NewRecorder creates a new Recorder which uses the given senderSSRC in the created +// feedback packets. +func NewRecorder(senderSSRC uint32) *Recorder { + return &Recorder{ + receivedPackets: []pktInfo{}, + senderSSRC: senderSSRC, + } +} + +// Record marks a packet with mediaSSRC and a transport wide sequence number seqNr as received at arrivalTime. +func (r *Recorder) Record(mediaSSRC uint32, seqNr uint16, arrivalTimeMS int64) { + r.mediaSSRC = mediaSSRC + if seqNr < 0x0fff && (r.lastSeqNr&0xffff) > 0xf000 { + r.cycles += 1 << 16 + } + r.receivedPackets = append(r.receivedPackets, pktInfo{ + seqNr: r.cycles | uint32(seqNr), + timestamp: arrivalTimeMS, + }) + r.lastSeqNr = seqNr +} + +// BuildFeedbackPacket creates a new RTCP packet containing a TWCC feedback report. +func (r *Recorder) BuildFeedbackPacket() []rtcp.Packet { + tlcc := newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) + if len(r.receivedPackets) == 0 { + return []rtcp.Packet{tlcc.getRTCP()} + } + + sort.Slice(r.receivedPackets, func(i, j int) bool { + return r.receivedPackets[i].seqNr < r.receivedPackets[j].seqNr + }) + tlcc.setBase(uint16(r.receivedPackets[0].seqNr&0xffff), r.receivedPackets[0].timestamp*1000) + + var pkts []rtcp.Packet + for _, pkt := range r.receivedPackets { + built := tlcc.addReceived(uint16(pkt.seqNr&0xffff), pkt.timestamp*1000) + if !built { + pkts = append(pkts, tlcc.getRTCP()) + r.fbPktCnt++ + tlcc = newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) + tlcc.addReceived(uint16(pkt.seqNr&0xffff), pkt.timestamp*1000) + } + } + r.receivedPackets = []pktInfo{} + pkts = append(pkts, tlcc.getRTCP()) + + r.fbPktCnt++ + + return pkts +} + +type feedback struct { + rtcp *rtcp.TransportLayerCC + baseSeqNr uint16 + refTimestamp64MS int64 + lastTimestampUS int64 + nextSeqNr uint16 + seqNrCount uint16 + len int + lastChunk chunk + chunks []rtcp.PacketStatusChunk + deltas []*rtcp.RecvDelta +} + +func newFeedback(senderSSRC, mediaSSRC uint32, count uint8) *feedback { + return &feedback{ + rtcp: &rtcp.TransportLayerCC{ + SenderSSRC: senderSSRC, + MediaSSRC: mediaSSRC, + FbPktCount: count, + }, + } +} + +func (f *feedback) setBase(seqNr uint16, timeUS int64) { + f.baseSeqNr = seqNr + f.nextSeqNr = f.baseSeqNr + f.refTimestamp64MS = timeUS / 64e3 + f.lastTimestampUS = f.refTimestamp64MS * 64e3 +} + +func (f *feedback) getRTCP() *rtcp.TransportLayerCC { + f.rtcp.PacketStatusCount = f.seqNrCount + f.rtcp.ReferenceTime = uint32(f.refTimestamp64MS) + f.rtcp.BaseSequenceNumber = f.baseSeqNr + if len(f.lastChunk.deltas) > 0 { + f.chunks = append(f.chunks, f.lastChunk.encode()) + f.rtcp.PacketChunks = append(f.rtcp.PacketChunks, f.chunks...) + } + f.rtcp.RecvDeltas = f.deltas + + padLen := 20 + len(f.rtcp.PacketChunks)*2 + f.len // 4 bytes header + 16 bytes twcc header + 2 bytes for each chunk + length of deltas + padding := padLen%4 != 0 + for padLen%4 != 0 { + padLen++ + } + f.rtcp.Header = rtcp.Header{ + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Padding: padding, + Length: uint16((padLen / 4) - 1), + } + + return f.rtcp +} + +func (f *feedback) addReceived(seqNr uint16, timestampUS int64) bool { + deltaUS := timestampUS - f.lastTimestampUS + delta250US := deltaUS / 250 + delta16 := uint16(delta250US) + if int64(delta16) != delta250US { // delta doesn't fit into 16 bit, need to create new packet + return false + } + + for ; f.nextSeqNr != seqNr; f.nextSeqNr++ { + if !f.lastChunk.canAdd(rtcp.TypeTCCPacketNotReceived) { + f.chunks = append(f.chunks, f.lastChunk.encode()) + } + f.lastChunk.add(rtcp.TypeTCCPacketNotReceived) + f.seqNrCount++ + } + + var recvDelta uint16 + switch { + case delta250US >= 0 && delta250US <= 0xff: + f.len++ + recvDelta = rtcp.TypeTCCPacketReceivedSmallDelta + default: + f.len += 2 + recvDelta = rtcp.TypeTCCPacketReceivedLargeDelta + } + + if !f.lastChunk.canAdd(recvDelta) { + f.chunks = append(f.chunks, f.lastChunk.encode()) + } + f.lastChunk.add(recvDelta) + f.deltas = append(f.deltas, &rtcp.RecvDelta{ + Type: recvDelta, + Delta: delta250US, + }) + f.lastTimestampUS = timestampUS + f.seqNrCount++ + f.nextSeqNr++ + return true +} + +const ( + maxRunLengthCap = 0x1fff // 13 bits + maxOneBitCap = 14 // bits + maxTwoBitCap = 7 // bits +) + +type chunk struct { + hasLargeDelta bool + hasDifferentTypes bool + deltas []uint16 +} + +func (c *chunk) canAdd(delta uint16) bool { + if len(c.deltas) < maxTwoBitCap { + return true + } + if len(c.deltas) < maxOneBitCap && !c.hasLargeDelta && delta != rtcp.TypeTCCPacketReceivedLargeDelta { + return true + } + if len(c.deltas) < maxRunLengthCap && !c.hasDifferentTypes && delta == c.deltas[0] { + return true + } + return false +} + +func (c *chunk) add(delta uint16) { + c.deltas = append(c.deltas, delta) + c.hasLargeDelta = c.hasLargeDelta || delta == rtcp.TypeTCCPacketReceivedLargeDelta + c.hasDifferentTypes = c.hasDifferentTypes || delta != c.deltas[0] +} + +func (c *chunk) encode() rtcp.PacketStatusChunk { + if !c.hasDifferentTypes { + defer c.reset() + return &rtcp.RunLengthChunk{ + PacketStatusSymbol: c.deltas[0], + RunLength: uint16(len(c.deltas)), + } + } + if len(c.deltas) == maxOneBitCap { + defer c.reset() + return &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: c.deltas, + } + } + + minCap := min(maxTwoBitCap, len(c.deltas)) + svc := &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: c.deltas[:minCap], + } + c.deltas = c.deltas[minCap:] + c.hasDifferentTypes = false + c.hasLargeDelta = false + + if len(c.deltas) > 0 { + tmp := c.deltas[0] + for _, d := range c.deltas { + if tmp != d { + c.hasDifferentTypes = true + } + if d == rtcp.TypeTCCPacketReceivedLargeDelta { + c.hasLargeDelta = true + } + } + } + + return svc +} + +func (c *chunk) reset() { + c.deltas = []uint16{} + c.hasLargeDelta = false + c.hasDifferentTypes = false +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/twcc/twcc_test.go b/pkg/twcc/twcc_test.go new file mode 100644 index 00000000..91961be4 --- /dev/null +++ b/pkg/twcc/twcc_test.go @@ -0,0 +1,339 @@ +package twcc + +import ( + "testing" + + "github.com/pion/rtcp" + "github.com/stretchr/testify/assert" +) + +func Test_chunk_add(t *testing.T) { + t.Run("fill with not received", func(t *testing.T) { + c := &chunk{} + + for i := 0; i < maxRunLengthCap; i++ { + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived), i) + c.add(rtcp.TypeTCCPacketNotReceived) + } + assert.Equal(t, make([]uint16, maxRunLengthCap), c.deltas) + assert.False(t, c.hasDifferentTypes) + + assert.False(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + + statusChunk := c.encode() + assert.IsType(t, &rtcp.RunLengthChunk{}, statusChunk) + + buf, err := statusChunk.Marshal() + assert.NoError(t, err) + assert.Equal(t, []byte{0x1f, 0xff}, buf) + }) + + t.Run("fill with small delta", func(t *testing.T) { + c := &chunk{} + + for i := 0; i < maxOneBitCap; i++ { + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta), i) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + } + + assert.Equal(t, c.deltas, []uint16{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) + assert.False(t, c.hasDifferentTypes) + + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + assert.False(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + + statusChunk := c.encode() + assert.IsType(t, &rtcp.RunLengthChunk{}, statusChunk) + + buf, err := statusChunk.Marshal() + assert.NoError(t, err) + assert.Equal(t, []byte{0x20, 0xe}, buf) + }) + + t.Run("fill with large delta", func(t *testing.T) { + c := &chunk{} + + for i := 0; i < maxTwoBitCap; i++ { + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta), i) + c.add(rtcp.TypeTCCPacketReceivedLargeDelta) + } + + assert.Equal(t, c.deltas, []uint16{2, 2, 2, 2, 2, 2, 2}) + assert.True(t, c.hasLargeDelta) + assert.False(t, c.hasDifferentTypes) + + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + assert.False(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + + statusChunk := c.encode() + assert.IsType(t, &rtcp.RunLengthChunk{}, statusChunk) + + buf, err := statusChunk.Marshal() + assert.NoError(t, err) + assert.Equal(t, []byte{0x40, 0x7}, buf) + }) + + t.Run("fill with different types", func(t *testing.T) { + c := &chunk{} + + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + c.add(rtcp.TypeTCCPacketReceivedLargeDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + c.add(rtcp.TypeTCCPacketReceivedLargeDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + c.add(rtcp.TypeTCCPacketReceivedLargeDelta) + + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + + statusChunk := c.encode() + assert.IsType(t, &rtcp.StatusVectorChunk{}, statusChunk) + + buf, err := statusChunk.Marshal() + assert.NoError(t, err) + assert.Equal(t, []byte{0xd5, 0x6a}, buf) + }) + + t.Run("overfill and encode", func(t *testing.T) { + c := chunk{} + + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedSmallDelta)) + c.add(rtcp.TypeTCCPacketReceivedSmallDelta) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + assert.True(t, c.canAdd(rtcp.TypeTCCPacketNotReceived)) + c.add(rtcp.TypeTCCPacketNotReceived) + + assert.False(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + + statusChunk1 := c.encode() + assert.IsType(t, &rtcp.StatusVectorChunk{}, statusChunk1) + assert.Equal(t, 1, len(c.deltas)) + + assert.True(t, c.canAdd(rtcp.TypeTCCPacketReceivedLargeDelta)) + c.add(rtcp.TypeTCCPacketReceivedLargeDelta) + + statusChunk2 := c.encode() + assert.IsType(t, &rtcp.StatusVectorChunk{}, statusChunk2) + + assert.Equal(t, 0, len(c.deltas)) + + assert.Equal(t, &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{rtcp.TypeTCCPacketNotReceived, rtcp.TypeTCCPacketReceivedLargeDelta}, + }, statusChunk2) + }) +} + +func Test_feedback(t *testing.T) { + t.Run("add simple", func(t *testing.T) { + f := feedback{} + + got := f.addReceived(0, 10) + + assert.True(t, got) + }) + + t.Run("add too large", func(t *testing.T) { + f := feedback{} + + assert.False(t, f.addReceived(12, 8200*1000*250)) + }) + + t.Run("add received 1", func(t *testing.T) { + f := &feedback{} + f.setBase(1, 1000*1000) + + got := f.addReceived(1, 1023*1000) + + assert.True(t, got) + assert.Equal(t, uint16(2), f.nextSeqNr) + assert.Equal(t, int64(15), f.refTimestamp64MS) + + got = f.addReceived(4, 1086*1000) + assert.True(t, got) + assert.Equal(t, uint16(5), f.nextSeqNr) + assert.Equal(t, int64(15), f.refTimestamp64MS) + + assert.True(t, f.lastChunk.hasDifferentTypes) + assert.Equal(t, 4, len(f.lastChunk.deltas)) + assert.NotContains(t, f.lastChunk.deltas, rtcp.TypeTCCPacketReceivedLargeDelta) + }) + + t.Run("add received 2", func(t *testing.T) { + f := newFeedback(0, 0, 0) + f.setBase(5, 320*1000) + + got := f.addReceived(5, 320*1000) + assert.True(t, got) + got = f.addReceived(7, 448*1000) + assert.True(t, got) + got = f.addReceived(8, 512*1000) + assert.True(t, got) + got = f.addReceived(11, 768*1000) + assert.True(t, got) + + pkt := f.getRTCP() + + assert.True(t, pkt.Header.Padding) + assert.Equal(t, uint16(7), pkt.Header.Length) + assert.Equal(t, uint16(5), pkt.BaseSequenceNumber) + assert.Equal(t, uint16(7), pkt.PacketStatusCount) + assert.Equal(t, uint32(5), pkt.ReferenceTime) + assert.Equal(t, uint8(0), pkt.FbPktCount) + assert.Equal(t, 1, len(pkt.PacketChunks)) + + assert.Equal(t, []rtcp.PacketStatusChunk{&rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + }, + }}, pkt.PacketChunks) + + expectedDeltas := []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 0, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0200, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0100, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0400, + }, + } + assert.Equal(t, len(expectedDeltas), len(pkt.RecvDeltas)) + for i, d := range expectedDeltas { + assert.Equal(t, d, pkt.RecvDeltas[i]) + } + }) + + t.Run("add received wrapped sequence number", func(t *testing.T) { + f := newFeedback(0, 0, 0) + f.setBase(65535, 320*1000) + + got := f.addReceived(65535, 320*1000) + assert.True(t, got) + got = f.addReceived(7, 448*1000) + assert.True(t, got) + got = f.addReceived(8, 512*1000) + assert.True(t, got) + got = f.addReceived(11, 768*1000) + assert.True(t, got) + + pkt := f.getRTCP() + + assert.True(t, pkt.Header.Padding) + assert.Equal(t, uint16(7), pkt.Header.Length) + assert.Equal(t, uint16(65535), pkt.BaseSequenceNumber) + assert.Equal(t, uint16(13), pkt.PacketStatusCount) + assert.Equal(t, uint32(5), pkt.ReferenceTime) + assert.Equal(t, uint8(0), pkt.FbPktCount) + assert.Equal(t, 2, len(pkt.PacketChunks)) + + assert.Equal(t, []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedLargeDelta, + }, + }, + }, pkt.PacketChunks) + + expectedDeltas := []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 0, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0200, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0100, + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 0x0400, + }, + } + assert.Equal(t, len(expectedDeltas), len(pkt.RecvDeltas)) + for i, d := range expectedDeltas { + assert.Equal(t, d, pkt.RecvDeltas[i]) + } + }) + + t.Run("get RTCP", func(t *testing.T) { + testcases := []struct { + arrivalTS int64 + seqNr uint16 + wantRefTime uint32 + wantBaseSeqNr uint16 + }{ + {320, 1, 5, 1}, + {1000, 2, 15, 2}, + } + for _, tt := range testcases { + tt := tt + + t.Run("set correct base seq and time", func(t *testing.T) { + f := newFeedback(0, 0, 0) + f.setBase(tt.seqNr, tt.arrivalTS*1000) + + got := f.getRTCP() + assert.Equal(t, tt.wantRefTime, got.ReferenceTime) + assert.Equal(t, tt.wantBaseSeqNr, got.BaseSequenceNumber) + }) + } + }) +}