diff --git a/api.go b/api.go index 60ac72809d..baee204c71 100644 --- a/api.go +++ b/api.go @@ -7,19 +7,22 @@ import ( "github.com/pion/logging" ) -// API bundles the global functions of the WebRTC and ORTC API. -// Some of these functions are also exported globally using the -// defaultAPI object. Note that the global version of the API -// may be phased out in the future. +// API allows configuration of a PeerConnection +// with APIs that are available in the standard. This +// lets you set custom behavior via the SettingEngine, configure +// codecs via the MediaEngine and define custom media behaviors via +// Interceptors. type API struct { - settingEngine *SettingEngine - mediaEngine *MediaEngine - interceptor interceptor.Interceptor + settingEngine *SettingEngine + mediaEngine *MediaEngine + interceptorRegistry *interceptor.Registry + + interceptor interceptor.Interceptor // Generated per PeerConnection } // NewAPI Creates a new API object for keeping semi-global settings to WebRTC objects func NewAPI(options ...func(*API)) *API { - a := &API{} + a := &API{interceptor: &interceptor.NoOp{}} for _, o := range options { o(a) @@ -37,8 +40,8 @@ func NewAPI(options ...func(*API)) *API { a.mediaEngine = &MediaEngine{} } - if a.interceptor == nil { - a.interceptor = &interceptor.NoOp{} + if a.interceptorRegistry == nil { + a.interceptorRegistry = &interceptor.Registry{} } return a @@ -68,6 +71,6 @@ func WithSettingEngine(s SettingEngine) func(a *API) { // Settings should not be changed after passing the registry to an API. func WithInterceptorRegistry(interceptorRegistry *interceptor.Registry) func(a *API) { return func(a *API) { - a.interceptor = interceptorRegistry.Build() + a.interceptorRegistry = interceptorRegistry } } diff --git a/go.mod b/go.mod index 794185cd3e..ce7fcf763b 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/pion/datachannel v1.4.21 github.com/pion/dtls/v2 v2.0.9 github.com/pion/ice/v2 v2.1.12 - github.com/pion/interceptor v0.0.19 + github.com/pion/interceptor v0.1.0 github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/rtcp v1.2.8 diff --git a/go.sum b/go.sum index c9349fdc69..770901c19e 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,8 @@ github.com/pion/dtls/v2 v2.0.9 h1:7Ow+V++YSZQMYzggI0P9vLJz/hUFcffsfGMfT/Qy+u8= github.com/pion/dtls/v2 v2.0.9/go.mod h1:O0Wr7si/Zj5/EBFlDzDd6UtVxx25CE1r7XM7BQKYQho= github.com/pion/ice/v2 v2.1.12 h1:ZDBuZz+fEI7iDifZCYFVzI4p0Foy0YhdSSZ87ZtRcRE= github.com/pion/ice/v2 v2.1.12/go.mod h1:ovgYHUmwYLlRvcCLI67PnQ5YGe+upXZbGgllBDG/ktU= -github.com/pion/interceptor v0.0.19 h1:NkxrKHVH7ulrkVHTcZRJubgsF1oJeLQUvMsX1Kqm8to= -github.com/pion/interceptor v0.0.19/go.mod h1:mv0Q0oPHxjRY8xz5v85G6aIqb1Tb0G0mxrZOaewHiVo= +github.com/pion/interceptor v0.1.0 h1:SlXKaDlEvSl7cr4j8fJykzVz4UdH+7UDtcvx+u01wLU= +github.com/pion/interceptor v0.1.0/go.mod h1:j5NIl3tJJPB3u8+Z2Xz8MZs/VV6rc+If9mXEKNuFmEM= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= @@ -52,7 +52,6 @@ github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01 github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= -github.com/pion/rtcp v1.2.7/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= github.com/pion/rtcp v1.2.8 h1:Cys8X6r0xxU65ESTmXkqr8eU1Q1Wx+lNkoZCUH4zD7E= github.com/pion/rtcp v1.2.8/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= github.com/pion/rtp v1.7.0/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= diff --git a/interceptor_test.go b/interceptor_test.go index 4915752cb7..7bca7fc135 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -33,26 +33,30 @@ func TestPeerConnection_Interceptor(t *testing.T) { assert.NoError(t, m.RegisterDefaultCodecs()) ir := &interceptor.Registry{} - ir.Add(&mock_interceptor.Interceptor{ - BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { - return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { - // set extension on outgoing packet - header.Extension = true - header.ExtensionProfile = 0xBEDE - assert.NoError(t, header.SetExtension(2, []byte("foo"))) - - return writer.Write(header, payload, attributes) - }) - }, - BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { - return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { - if a == nil { - a = interceptor.Attributes{} - } - - a.Set("attribute", "value") - return reader.Read(b, a) - }) + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { + return &mock_interceptor.Interceptor{ + BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + // set extension on outgoing packet + header.Extension = true + header.ExtensionProfile = 0xBEDE + assert.NoError(t, header.SetExtension(2, []byte("foo"))) + + return writer.Write(header, payload, attributes) + }) + }, + BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if a == nil { + a = interceptor.Attributes{} + } + + a.Set("attribute", "value") + return reader.Read(b, a) + }) + }, + }, nil }, }) @@ -148,7 +152,9 @@ func Test_Interceptor_BindUnbind(t *testing.T) { }, } ir := &interceptor.Registry{} - ir.Add(mockInterceptor) + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil }, + }) sender, receiver, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{}) assert.NoError(t, err) @@ -209,3 +215,24 @@ func Test_Interceptor_BindUnbind(t *testing.T) { t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt) } } + +func Test_InterceptorRegistry_Build(t *testing.T) { + registryBuildCount := 0 + + ir := &interceptor.Registry{} + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { + registryBuildCount++ + return &interceptor.NoOp{}, nil + }, + }) + + peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + assert.Equal(t, 2, registryBuildCount) + closePairNow(t, peerConnectionA, peerConnectionB) +} diff --git a/peerconnection.go b/peerconnection.go index 67f0c2a036..92db7d3bba 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -134,15 +134,22 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, pc.iceConnectionState.Store(ICEConnectionStateNew) pc.connectionState.Store(PeerConnectionStateNew) - if !api.settingEngine.disableMediaEngineCopy { - pc.api = &API{ - settingEngine: api.settingEngine, - mediaEngine: api.mediaEngine.copy(), - interceptor: api.interceptor, - } + i, err := api.interceptorRegistry.Build("") + if err != nil { + return nil, err + } + + pc.api = &API{ + settingEngine: api.settingEngine, + interceptor: i, + } + + if api.settingEngine.disableMediaEngineCopy { + pc.api.mediaEngine = api.mediaEngine + } else { + pc.api.mediaEngine = api.mediaEngine.copy() } - var err error if err = pc.initConfiguration(configuration); err != nil { return nil, err } @@ -176,7 +183,7 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, } }) - pc.interceptorRTCPWriter = api.interceptor.BindRTCPWriter(interceptor.RTCPWriterFunc(pc.writeRTCP)) + pc.interceptorRTCPWriter = pc.api.interceptor.BindRTCPWriter(interceptor.RTCPWriterFunc(pc.writeRTCP)) return pc, nil }