From 5ce39f696e32617a4da6f3a3085f8dabb4302f97 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Wed, 21 Dec 2022 17:58:17 +0900 Subject: [PATCH] Fix maximum packet count handling Accoring to RFC3711 section 9.2, SRTP/SRTCP session must not wrap SRTP ROC and SRTCP index without changing the master key. Also fix Context.SetROC() with ROC>0xffff. --- context.go | 10 +-- errors.go | 1 + option.go | 4 +- srtcp.go | 11 ++- srtcp_test.go | 228 ++++++++++++++++++++++++++++++++++---------------- srtp.go | 15 +++- srtp_test.go | 166 ++++++++++++++++++++++++++++++------ 7 files changed, 326 insertions(+), 109 deletions(-) diff --git a/context.go b/context.go index bf871b2..78deaee 100644 --- a/context.go +++ b/context.go @@ -16,6 +16,7 @@ const ( labelSRTCPSalt = 0x05 maxSequenceNumber = 65535 + maxROC = (1 << 32) - 1 seqNumMedian = 1 << 15 seqNumMax = 1 << 16 @@ -60,8 +61,7 @@ type Context struct { // Passing multiple options which set the same parameter let the last one valid. // Following example create SRTP Context with replay protection with window size of 256. // -// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) -// +// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { keyLen, err := profile.keyLen() if err != nil { @@ -112,7 +112,7 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts } // https://tools.ietf.org/html/rfc3550#appendix-A.1 -func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, int32) { +func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) { seq := int32(sequenceNumber) localRoc := uint32(s.index >> 16) localSeq := int32(s.index & (seqNumMax - 1)) @@ -147,7 +147,7 @@ func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, int32) } } - return guessRoc, difference + return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC) } func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) { @@ -201,7 +201,7 @@ func (c *Context) ROC(ssrc uint32) (uint32, bool) { // SetROC sets SRTP rollover counter value of specified SSRC. func (c *Context) SetROC(ssrc uint32, roc uint32) { s := c.getSRTPSSRCState(ssrc) - s.index = uint64(roc<<16) | (s.index & (seqNumMax - 1)) + s.index = uint64(roc)<<16 | (s.index & (seqNumMax - 1)) } // Index returns SRTCP index value of specified SSRC. diff --git a/errors.go b/errors.go index 55a67bc..db5b7db 100644 --- a/errors.go +++ b/errors.go @@ -19,6 +19,7 @@ var ( errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") + errExceededMaxPackets = errors.New("exceeded the maximum number of packets") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") diff --git a/option.go b/option.go index 86ecd8e..25636b6 100644 --- a/option.go +++ b/option.go @@ -11,7 +11,7 @@ type ContextOption func(*Context) error func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { - return replaydetector.WithWrap(windowSize, maxSequenceNumber) + return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber) } return nil } @@ -21,7 +21,7 @@ func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive func SRTCPReplayProtection(windowSize uint) ContextOption { return func(c *Context) error { c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { - return replaydetector.WithWrap(windowSize, maxSRTCPIndex) + return replaydetector.New(windowSize, maxSRTCPIndex) } return nil } diff --git a/srtcp.go b/srtcp.go index d3e387b..2812851 100644 --- a/srtcp.go +++ b/srtcp.go @@ -63,11 +63,16 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { ssrc := binary.BigEndian.Uint32(decrypted[4:]) s := c.getSRTCPSSRCState(ssrc) + if s.srtcpIndex >= maxSRTCPIndex { + // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key + // (whichever occurs before), the key management MUST be called to provide new master key(s) + // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. + // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 + return nil, errExceededMaxPackets + } + // We roll over early because MSB is used for marking as encrypted s.srtcpIndex++ - if s.srtcpIndex > maxSRTCPIndex { - s.srtcpIndex = 0 - } return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc) } diff --git a/srtcp_test.go b/srtcp_test.go index f2e870d..96133ab 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -25,7 +25,7 @@ type rtcpTestCase struct { packets []rtcpTestPacket } -func rtcpTestCasesSingle() map[string]rtcpTestCase { +func rtcpTestCases() map[string]rtcpTestCase { return map[string]rtcpTestCase{ "AEAD_AES_128_GCM": { algo: ProtectionProfileAeadAes128Gcm, @@ -112,72 +112,6 @@ func rtcpTestCasesSingle() map[string]rtcpTestCase { } } -func rtcpTestCases() map[string]rtcpTestCase { - single := rtcpTestCasesSingle() - return map[string]rtcpTestCase{ - "AEAD_AES_128_GCM": single["AEAD_AES_128_GCM"], - "AES_128_CM_HMAC_SHA1_80": { - algo: ProtectionProfileAes128CmHmacSha1_80, - masterKey: single["AES_128_CM_HMAC_SHA1_80"].masterKey, - masterSalt: single["AES_128_CM_HMAC_SHA1_80"].masterSalt, - packets: []rtcpTestPacket{ - single["AES_128_CM_HMAC_SHA1_80"].packets[0], - single["AES_128_CM_HMAC_SHA1_80"].packets[1], - { - ssrc: 0x11111111, - index: 0x7ffffffe, // Upper boundary of index - pktType: rtcp.TypeSenderReport, - encrypted: []byte{ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, - 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5, - 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, - 0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a, - 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, - 0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, - 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, - 0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed, - 0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97, - }, - decrypted: []byte{ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, - 0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde, - 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, - 0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b, - 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e, - 0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, - 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d, - }, - }, - { - ssrc: 0x11111111, - index: 0x7fffffff, // Will be wrapped to 0 - pktType: rtcp.TypeSenderReport, - encrypted: []byte{ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, - 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5, - 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, - 0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a, - 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, - 0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, - 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, - 0x80, 0x00, 0x00, 0x00, 0x7d, 0x51, 0xf8, 0x0e, - 0x56, 0x40, 0x72, 0x7b, 0x9e, 0x02, - }, - decrypted: []byte{ - 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, - 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed, - 0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40, - 0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4, - 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, - 0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, - 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28, - }, - }, - }, - }, - } -} - func TestRTCPLifecycle(t *testing.T) { options := map[string][]ContextOption{ "Default": {}, @@ -371,7 +305,7 @@ func TestRTCPInvalidAuthTag(t *testing.T) { } func TestRTCPReplayDetectorSeparation(t *testing.T) { - for caseName, testCase := range rtcpTestCasesSingle() { + for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) @@ -409,7 +343,7 @@ func getRTCPIndex(encrypted []byte, authTagLen int) uint32 { } func TestEncryptRTCPSeparation(t *testing.T) { - for caseName, testCase := range rtcpTestCasesSingle() { + for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) @@ -462,7 +396,7 @@ func TestEncryptRTCPSeparation(t *testing.T) { } func TestRTCPDecryptShortenedPacket(t *testing.T) { - for caseName, testCase := range rtcpTestCasesSingle() { + for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { pkt := testCase.packets[0] @@ -479,3 +413,157 @@ func TestRTCPDecryptShortenedPacket(t *testing.T) { }) } } + +func TestRTCPMaxPackets(t *testing.T) { + const ssrc = 0x11111111 + testCases := map[string]rtcpTestCase{ + "AEAD_AES_128_GCM": { + algo: ProtectionProfileAeadAes128Gcm, + masterKey: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}, + masterSalt: []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab}, + packets: []rtcpTestPacket{ + { + pktType: rtcp.TypeSenderReport, + encrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x02, 0xb6, 0xc1, 0x47, 0x92, 0xbe, 0xf0, 0xae, + 0xd9, 0x40, 0xa5, 0x1c, 0xbe, 0xec, 0xaf, 0xfc, + 0x7d, 0x86, 0x3b, 0xbb, 0x93, 0x0c, 0xb0, 0xd4, + 0xea, 0x4a, 0x3c, 0x5b, 0xd1, 0xd5, 0x47, 0xb1, + 0x1a, 0x61, 0xae, 0xa6, 0x1a, 0x0c, 0xb9, 0x14, + 0xa5, 0x16, 0x08, 0xe4, 0xfb, 0x0d, 0x15, 0xba, + 0x7f, 0x70, 0x2b, 0xb8, 0x99, 0x97, 0x91, 0xfd, + 0x53, 0x03, 0xcd, 0x57, 0xbb, 0x8f, 0x93, 0xbe, + 0xff, 0xff, 0xff, 0xff, + }, + decrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde, + 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, + 0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b, + 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e, + 0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, + 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d, + }, + }, + { + pktType: rtcp.TypeSenderReport, + encrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x77, 0x47, 0x0c, 0x21, 0xc2, 0xcd, 0x33, 0xa7, + 0x5a, 0x81, 0xb5, 0xb5, 0x8f, 0xe2, 0x34, 0x28, + 0x11, 0xa8, 0xa3, 0x34, 0xf8, 0x9d, 0xfc, 0xd8, + 0xcb, 0x87, 0xe2, 0x51, 0x8e, 0xae, 0xdb, 0xfd, + 0x9d, 0xf1, 0xfa, 0x18, 0xe2, 0xdc, 0x0a, 0xd4, + 0xe3, 0x06, 0x18, 0xff, 0xf7, 0x27, 0x92, 0x1f, + 0x28, 0xcd, 0x3c, 0xf8, 0xa4, 0x0a, 0x2b, 0xbb, + 0x5b, 0x1f, 0x4d, 0x1f, 0xef, 0x0e, 0xc4, 0x91, + 0x80, 0x00, 0x00, 0x01, + }, + decrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed, + 0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40, + 0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4, + 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, + 0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, + 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28, + }, + }, + }, + }, + "AES_128_CM_HMAC_SHA1_80": { + algo: ProtectionProfileAes128CmHmacSha1_80, + masterKey: []byte{0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f, 0x34}, + masterSalt: []byte{0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23}, + packets: []rtcpTestPacket{ + { + encrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5, + 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, + 0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a, + 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, + 0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, + 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, + 0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed, + 0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97, + }, + decrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde, + 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, + 0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b, + 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e, + 0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, + 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d, + }, + }, + { + encrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0x12, 0x71, 0x75, 0x7a, 0xb0, 0xfd, 0x80, 0xcb, + 0x26, 0xbb, 0x54, 0x5a, 0x1c, 0x0e, 0x98, 0x09, + 0xbe, 0x60, 0x23, 0xd8, 0xe6, 0x6e, 0x68, 0xe8, + 0x6e, 0x9c, 0xb2, 0x7e, 0x02, 0xa7, 0xab, 0xfe, + 0xb3, 0xf4, 0x4c, 0x13, 0xc3, 0xac, 0x97, 0x2c, + 0x35, 0x91, 0xbb, 0x37, 0x9c, 0x86, 0x28, 0x85, + 0x80, 0x00, 0x00, 0x01, 0x89, 0x76, 0x07, 0xca, + 0xd9, 0xc4, 0xcb, 0xca, 0x66, 0xab, + }, + decrypted: []byte{ + 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, + 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed, + 0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40, + 0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4, + 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, + 0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, + 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28, + }, + }, + }, + }, + } + + for caseName, testCase := range testCases { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + assert := assert.New(t) + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + // Upper boundary of index + encryptContext.SetIndex(ssrc, 0x7ffffffe) + + decryptResult, err := decryptContext.DecryptRTCP(nil, testCase.packets[0].encrypted, nil) + if err != nil { + t.Error(err) + } + assert.Equal(testCase.packets[0].decrypted, decryptResult, "RTCP failed to decrypt") + + encryptResult, err := encryptContext.EncryptRTCP(nil, testCase.packets[0].decrypted, nil) + if err != nil { + t.Error(err) + } + assert.Equal(testCase.packets[0].encrypted, encryptResult, "RTCP failed to encrypt") + + // Next packet will exceeds the maximum packet count + _, err = decryptContext.DecryptRTCP(nil, testCase.packets[1].encrypted, nil) + if !errors.Is(err, errDuplicated) { + t.Errorf("Expected error: '%v', got: '%v'", errDuplicated, err) + } + + _, err = encryptContext.EncryptRTCP(nil, testCase.packets[1].decrypted, nil) + if !errors.Is(err, errExceededMaxPackets) { + t.Errorf("Expected error: '%v', got: '%v'", errExceededMaxPackets, err) + } + }) + } +} diff --git a/srtp.go b/srtp.go index 0feaf0f..c8ad8a5 100644 --- a/srtp.go +++ b/srtp.go @@ -8,7 +8,10 @@ import ( func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { s := c.getSRTPSSRCState(header.SSRC) - markAsValid, ok := s.replayDetector.Check(uint64(header.SequenceNumber)) + roc, diff, _ := s.nextRolloverCount(header.SequenceNumber) + markAsValid, ok := s.replayDetector.Check( + (uint64(roc) << 16) | uint64(header.SequenceNumber), + ) if !ok { return nil, &duplicatedError{ Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber), @@ -20,7 +23,6 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL return nil, err } dst = growBufferSize(dst, len(ciphertext)-authTagLen) - roc, diff := s.nextRolloverCount(header.SequenceNumber) dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { @@ -67,7 +69,14 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ( // Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload. func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) { s := c.getSRTPSSRCState(header.SSRC) - roc, diff := s.nextRolloverCount(header.SequenceNumber) + roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber) + if ovf { + // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key + // (whichever occurs before), the key management MUST be called to provide new master key(s) + // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. + // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 + return nil, errExceededMaxPackets + } s.updateRolloverCount(header.SequenceNumber, diff) return c.cipher.encryptRTP(dst, header, payload, roc) diff --git a/srtp_test.go b/srtp_test.go index 8333b48..91091c5 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -76,68 +76,104 @@ func TestRolloverCount(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} // Set initial seqnum - roc, diff := s.nextRolloverCount(65530) + roc, diff, ovf := s.nextRolloverCount(65530) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(65530, diff) // Invalid packets never update ROC - _, _ = s.nextRolloverCount(0) - _, _ = s.nextRolloverCount(0x4000) - _, _ = s.nextRolloverCount(0x8000) - _, _ = s.nextRolloverCount(0xFFFF) - _, _ = s.nextRolloverCount(0) + s.nextRolloverCount(0) + s.nextRolloverCount(0x4000) + s.nextRolloverCount(0x8000) + s.nextRolloverCount(0xFFFF) + s.nextRolloverCount(0) // We rolled over to 0 - roc, diff = s.nextRolloverCount(0) + roc, diff, ovf = s.nextRolloverCount(0) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(0, diff) - roc, diff = s.nextRolloverCount(65530) + roc, diff, ovf = s.nextRolloverCount(65530) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(65530, diff) - roc, diff = s.nextRolloverCount(5) + roc, diff, ovf = s.nextRolloverCount(5) if roc != 1 { t.Errorf("rolloverCounter was not updated when it rolled over initial, to handle out of order") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(5, diff) - _, diff = s.nextRolloverCount(6) + _, diff, _ = s.nextRolloverCount(6) s.updateRolloverCount(6, diff) - _, diff = s.nextRolloverCount(7) + _, diff, _ = s.nextRolloverCount(7) s.updateRolloverCount(7, diff) - roc, diff = s.nextRolloverCount(8) + roc, diff, _ = s.nextRolloverCount(8) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } s.updateRolloverCount(8, diff) // valid packets never update ROC - roc, diff = s.nextRolloverCount(0x4000) + roc, diff, ovf = s.nextRolloverCount(0x4000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(0x4000, diff) - roc, diff = s.nextRolloverCount(0x8000) + roc, diff, ovf = s.nextRolloverCount(0x8000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(0x8000, diff) - roc, diff = s.nextRolloverCount(0xFFFF) + roc, diff, ovf = s.nextRolloverCount(0xFFFF) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(0xFFFF, diff) - roc, _ = s.nextRolloverCount(0) + roc, _, ovf = s.nextRolloverCount(0) if roc != 2 { t.Errorf("rolloverCounter must be incremented after wrapping, got %d", roc) } + if ovf { + t.Error("Should not overflow") + } +} + +func TestRolloverCountOverflow(t *testing.T) { + s := &srtpSSRCState{ + ssrc: defaultSsrc, + index: maxROC << 16, + } + s.updateRolloverCount(0xFFFF, 0) + _, _, ovf := s.nextRolloverCount(0) + if !ovf { + t.Error("Should overflow") + } } func buildTestContext(profile ProtectionProfile, opts ...ContextOption) (*Context, error) { @@ -541,57 +577,87 @@ func BenchmarkDecryptRTP(b *testing.B) { func TestRolloverCount2(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} - roc, diff := s.nextRolloverCount(30123) + roc, diff, ovf := s.nextRolloverCount(30123) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(30123, diff) - roc, diff = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 + roc, diff, ovf = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(62892, diff) - roc, diff = s.nextRolloverCount(204) + roc, diff, ovf = s.nextRolloverCount(204) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(62892, diff) - roc, diff = s.nextRolloverCount(64535) + roc, diff, ovf = s.nextRolloverCount(64535) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(64535, diff) - roc, diff = s.nextRolloverCount(205) + roc, diff, ovf = s.nextRolloverCount(205) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(205, diff) - roc, diff = s.nextRolloverCount(1) + roc, diff, ovf = s.nextRolloverCount(1) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(1, diff) - roc, diff = s.nextRolloverCount(64532) + roc, diff, ovf = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(64532, diff) - roc, diff = s.nextRolloverCount(65534) + roc, diff, ovf = s.nextRolloverCount(65534) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(65534, diff) - roc, diff = s.nextRolloverCount(64532) + roc, diff, ovf = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(65532, diff) - roc, diff = s.nextRolloverCount(205) + roc, diff, ovf = s.nextRolloverCount(205) if roc != 1 { t.Errorf("index was not updated after it crossed 0") } + if ovf { + t.Error("Should not overflow") + } s.updateRolloverCount(65532, diff) } @@ -660,3 +726,51 @@ func TestRTPDecryptShotenedPacket(t *testing.T) { }) } } + +func TestRTPMaxPackets(t *testing.T) { + profiles := map[string]ProtectionProfile{ + "CTR": profileCTR, + "GCM": profileGCM, + } + for name, profile := range profiles { + profile := profile + t.Run(name, func(t *testing.T) { + context, err := buildTestContext(profile) + if err != nil { + t.Fatal(err) + } + + context.SetROC(1, (1<<32)-1) + + pkt0 := &rtp.Packet{ + Header: rtp.Header{ + SSRC: 1, + SequenceNumber: 0xffff, + }, + Payload: []byte{0, 1}, + } + raw0, err0 := pkt0.Marshal() + if err0 != nil { + t.Fatal(err0) + } + if _, errEnc := context.EncryptRTP(nil, raw0, nil); errEnc != nil { + t.Fatal(errEnc) + } + + pkt1 := &rtp.Packet{ + Header: rtp.Header{ + SSRC: 1, + SequenceNumber: 0x0, + }, + Payload: []byte{0, 1}, + } + raw1, err1 := pkt1.Marshal() + if err1 != nil { + t.Fatal(err1) + } + if _, errEnc := context.EncryptRTP(nil, raw1, nil); !errors.Is(errEnc, errExceededMaxPackets) { + t.Fatalf("Expected error '%v', got '%v'", errExceededMaxPackets, errEnc) + } + }) + } +}