From a603db1d0eb411c96606843033eb27a9e98b8be6 Mon Sep 17 00:00:00 2001 From: Tim Ramlot <42113979+inteon@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:19:51 +0200 Subject: [PATCH] only allow starting a source once Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com> --- pkg/controller/controller_test.go | 2 +- pkg/internal/controller/controller_test.go | 6 +- pkg/internal/source/kind.go | 31 +- pkg/source/example_test.go | 2 +- pkg/source/source.go | 323 ++++++++++++++++----- pkg/source/source_test.go | 16 +- 6 files changed, 288 insertions(+), 92 deletions(-) diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index 0454cb4b90..237565fa07 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -79,7 +79,7 @@ var _ = Describe("controller.Controller", func() { ctx, cancel := context.WithCancel(context.Background()) watchChan := make(chan event.GenericEvent, 1) - watch := source.Channel(watchChan, &handler.EnqueueRequestForObject{}) + watch := source.Channel(source.NewChannelBroadcaster(watchChan), &handler.EnqueueRequestForObject{}) watchChan <- event.GenericEvent{Object: &corev1.Pod{}} reconcileStarted := make(chan struct{}) diff --git a/pkg/internal/controller/controller_test.go b/pkg/internal/controller/controller_test.go index 2e1842d907..d47cf5e249 100644 --- a/pkg/internal/controller/controller_test.go +++ b/pkg/internal/controller/controller_test.go @@ -227,7 +227,7 @@ var _ = Describe("controller", func() { } ins := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ GenericFunc: func(ctx context.Context, evt event.GenericEvent, q workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -248,7 +248,7 @@ var _ = Describe("controller", func() { <-processed }) - It("should error when channel source is not specified", func() { + It("should error when ChannelBroadcaster is not specified", func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -257,7 +257,7 @@ var _ = Describe("controller", func() { e := ctrl.Start(ctx) Expect(e).To(HaveOccurred()) - Expect(e.Error()).To(ContainSubstring("must specify Channel.Source")) + Expect(e.Error()).To(ContainSubstring("must create Channel with a non-nil broadcaster")) }) It("should call Start on sources with the appropriate EventHandler, Queue, and Predicates", func() { diff --git a/pkg/internal/source/kind.go b/pkg/internal/source/kind.go index 03431d1d24..d0b9c81a1e 100644 --- a/pkg/internal/source/kind.go +++ b/pkg/internal/source/kind.go @@ -1,3 +1,19 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package internal import ( @@ -5,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "sync" "time" "k8s.io/apimachinery/pkg/api/meta" @@ -30,6 +47,9 @@ type Kind[T client.Object] struct { Predicates []predicate.TypedPredicate[T] + mu sync.RWMutex + isStarted bool + // startedErr may contain an error if one was encountered during startup. If its closed and does not // contain an error, startup and syncing finished. startedErr chan error @@ -40,14 +60,21 @@ type Kind[T client.Object] struct { // to enqueue reconcile.Requests. func (ks *Kind[T]) Start(ctx context.Context, queue workqueue.RateLimitingInterface) error { if isNil(ks.Type) { - return fmt.Errorf("must create Kind with a non-nil object") + return fmt.Errorf("must create Kind with a non-nil type") } if isNil(ks.Cache) { return fmt.Errorf("must create Kind with a non-nil cache") } if isNil(ks.Handler) { - return errors.New("must create Kind with non-nil handler") + return errors.New("must create Kind with a non-nil handler") + } + + ks.mu.Lock() + defer ks.mu.Unlock() + if ks.isStarted { + return fmt.Errorf("cannot start an already started Kind source") } + ks.isStarted = true // cache.GetInformer will block until its context is cancelled if the cache was already started and it can not // sync that informer (most commonly due to RBAC issues). diff --git a/pkg/source/example_test.go b/pkg/source/example_test.go index b596ff0a0a..6ba5acacdc 100644 --- a/pkg/source/example_test.go +++ b/pkg/source/example_test.go @@ -44,7 +44,7 @@ func ExampleChannel() { err := ctrl.Watch( source.Channel( - events, + source.NewChannelBroadcaster(events), &handler.EnqueueRequestForObject{}, ), ) diff --git a/pkg/source/source.go b/pkg/source/source.go index 26e53022bf..adb3b819c3 100644 --- a/pkg/source/source.go +++ b/pkg/source/source.go @@ -23,14 +23,13 @@ import ( "sync" "k8s.io/client-go/util/workqueue" - "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" internal "sigs.k8s.io/controller-runtime/pkg/internal/source" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/predicate" ) // Source is a source of events (e.g. Create, Update, Delete operations on Kubernetes Objects, Webhook callbacks, etc) @@ -54,6 +53,9 @@ type SyncingSource interface { WaitForSync(ctx context.Context) error } +var _ Source = &internal.Kind[client.Object]{} +var _ SyncingSource = &internal.Kind[client.Object]{} + // Kind creates a KindSource with the given cache provider. func Kind[T client.Object](cache cache.Cache, object T, handler handler.TypedEventHandler[T], predicates ...predicate.TypedPredicate[T]) SyncingSource { return &internal.Kind[T]{ @@ -80,17 +82,18 @@ func WithPredicates[T any](p ...predicate.TypedPredicate[T]) ChannelOpt[T] { // default, the buffer size is 1024. func WithBufferSize[T any](bufferSize int) ChannelOpt[T] { return func(c *channel[T]) { - c.bufferSize = &bufferSize + c.bufferSize = bufferSize } } // Channel is used to provide a source of events originating outside the cluster // (e.g. GitHub Webhook callback). Channel requires the user to wire the external // source (e.g. http handler) to write GenericEvents to the underlying channel. -func Channel[T any](source <-chan event.TypedGenericEvent[T], handler handler.TypedEventHandler[T], opts ...ChannelOpt[T]) Source { +func Channel[T any](broadcaster *channelBroadcaster[T], handler handler.TypedEventHandler[T], opts ...ChannelOpt[T]) Source { c := &channel[T]{ - source: source, - handler: handler, + broadcaster: broadcaster, + handler: handler, + bufferSize: 1024, } for _, opt := range opts { opt(c) @@ -99,24 +102,21 @@ func Channel[T any](source <-chan event.TypedGenericEvent[T], handler handler.Ty return c } -type channel[T any] struct { - // once ensures the event distribution goroutine will be performed only once - once sync.Once +var _ Source = &channel[string]{} - // source is the source channel to fetch GenericEvents - source <-chan event.TypedGenericEvent[T] +type channel[T any] struct { + // broadcaster contains the source channel for events. + broadcaster *channelBroadcaster[T] handler handler.TypedEventHandler[T] predicates []predicate.TypedPredicate[T] - bufferSize *int - - // dest is the destination channels of the added event handlers - dest []chan event.TypedGenericEvent[T] + bufferSize int - // destLock is to ensure the destination channels are safely added/removed - destLock sync.Mutex + mu sync.Mutex + // isStarted is true if the source has been started. A source can only be started once. + isStarted bool } func (cs *channel[T]) String() string { @@ -129,113 +129,282 @@ func (cs *channel[T]) Start( queue workqueue.RateLimitingInterface, ) error { // Source should have been specified by the user. - if cs.source == nil { - return fmt.Errorf("must specify Channel.Source") + if cs.broadcaster == nil { + return fmt.Errorf("must create Channel with a non-nil broadcaster") } if cs.handler == nil { - return errors.New("must specify Channel.Handler") + return errors.New("must create Channel with a non-nil handler") + } + if cs.bufferSize == 0 { + return errors.New("must create Channel with a >0 bufferSize") } - if cs.bufferSize == nil { - cs.bufferSize = ptr.To(1024) + cs.mu.Lock() + defer cs.mu.Unlock() + if cs.isStarted { + return fmt.Errorf("cannot start an already started Channel source") } + cs.isStarted = true - dst := make(chan event.TypedGenericEvent[T], *cs.bufferSize) + // Create a destination channel for the event handler + // and add it to the list of destinations + destination := make(chan event.TypedGenericEvent[T], cs.bufferSize) + cs.broadcaster.AddListener(destination) - cs.destLock.Lock() - cs.dest = append(cs.dest, dst) - cs.destLock.Unlock() + go func() { + // Remove the listener and wait for the broadcaster + // to stop sending events to the destination channel. + defer cs.broadcaster.RemoveListener(destination) + + cs.processReceivedEvents( + ctx, + destination, + queue, + cs.handler, + cs.predicates, + ) + }() - cs.once.Do(func() { - // Distribute GenericEvents to all EventHandler / Queue pairs Watching this source - go cs.syncLoop(ctx) - }) + return nil +} - go func() { - for evt := range dst { - shouldHandle := true - for _, p := range cs.predicates { - if !p.Generic(evt) { - shouldHandle = false - break - } +func (cs *channel[T]) processReceivedEvents( + ctx context.Context, + destination <-chan event.TypedGenericEvent[T], + queue workqueue.RateLimitingInterface, + eventHandler handler.TypedEventHandler[T], + predicates []predicate.TypedPredicate[T], +) { +eventloop: + for { + select { + case <-ctx.Done(): + return + case event, stillOpen := <-destination: + if !stillOpen { + return } - if shouldHandle { - func() { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - cs.handler.Generic(ctx, evt, queue) - }() + // Check predicates against the event first + // and continue the outer loop if any of them fail. + for _, p := range predicates { + if !p.Generic(event) { + continue eventloop + } } + + // Call the event handler with the event. + eventHandler.Generic(ctx, event, queue) } - }() + } +} - return nil +// NewChannelBroadcaster creates a new ChannelBroadcaster for the given channel. +// A ChannelBroadcaster is a wrapper around a channel that allows multiple listeners to all +// receive the events from the channel. +func NewChannelBroadcaster[T any](source <-chan event.TypedGenericEvent[T]) *channelBroadcaster[T] { + return &channelBroadcaster[T]{ + source: source, + } } -func (cs *channel[T]) doStop() { - cs.destLock.Lock() - defer cs.destLock.Unlock() +// ChannelBroadcaster is a wrapper around a channel that allows multiple listeners to all +// receive the events from the channel. +type channelBroadcaster[T any] struct { + source <-chan event.TypedGenericEvent[T] - for _, dst := range cs.dest { - close(dst) - } + mu sync.Mutex + rcCount uint + managementCh chan managementMsg[T] + doneCh chan struct{} } -func (cs *channel[T]) distribute(evt event.TypedGenericEvent[T]) { - cs.destLock.Lock() - defer cs.destLock.Unlock() - - for _, dst := range cs.dest { - // We cannot make it under goroutine here, or we'll meet the - // race condition of writing message to closed channels. - // To avoid blocking, the dest channels are expected to be of - // proper buffer size. If we still see it blocked, then - // the controller is thought to be in an abnormal state. - dst <- evt +type managementOperation bool + +const ( + addChannel managementOperation = true + removeChannel managementOperation = false +) + +type managementMsg[T any] struct { + operation managementOperation + ch chan event.TypedGenericEvent[T] +} + +// AddListener adds a new listener to the ChannelBroadcaster. Each listener +// will receive all events from the source channel. All listeners have to be +// removed using RemoveListener before the ChannelBroadcaster can be garbage +// collected. +func (sc *channelBroadcaster[T]) AddListener(ch chan event.TypedGenericEvent[T]) { + var managementCh chan managementMsg[T] + var doneCh chan struct{} + isFirst := false + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + isFirst = sc.rcCount == 0 + sc.rcCount++ + + if isFirst { + sc.managementCh = make(chan managementMsg[T]) + sc.doneCh = make(chan struct{}) + } + + managementCh = sc.managementCh + doneCh = sc.doneCh + }() + + if isFirst { + go startLoop(sc.Source, managementCh, doneCh) + } + + // If the goroutine is not yet stopped, send a message to add the + // destination channel. The routine might be stopped already because + // the source channel was closed. + select { + case <-doneCh: + default: + managementCh <- managementMsg[T]{ + operation: addChannel, + ch: ch, + } } } -func (cs *channel[T]) syncLoop(ctx context.Context) { +func startLoop[T any]( + source <-chan event.TypedGenericEvent[T], + managementCh chan managementMsg[T], + doneCh chan struct{}, +) { + defer close(doneCh) + + var destinations []chan event.TypedGenericEvent[T] + + // Close all remaining destinations in case the Source channel is closed. + defer func() { + for _, dst := range destinations { + close(dst) + } + }() + + // Wait for the first destination to be added before starting the loop. + for len(destinations) == 0 { + managementMsg := <-managementCh + if managementMsg.operation == addChannel { + destinations = append(destinations, managementMsg.ch) + } + } + for { select { - case <-ctx.Done(): - // Close destination channels - cs.doStop() - return - case evt, stillOpen := <-cs.source: + case msg := <-managementCh: + + switch msg.operation { + case addChannel: + destinations = append(destinations, msg.ch) + case removeChannel: + SearchLoop: + for i, dst := range destinations { + if dst == msg.ch { + destinations = append(destinations[:i], destinations[i+1:]...) + close(dst) + break SearchLoop + } + } + + if len(destinations) == 0 { + return + } + } + + case evt, stillOpen := <-source: if !stillOpen { - // if the source channel is closed, we're never gonna get - // anything more on it, so stop & bail - cs.doStop() return } - cs.distribute(evt) + + for _, dst := range destinations { + // We cannot make it under goroutine here, or we'll meet the + // race condition of writing message to closed channels. + // To avoid blocking, the dest channels are expected to be of + // proper buffer size. If we still see it blocked, then + // the controller is thought to be in an abnormal state. + dst <- evt + } } } } +// RemoveListener removes a listener from the ChannelBroadcaster. The listener +// will no longer receive events from the source channel. If this is the last +// listener, this function will block until the ChannelBroadcaster's is stopped. +func (sc *channelBroadcaster[T]) RemoveListener(ch chan event.TypedGenericEvent[T]) { + var managementCh chan managementMsg[T] + var doneCh chan struct{} + isLast := false + func() { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.rcCount-- + isLast = sc.rcCount == 0 + + managementCh = sc.managementCh + doneCh = sc.doneCh + }() + + // If the goroutine is not yet stopped, send a message to remove the + // destination channel. The routine might be stopped already because + // the source channel was closed. + select { + case <-doneCh: + default: + managementCh <- managementMsg[T]{ + operation: removeChannel, + ch: ch, + } + } + + // Wait for the doneCh to be closed (in case we are the last one) + if isLast { + <-doneCh + } + + // Wait for the destination channel to be closed. + <-ch +} + +var _ Source = &Informer{} + // Informer is used to provide a source of events originating inside the cluster from Watches (e.g. Pod Create). type Informer struct { // Informer is the controller-runtime Informer Informer cache.Informer Handler handler.EventHandler Predicates []predicate.Predicate -} -var _ Source = &Informer{} + mu sync.Mutex + // isStarted is true if the source has been started. A source can only be started once. + isStarted bool +} // Start is internal and should be called only by the Controller to register an EventHandler with the Informer // to enqueue reconcile.Requests. func (is *Informer) Start(ctx context.Context, queue workqueue.RateLimitingInterface) error { // Informer should have been specified by the user. if is.Informer == nil { - return fmt.Errorf("must specify Informer.Informer") + return fmt.Errorf("must create Informer with a non-nil informer") } if is.Handler == nil { - return errors.New("must specify Informer.Handler") + return errors.New("must create Informer with a non-nil handler") + } + + is.mu.Lock() + defer is.mu.Unlock() + if is.isStarted { + return fmt.Errorf("cannot start an already started Informer source") } + is.isStarted = true _, err := is.Informer.AddEventHandler(internal.NewEventHandler(ctx, queue, is.Handler, is.Predicates).HandlerFuncs()) if err != nil { diff --git a/pkg/source/source_test.go b/pkg/source/source_test.go index d30d5ae5c7..f1522c7f48 100644 --- a/pkg/source/source_test.go +++ b/pkg/source/source_test.go @@ -191,13 +191,13 @@ var _ = Describe("Source", func() { instance := source.Kind[client.Object](ic, nil, nil) err := instance.Start(ctx, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil object")) + Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil type")) }) It("should return an error from Start if a handler was not provided", func() { instance := source.Kind(ic, &corev1.Pod{}, nil) err := instance.Start(ctx, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("must create Kind with non-nil handler")) + Expect(err.Error()).To(ContainSubstring("must create Kind with a non-nil handler")) }) It("should return an error if syncing fails", func() { @@ -295,7 +295,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -337,7 +337,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") // Add a handler to get distribution blocked instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -395,7 +395,7 @@ var _ = Describe("Source", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") // Add a handler to get distribution blocked instance := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -438,7 +438,7 @@ var _ = Describe("Source", func() { processed := make(chan struct{}) defer close(processed) src := source.Channel( - ch, + source.NewChannelBroadcaster(ch), handler.Funcs{ CreateFunc: func(context.Context, event.CreateEvent, workqueue.RateLimitingInterface) { defer GinkgoRecover() @@ -467,11 +467,11 @@ var _ = Describe("Source", func() { Eventually(processed).Should(Receive()) Consistently(processed).ShouldNot(Receive()) }) - It("should get error if no source specified", func() { + It("should get error if no Broadcaster specified", func() { q := workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "test") instance := source.Channel[string](nil, nil /*no source specified*/) err := instance.Start(ctx, q) - Expect(err).To(Equal(fmt.Errorf("must specify Channel.Source"))) + Expect(err).To(Equal(fmt.Errorf("must create Channel with a non-nil broadcaster"))) }) }) })