Skip to content

Commit 78da5a2

Browse files
committedJan 28, 2024
Add VLA extention header parser
Added VLA parser, builder and unit tests.
1 parent 314bd8e commit 78da5a2

File tree

4 files changed

+941
-0
lines changed

4 files changed

+941
-0
lines changed
 

‎codecs/av1/obu/leb128.go

+16
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,19 @@ func ReadLeb128(in []byte) (uint, uint, error) {
6767

6868
return 0, 0, ErrFailedToReadLEB128
6969
}
70+
71+
// WriteToLeb128 writes a uint to a LEB128 encoded byte slice.
72+
func WriteToLeb128(in uint) []byte {
73+
b := make([]byte, 10)
74+
75+
for i := 0; i < len(b); i++ {
76+
b[i] = byte(in & 0x7f)
77+
in >>= 7
78+
if in == 0 {
79+
return b[:i+1]
80+
}
81+
b[i] |= 0x80
82+
}
83+
84+
return b // unreachable
85+
}

‎codecs/av1/obu/leb128_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
package obu
55

66
import (
7+
"encoding/hex"
78
"errors"
9+
"fmt"
10+
"math"
811
"testing"
912
)
1013

@@ -40,3 +43,33 @@ func TestReadLeb128(t *testing.T) {
4043
t.Fatal("ReadLeb128 on a buffer with all MSB set should fail")
4144
}
4245
}
46+
47+
func TestWriteToLeb128(t *testing.T) {
48+
type testVector struct {
49+
value uint
50+
leb128 string
51+
}
52+
testVectors := []testVector{
53+
{150, "9601"},
54+
{240, "f001"},
55+
{400, "9003"},
56+
{720, "d005"},
57+
{1200, "b009"},
58+
{999999, "bf843d"},
59+
{0, "00"},
60+
{math.MaxUint32, "ffffffff0f"},
61+
}
62+
63+
runTest := func(t *testing.T, v testVector) {
64+
b := WriteToLeb128(v.value)
65+
if v.leb128 != hex.EncodeToString(b) {
66+
t.Errorf("Expected %s, got %s", v.leb128, hex.EncodeToString(b))
67+
}
68+
}
69+
70+
for _, v := range testVectors {
71+
t.Run(fmt.Sprintf("encode %d", v.value), func(t *testing.T) {
72+
runTest(t, v)
73+
})
74+
}
75+
}

‎vlaextension.go

+360
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package rtp
5+
6+
import (
7+
"encoding/binary"
8+
"errors"
9+
"fmt"
10+
"strings"
11+
12+
"github.com/pion/rtp/codecs/av1/obu"
13+
)
14+
15+
var (
16+
ErrVLATooShort = errors.New("VLA payload too short") // ErrVLATooShort is returned when payload is too short
17+
ErrVLAInvalidStreamCount = errors.New("invalid RTP stream count in VLA") // ErrVLAInvalidStreamCount is returned when RTP stream count is invalid
18+
ErrVLAInvalidStreamID = errors.New("invalid RTP stream ID in VLA") // ErrVLAInvalidStreamID is returned when RTP stream ID is invalid
19+
ErrVLAInvalidSpatialID = errors.New("invalid spatial ID in VLA") // ErrVLAInvalidSpatialID is returned when spatial ID is invalid
20+
ErrVLADuplicateSpatialID = errors.New("duplicate spatial ID in VLA") // ErrVLADuplicateSpatialID is returned when spatial ID is invalid
21+
ErrVLAInvalidTemporalLayer = errors.New("invalid temporal layer in VLA") // ErrVLAInvalidTemporalLayer is returned when temporal layer is invalid
22+
)
23+
24+
// SpatialLayer is a spatial layer in VLA.
25+
type SpatialLayer struct {
26+
RTPStreamID int
27+
SpatialID int
28+
TargetBitrates []int // target bitrates per temporal layer
29+
30+
// Following members are valid only when HasResolutionAndFramerate is true
31+
Width int
32+
Height int
33+
Framerate int
34+
}
35+
36+
// VLA is a Video Layer Allocation (VLA) extension.
37+
// See https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00
38+
type VLA struct {
39+
RTPStreamID int // 0-origin RTP stream ID (RID) this allocation is sent on (0..3)
40+
RTPStreamCount int // Number of RTP streams (1..4)
41+
ActiveSpatialLayer []SpatialLayer
42+
HasResolutionAndFramerate bool
43+
}
44+
45+
type vlaMarshalingContext struct {
46+
slMBs [4]uint8
47+
sls [4][4]*SpatialLayer
48+
commonSLBM uint8
49+
encodedTargetBitrates [][]byte
50+
requiredLen int
51+
}
52+
53+
func (v VLA) preprocessForMashaling(ctx *vlaMarshalingContext) error {
54+
for i := 0; i < len(v.ActiveSpatialLayer); i++ {
55+
sl := v.ActiveSpatialLayer[i]
56+
if sl.RTPStreamID < 0 || sl.RTPStreamID >= v.RTPStreamCount {
57+
return fmt.Errorf("invalid RTP streamID %d:%w", sl.RTPStreamID, ErrVLAInvalidStreamID)
58+
}
59+
if sl.SpatialID < 0 || sl.SpatialID >= 4 {
60+
return fmt.Errorf("invalid spatial ID %d: %w", sl.SpatialID, ErrVLAInvalidSpatialID)
61+
}
62+
if len(sl.TargetBitrates) == 0 || len(sl.TargetBitrates) > 4 {
63+
return fmt.Errorf("invalid temporal layer count %d: %w", len(sl.TargetBitrates), ErrVLAInvalidTemporalLayer)
64+
}
65+
ctx.slMBs[sl.RTPStreamID] |= 1 << sl.SpatialID
66+
if ctx.sls[sl.RTPStreamID][sl.SpatialID] != nil {
67+
return fmt.Errorf("duplicate spatial layer: %w", ErrVLADuplicateSpatialID)
68+
}
69+
ctx.sls[sl.RTPStreamID][sl.SpatialID] = &sl
70+
}
71+
return nil
72+
}
73+
74+
func (v VLA) encodeTargetBitrates(ctx *vlaMarshalingContext) {
75+
for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
76+
for spatialID := 0; spatialID < 4; spatialID++ {
77+
if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
78+
for _, kbps := range sl.TargetBitrates {
79+
leb128 := obu.WriteToLeb128(uint(kbps))
80+
ctx.encodedTargetBitrates = append(ctx.encodedTargetBitrates, leb128)
81+
ctx.requiredLen += len(leb128)
82+
}
83+
}
84+
}
85+
}
86+
}
87+
88+
func (v VLA) analyzeVLAForMarshaling() (*vlaMarshalingContext, error) {
89+
// Validate RTPStreamCount
90+
if v.RTPStreamCount <= 0 || v.RTPStreamCount > 4 {
91+
return nil, ErrVLAInvalidStreamCount
92+
}
93+
// Validate RTPStreamID
94+
if v.RTPStreamID < 0 || v.RTPStreamID >= v.RTPStreamCount {
95+
return nil, ErrVLAInvalidStreamID
96+
}
97+
98+
ctx := &vlaMarshalingContext{}
99+
err := v.preprocessForMashaling(ctx)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
ctx.commonSLBM = commonSLBMValues(ctx.slMBs[:])
105+
106+
// RID, NS, sl_bm fields
107+
if ctx.commonSLBM != 0 {
108+
ctx.requiredLen = 1
109+
} else {
110+
ctx.requiredLen = 3
111+
}
112+
113+
// #tl fields
114+
ctx.requiredLen += (len(v.ActiveSpatialLayer)-1)/4 + 1
115+
116+
v.encodeTargetBitrates(ctx)
117+
118+
if v.HasResolutionAndFramerate {
119+
ctx.requiredLen += len(v.ActiveSpatialLayer) * 5
120+
}
121+
122+
return ctx, nil
123+
}
124+
125+
// Marshal encodes VLA into a byte slice.
126+
func (v VLA) Marshal() ([]byte, error) {
127+
ctx, err := v.analyzeVLAForMarshaling()
128+
if err != nil {
129+
return nil, err
130+
}
131+
132+
payload := make([]byte, ctx.requiredLen)
133+
offset := 0
134+
135+
// RID, NS, sl_bm fields
136+
payload[offset] = byte(v.RTPStreamID<<6) | byte(v.RTPStreamCount-1)<<4 | ctx.commonSLBM
137+
138+
if ctx.commonSLBM == 0 {
139+
offset++
140+
for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
141+
if streamID%2 == 0 {
142+
payload[offset+streamID/2] |= ctx.slMBs[streamID] << 4
143+
} else {
144+
payload[offset+streamID/2] |= ctx.slMBs[streamID]
145+
}
146+
}
147+
offset += (v.RTPStreamCount - 1) / 2
148+
}
149+
150+
// #tl fields
151+
offset++
152+
var temporalLayerIndex int
153+
for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ {
154+
for spatialID := 0; spatialID < 4; spatialID++ {
155+
if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil {
156+
if temporalLayerIndex >= 4 {
157+
temporalLayerIndex = 0
158+
offset++
159+
}
160+
payload[offset] |= byte(len(sl.TargetBitrates)-1) << (2 * (3 - temporalLayerIndex))
161+
temporalLayerIndex++
162+
}
163+
}
164+
}
165+
166+
// Target bitrate fields
167+
offset++
168+
for _, encodedKbps := range ctx.encodedTargetBitrates {
169+
encodedSize := len(encodedKbps)
170+
copy(payload[offset:], encodedKbps)
171+
offset += encodedSize
172+
}
173+
174+
// Resolution & framerate fields
175+
if v.HasResolutionAndFramerate {
176+
for _, sl := range v.ActiveSpatialLayer {
177+
binary.BigEndian.PutUint16(payload[offset+0:], uint16(sl.Width-1))
178+
binary.BigEndian.PutUint16(payload[offset+2:], uint16(sl.Height-1))
179+
payload[offset+4] = byte(sl.Framerate)
180+
offset += 5
181+
}
182+
}
183+
184+
return payload, nil
185+
}
186+
187+
func commonSLBMValues(slMBs []uint8) uint8 {
188+
var common uint8
189+
for i := 0; i < len(slMBs); i++ {
190+
if slMBs[i] == 0 {
191+
continue
192+
}
193+
if common == 0 {
194+
common = slMBs[i]
195+
continue
196+
}
197+
if slMBs[i] != common {
198+
return 0
199+
}
200+
}
201+
return common
202+
}
203+
204+
type vlaUnmarshalingContext struct {
205+
payload []byte
206+
offset int
207+
slBMField uint8
208+
slBMs [4]uint8
209+
}
210+
211+
func (ctx *vlaUnmarshalingContext) checkRemainingLen(requiredLen int) bool {
212+
return len(ctx.payload)-ctx.offset >= requiredLen
213+
}
214+
215+
func (v *VLA) unmarshalSpatialLayers(ctx *vlaUnmarshalingContext) error {
216+
if !ctx.checkRemainingLen(1) {
217+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
218+
}
219+
v.RTPStreamID = int(ctx.payload[ctx.offset] >> 6 & 0b11)
220+
v.RTPStreamCount = int(ctx.payload[ctx.offset]>>4&0b11) + 1
221+
222+
// sl_bm fields
223+
ctx.slBMField = ctx.payload[ctx.offset] & 0b1111
224+
ctx.offset++
225+
226+
if ctx.slBMField != 0 {
227+
for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
228+
ctx.slBMs[streamID] = ctx.slBMField
229+
}
230+
} else {
231+
if !ctx.checkRemainingLen((v.RTPStreamCount-1)/2 + 1) {
232+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
233+
}
234+
// slX_bm fields
235+
for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
236+
var bm uint8
237+
if streamID%2 == 0 {
238+
bm = ctx.payload[ctx.offset+streamID/2] >> 4 & 0b1111
239+
} else {
240+
bm = ctx.payload[ctx.offset+streamID/2] & 0b1111
241+
}
242+
ctx.slBMs[streamID] = bm
243+
}
244+
ctx.offset += 1 + (v.RTPStreamCount-1)/2
245+
}
246+
247+
return nil
248+
}
249+
250+
func (v *VLA) unmarshalTemporalLayers(ctx *vlaUnmarshalingContext) error {
251+
if !ctx.checkRemainingLen(1) {
252+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
253+
}
254+
255+
var temporalLayerIndex int
256+
for streamID := 0; streamID < v.RTPStreamCount; streamID++ {
257+
for spatialID := 0; spatialID < 4; spatialID++ {
258+
if ctx.slBMs[streamID]&(1<<spatialID) == 0 {
259+
continue
260+
}
261+
if temporalLayerIndex >= 4 {
262+
temporalLayerIndex = 0
263+
ctx.offset++
264+
if !ctx.checkRemainingLen(1) {
265+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
266+
}
267+
}
268+
tlCount := int(ctx.payload[ctx.offset]>>(2*(3-temporalLayerIndex))&0b11) + 1
269+
temporalLayerIndex++
270+
sl := SpatialLayer{
271+
RTPStreamID: streamID,
272+
SpatialID: spatialID,
273+
TargetBitrates: make([]int, tlCount),
274+
}
275+
v.ActiveSpatialLayer = append(v.ActiveSpatialLayer, sl)
276+
}
277+
}
278+
ctx.offset++
279+
280+
// target bitrates
281+
for i, sl := range v.ActiveSpatialLayer {
282+
for j := range sl.TargetBitrates {
283+
kbps, n, err := obu.ReadLeb128(ctx.payload[ctx.offset:])
284+
if err != nil {
285+
return err
286+
}
287+
if !ctx.checkRemainingLen(int(n)) {
288+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
289+
}
290+
v.ActiveSpatialLayer[i].TargetBitrates[j] = int(kbps)
291+
ctx.offset += int(n)
292+
}
293+
}
294+
295+
return nil
296+
}
297+
298+
func (v *VLA) unmarshalResolutionAndFramerate(ctx *vlaUnmarshalingContext) error {
299+
if !ctx.checkRemainingLen(len(v.ActiveSpatialLayer) * 5) {
300+
return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort)
301+
}
302+
303+
v.HasResolutionAndFramerate = true
304+
305+
for i := range v.ActiveSpatialLayer {
306+
v.ActiveSpatialLayer[i].Width = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+0:])) + 1
307+
v.ActiveSpatialLayer[i].Height = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+2:])) + 1
308+
v.ActiveSpatialLayer[i].Framerate = int(ctx.payload[ctx.offset+4])
309+
ctx.offset += 5
310+
}
311+
312+
return nil
313+
}
314+
315+
// Unmarshal decodes VLA from a byte slice.
316+
func (v *VLA) Unmarshal(payload []byte) (int, error) {
317+
ctx := &vlaUnmarshalingContext{
318+
payload: payload,
319+
}
320+
321+
err := v.unmarshalSpatialLayers(ctx)
322+
if err != nil {
323+
return ctx.offset, err
324+
}
325+
326+
// #tl fields (build the list ActiveSpatialLayer at the same time)
327+
err = v.unmarshalTemporalLayers(ctx)
328+
if err != nil {
329+
return ctx.offset, err
330+
}
331+
332+
if len(ctx.payload) == ctx.offset {
333+
return ctx.offset, nil
334+
}
335+
336+
// resolution & framerate (optional)
337+
err = v.unmarshalResolutionAndFramerate(ctx)
338+
if err != nil {
339+
return ctx.offset, err
340+
}
341+
342+
return ctx.offset, nil
343+
}
344+
345+
// String makes VLA printable.
346+
func (v VLA) String() string {
347+
out := fmt.Sprintf("RID:%d,RTPStreamCount:%d", v.RTPStreamID, v.RTPStreamCount)
348+
var slOut []string
349+
for _, sl := range v.ActiveSpatialLayer {
350+
out2 := fmt.Sprintf("RTPStreamID:%d", sl.RTPStreamID)
351+
out2 += fmt.Sprintf(",TargetBitrates:%v", sl.TargetBitrates)
352+
if v.HasResolutionAndFramerate {
353+
out2 += fmt.Sprintf(",Resolution:(%d,%d)", sl.Width, sl.Height)
354+
out2 += fmt.Sprintf(",Framerate:%d", sl.Framerate)
355+
}
356+
slOut = append(slOut, out2)
357+
}
358+
out += fmt.Sprintf(",ActiveSpatialLayers:{%s}", strings.Join(slOut, ","))
359+
return out
360+
}

‎vlaextension_test.go

+532
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,532 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package rtp
5+
6+
import (
7+
"bytes"
8+
"encoding/hex"
9+
"errors"
10+
"reflect"
11+
"testing"
12+
)
13+
14+
func TestVLAMarshal(t *testing.T) {
15+
requireNoError := func(t *testing.T, err error) {
16+
if err != nil {
17+
t.Fatal(err)
18+
}
19+
}
20+
21+
t.Run("3 streams no resolution and framerate", func(t *testing.T) {
22+
vla := &VLA{
23+
RTPStreamID: 0,
24+
RTPStreamCount: 3,
25+
ActiveSpatialLayer: []SpatialLayer{
26+
{
27+
RTPStreamID: 0,
28+
SpatialID: 0,
29+
TargetBitrates: []int{150},
30+
},
31+
{
32+
RTPStreamID: 1,
33+
SpatialID: 0,
34+
TargetBitrates: []int{240, 400},
35+
},
36+
{
37+
RTPStreamID: 2,
38+
SpatialID: 0,
39+
TargetBitrates: []int{720, 1200},
40+
},
41+
},
42+
}
43+
44+
bytesActual, err := vla.Marshal()
45+
requireNoError(t, err)
46+
bytesExpected, err := hex.DecodeString("21149601f0019003d005b009")
47+
requireNoError(t, err)
48+
if !bytes.Equal(bytesExpected, bytesActual) {
49+
t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual))
50+
}
51+
})
52+
53+
t.Run("3 streams with resolution and framerate", func(t *testing.T) {
54+
vla := &VLA{
55+
RTPStreamID: 2,
56+
RTPStreamCount: 3,
57+
ActiveSpatialLayer: []SpatialLayer{
58+
{
59+
RTPStreamID: 0,
60+
SpatialID: 0,
61+
TargetBitrates: []int{150},
62+
Width: 320,
63+
Height: 180,
64+
Framerate: 30,
65+
},
66+
{
67+
RTPStreamID: 1,
68+
SpatialID: 0,
69+
TargetBitrates: []int{240, 400},
70+
Width: 640,
71+
Height: 360,
72+
Framerate: 30,
73+
},
74+
{
75+
RTPStreamID: 2,
76+
SpatialID: 0,
77+
TargetBitrates: []int{720, 1200},
78+
Width: 1280,
79+
Height: 720,
80+
Framerate: 30,
81+
},
82+
},
83+
HasResolutionAndFramerate: true,
84+
}
85+
86+
bytesActual, err := vla.Marshal()
87+
requireNoError(t, err)
88+
bytesExpected, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e")
89+
requireNoError(t, err)
90+
if !bytes.Equal(bytesExpected, bytesActual) {
91+
t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual))
92+
}
93+
})
94+
95+
t.Run("Negative RTPStreamCount", func(t *testing.T) {
96+
vla := &VLA{
97+
RTPStreamID: 0,
98+
RTPStreamCount: -1,
99+
ActiveSpatialLayer: []SpatialLayer{},
100+
}
101+
_, err := vla.Marshal()
102+
if !errors.Is(err, ErrVLAInvalidStreamCount) {
103+
t.Fatal("expected ErrVLAInvalidRTPStreamCount")
104+
}
105+
})
106+
107+
t.Run("RTPStreamCount too large", func(t *testing.T) {
108+
vla := &VLA{
109+
RTPStreamID: 0,
110+
RTPStreamCount: 5,
111+
ActiveSpatialLayer: []SpatialLayer{{}, {}, {}, {}, {}},
112+
}
113+
_, err := vla.Marshal()
114+
if !errors.Is(err, ErrVLAInvalidStreamCount) {
115+
t.Fatal("expected ErrVLAInvalidRTPStreamCount")
116+
}
117+
})
118+
119+
t.Run("Negative RTPStreamID", func(t *testing.T) {
120+
vla := &VLA{
121+
RTPStreamID: -1,
122+
RTPStreamCount: 1,
123+
ActiveSpatialLayer: []SpatialLayer{{}},
124+
}
125+
_, err := vla.Marshal()
126+
if !errors.Is(err, ErrVLAInvalidStreamID) {
127+
t.Fatalf("expected ErrVLAInvalidRTPStreamID, actual %v", err)
128+
}
129+
})
130+
131+
t.Run("RTPStreamID to large", func(t *testing.T) {
132+
vla := &VLA{
133+
RTPStreamID: 1,
134+
RTPStreamCount: 1,
135+
ActiveSpatialLayer: []SpatialLayer{{}},
136+
}
137+
_, err := vla.Marshal()
138+
if !errors.Is(err, ErrVLAInvalidStreamID) {
139+
t.Fatalf("expected ErrVLAInvalidRTPStreamID: %v", err)
140+
}
141+
})
142+
143+
t.Run("Invalid stream ID in the spatial layer", func(t *testing.T) {
144+
vla := &VLA{
145+
RTPStreamID: 0,
146+
RTPStreamCount: 1,
147+
ActiveSpatialLayer: []SpatialLayer{{
148+
RTPStreamID: -1,
149+
}},
150+
}
151+
_, err := vla.Marshal()
152+
if !errors.Is(err, ErrVLAInvalidStreamID) {
153+
t.Fatalf("expected ErrVLAInvalidStreamID: %v", err)
154+
}
155+
vla = &VLA{
156+
RTPStreamID: 0,
157+
RTPStreamCount: 1,
158+
ActiveSpatialLayer: []SpatialLayer{{
159+
RTPStreamID: 1,
160+
}},
161+
}
162+
_, err = vla.Marshal()
163+
if !errors.Is(err, ErrVLAInvalidStreamID) {
164+
t.Fatalf("expected ErrVLAInvalidStreamID: %v", err)
165+
}
166+
})
167+
168+
t.Run("Invalid spatial ID in the spatial layer", func(t *testing.T) {
169+
vla := &VLA{
170+
RTPStreamID: 0,
171+
RTPStreamCount: 1,
172+
ActiveSpatialLayer: []SpatialLayer{{
173+
RTPStreamID: 0,
174+
SpatialID: -1,
175+
}},
176+
}
177+
_, err := vla.Marshal()
178+
if !errors.Is(err, ErrVLAInvalidSpatialID) {
179+
t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err)
180+
}
181+
vla = &VLA{
182+
RTPStreamID: 0,
183+
RTPStreamCount: 1,
184+
ActiveSpatialLayer: []SpatialLayer{{
185+
RTPStreamID: 0,
186+
SpatialID: 5,
187+
}},
188+
}
189+
_, err = vla.Marshal()
190+
if !errors.Is(err, ErrVLAInvalidSpatialID) {
191+
t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err)
192+
}
193+
})
194+
195+
t.Run("Invalid temporal layer in the spatial layer", func(t *testing.T) {
196+
vla := &VLA{
197+
RTPStreamID: 0,
198+
RTPStreamCount: 1,
199+
ActiveSpatialLayer: []SpatialLayer{{
200+
RTPStreamID: 0,
201+
SpatialID: 0,
202+
TargetBitrates: []int{},
203+
}},
204+
}
205+
_, err := vla.Marshal()
206+
if !errors.Is(err, ErrVLAInvalidTemporalLayer) {
207+
t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err)
208+
}
209+
vla = &VLA{
210+
RTPStreamID: 0,
211+
RTPStreamCount: 1,
212+
ActiveSpatialLayer: []SpatialLayer{{
213+
RTPStreamID: 0,
214+
SpatialID: 0,
215+
TargetBitrates: []int{100, 200, 300, 400, 500},
216+
}},
217+
}
218+
_, err = vla.Marshal()
219+
if !errors.Is(err, ErrVLAInvalidTemporalLayer) {
220+
t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err)
221+
}
222+
})
223+
224+
t.Run("Duplicate spatial ID in the spatial layer", func(t *testing.T) {
225+
vla := &VLA{
226+
RTPStreamID: 0,
227+
RTPStreamCount: 1,
228+
ActiveSpatialLayer: []SpatialLayer{{
229+
RTPStreamID: 0,
230+
SpatialID: 0,
231+
TargetBitrates: []int{100},
232+
}, {
233+
RTPStreamID: 0,
234+
SpatialID: 0,
235+
TargetBitrates: []int{200},
236+
}},
237+
}
238+
_, err := vla.Marshal()
239+
if !errors.Is(err, ErrVLADuplicateSpatialID) {
240+
t.Fatalf("expected ErrVLADuplicateSpatialID: %v", err)
241+
}
242+
})
243+
}
244+
245+
func TestVLAUnmarshal(t *testing.T) {
246+
requireEqualInt := func(t *testing.T, expected, actual int) {
247+
if expected != actual {
248+
t.Fatalf("expected %d, actual %d", expected, actual)
249+
}
250+
}
251+
requireNoError := func(t *testing.T, err error) {
252+
if err != nil {
253+
t.Fatal(err)
254+
}
255+
}
256+
requireTrue := func(t *testing.T, val bool) {
257+
if !val {
258+
t.Fatal("expected true")
259+
}
260+
}
261+
requireFalse := func(t *testing.T, val bool) {
262+
if val {
263+
t.Fatal("expected false")
264+
}
265+
}
266+
267+
t.Run("3 streams no resolution and framerate", func(t *testing.T) {
268+
// two layer ("low", "high")
269+
b, err := hex.DecodeString("21149601f0019003d005b009")
270+
requireNoError(t, err)
271+
if err != nil {
272+
t.Fatal("failed to decode input data")
273+
}
274+
275+
vla := &VLA{}
276+
n, err := vla.Unmarshal(b)
277+
requireNoError(t, err)
278+
requireEqualInt(t, len(b), n)
279+
280+
requireEqualInt(t, 0, vla.RTPStreamID)
281+
requireEqualInt(t, 3, vla.RTPStreamCount)
282+
requireEqualInt(t, 3, len(vla.ActiveSpatialLayer))
283+
284+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID)
285+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID)
286+
requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates))
287+
requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0])
288+
289+
requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID)
290+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID)
291+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates))
292+
requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0])
293+
requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1])
294+
295+
requireFalse(t, vla.HasResolutionAndFramerate)
296+
297+
requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID)
298+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID)
299+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates))
300+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0])
301+
requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1])
302+
})
303+
304+
t.Run("3 streams with resolution and framerate", func(t *testing.T) {
305+
b, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e")
306+
requireNoError(t, err)
307+
308+
vla := &VLA{}
309+
n, err := vla.Unmarshal(b)
310+
requireNoError(t, err)
311+
requireEqualInt(t, len(b), n)
312+
313+
requireEqualInt(t, 2, vla.RTPStreamID)
314+
requireEqualInt(t, 3, vla.RTPStreamCount)
315+
316+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID)
317+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID)
318+
requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates))
319+
requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0])
320+
321+
requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID)
322+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID)
323+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates))
324+
requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0])
325+
requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1])
326+
327+
requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID)
328+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID)
329+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates))
330+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0])
331+
requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1])
332+
333+
requireTrue(t, vla.HasResolutionAndFramerate)
334+
335+
requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width)
336+
requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height)
337+
requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate)
338+
requireEqualInt(t, 640, vla.ActiveSpatialLayer[1].Width)
339+
requireEqualInt(t, 360, vla.ActiveSpatialLayer[1].Height)
340+
requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate)
341+
requireEqualInt(t, 1280, vla.ActiveSpatialLayer[2].Width)
342+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].Height)
343+
requireEqualInt(t, 30, vla.ActiveSpatialLayer[2].Framerate)
344+
})
345+
346+
t.Run("2 streams", func(t *testing.T) {
347+
// two layer ("low", "high")
348+
b, err := hex.DecodeString("1110c801d005b009")
349+
requireNoError(t, err)
350+
351+
vla := &VLA{}
352+
n, err := vla.Unmarshal(b)
353+
requireNoError(t, err)
354+
requireEqualInt(t, len(b), n)
355+
356+
requireEqualInt(t, 0, vla.RTPStreamID)
357+
requireEqualInt(t, 2, vla.RTPStreamCount)
358+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer))
359+
360+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID)
361+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID)
362+
requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates))
363+
requireEqualInt(t, 200, vla.ActiveSpatialLayer[0].TargetBitrates[0])
364+
365+
requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID)
366+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID)
367+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates))
368+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0])
369+
requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1])
370+
371+
requireFalse(t, vla.HasResolutionAndFramerate)
372+
})
373+
374+
t.Run("3 streams mid paused with resolution and framerate", func(t *testing.T) {
375+
b, err := hex.DecodeString("601010109601d005b009013f00b31e04ff02cf1e")
376+
requireNoError(t, err)
377+
378+
vla := &VLA{}
379+
n, err := vla.Unmarshal(b)
380+
requireNoError(t, err)
381+
requireEqualInt(t, len(b), n)
382+
383+
requireEqualInt(t, 1, vla.RTPStreamID)
384+
requireEqualInt(t, 3, vla.RTPStreamCount)
385+
386+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID)
387+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID)
388+
requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates))
389+
requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0])
390+
391+
requireEqualInt(t, 2, vla.ActiveSpatialLayer[1].RTPStreamID)
392+
requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID)
393+
requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates))
394+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0])
395+
requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1])
396+
397+
requireTrue(t, vla.HasResolutionAndFramerate)
398+
399+
requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width)
400+
requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height)
401+
requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate)
402+
requireEqualInt(t, 1280, vla.ActiveSpatialLayer[1].Width)
403+
requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].Height)
404+
requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate)
405+
})
406+
407+
t.Run("extra 1", func(t *testing.T) {
408+
b, err := hex.DecodeString("a0001040ac02f403")
409+
requireNoError(t, err)
410+
411+
vla := &VLA{}
412+
n, err := vla.Unmarshal(b)
413+
requireNoError(t, err)
414+
requireEqualInt(t, len(b), n)
415+
})
416+
417+
t.Run("extra 2", func(t *testing.T) {
418+
b, err := hex.DecodeString("a00010409405cc08")
419+
requireNoError(t, err)
420+
421+
vla := &VLA{}
422+
n, err := vla.Unmarshal(b)
423+
requireNoError(t, err)
424+
requireEqualInt(t, len(b), n)
425+
})
426+
}
427+
428+
func TestVLAMarshalThenUnmarshal(t *testing.T) {
429+
requireEqualInt := func(t *testing.T, expected, actual int) {
430+
if expected != actual {
431+
t.Fatalf("expected %d, actual %d", expected, actual)
432+
}
433+
}
434+
requireNoError := func(t *testing.T, err error) {
435+
if err != nil {
436+
t.Fatal(err)
437+
}
438+
}
439+
440+
t.Run("multiple spatial layers", func(t *testing.T) {
441+
var spatialLayers []SpatialLayer
442+
for streamID := 0; streamID < 3; streamID++ {
443+
for spatialID := 0; spatialID < 4; spatialID++ {
444+
spatialLayers = append(spatialLayers, SpatialLayer{
445+
RTPStreamID: streamID,
446+
SpatialID: spatialID,
447+
TargetBitrates: []int{150, 200},
448+
Width: 320,
449+
Height: 180,
450+
Framerate: 30,
451+
})
452+
}
453+
}
454+
455+
vla0 := &VLA{
456+
RTPStreamID: 2,
457+
RTPStreamCount: 3,
458+
ActiveSpatialLayer: spatialLayers,
459+
HasResolutionAndFramerate: true,
460+
}
461+
462+
b, err := vla0.Marshal()
463+
requireNoError(t, err)
464+
465+
vla1 := &VLA{}
466+
n, err := vla1.Unmarshal(b)
467+
requireNoError(t, err)
468+
requireEqualInt(t, len(b), n)
469+
470+
if !reflect.DeepEqual(vla0, vla1) {
471+
t.Fatalf("expected %v, actual %v", vla0, vla1)
472+
}
473+
})
474+
475+
t.Run("different spatial layer bitmasks", func(t *testing.T) {
476+
var spatialLayers []SpatialLayer
477+
for streamID := 0; streamID < 4; streamID++ {
478+
for spatialID := 0; spatialID < streamID+1; spatialID++ {
479+
spatialLayers = append(spatialLayers, SpatialLayer{
480+
RTPStreamID: streamID,
481+
SpatialID: spatialID,
482+
TargetBitrates: []int{150, 200},
483+
Width: 320,
484+
Height: 180,
485+
Framerate: 30,
486+
})
487+
}
488+
}
489+
490+
vla0 := &VLA{
491+
RTPStreamID: 0,
492+
RTPStreamCount: 4,
493+
ActiveSpatialLayer: spatialLayers,
494+
HasResolutionAndFramerate: true,
495+
}
496+
497+
b, err := vla0.Marshal()
498+
requireNoError(t, err)
499+
if b[0]&0x0f != 0 {
500+
t.Error("expects sl_bm to be 0")
501+
}
502+
if b[1] != 0x13 {
503+
t.Error("expects sl0_bm,sl1_bm to be b0001,b0011")
504+
}
505+
if b[2] != 0x7f {
506+
t.Error("expects sl1_bm,sl2_bm to be b0111,b1111")
507+
}
508+
t.Logf("b: %s", hex.EncodeToString(b))
509+
510+
vla1 := &VLA{}
511+
n, err := vla1.Unmarshal(b)
512+
requireNoError(t, err)
513+
requireEqualInt(t, len(b), n)
514+
515+
if !reflect.DeepEqual(vla0, vla1) {
516+
t.Fatalf("expected %v, actual %v", vla0, vla1)
517+
}
518+
})
519+
}
520+
521+
func FuzzVLAUnmarshal(f *testing.F) {
522+
f.Add([]byte{0})
523+
f.Add([]byte("70"))
524+
525+
f.Fuzz(func(t *testing.T, data []byte) {
526+
vla := &VLA{}
527+
_, err := vla.Unmarshal(data)
528+
if err != nil {
529+
t.Skip() // If the function returns an error, we skip the test case
530+
}
531+
})
532+
}

0 commit comments

Comments
 (0)
Please sign in to comment.