diff --git a/pubsublite/internal/wire/publisher.go b/pubsublite/internal/wire/publisher.go index 215288f08cb..a579ef26709 100644 --- a/pubsublite/internal/wire/publisher.go +++ b/pubsublite/internal/wire/publisher.go @@ -287,16 +287,16 @@ type routingPublisher struct { msgRouter messageRouter publishers []*singlePartitionPublisher - apiClientService + compositeService } func newRoutingPublisher(allClients apiClients, adminClient *vkit.AdminClient, msgRouterFactory *messageRouterFactory, pubFactory *singlePartitionPublisherFactory) *routingPublisher { pub := &routingPublisher{ - apiClientService: apiClientService{clients: allClients}, msgRouterFactory: msgRouterFactory, pubFactory: pubFactory, } pub.init() + pub.toClose = allClients pub.partitionWatcher = newPartitionCountWatcher(pubFactory.ctx, adminClient, pubFactory.settings, pubFactory.topicPath, pub.onPartitionCountChanged) pub.unsafeAddServices(pub.partitionWatcher) return pub diff --git a/pubsublite/internal/wire/publisher_test.go b/pubsublite/internal/wire/publisher_test.go index 47eda71f5be..ca060782e6a 100644 --- a/pubsublite/internal/wire/publisher_test.go +++ b/pubsublite/internal/wire/publisher_test.go @@ -601,8 +601,7 @@ func TestRoutingPublisherStartStop(t *testing.T) { defer mockServer.OnTestEnd() pub := newTestRoutingPublisher(t, topic, testPublishSettings(), 0) - pub.Stop() - barrier.Release() + barrier.ReleaseAfter(func() { pub.Stop() }) if gotErr := pub.WaitStopped(); gotErr != nil { t.Errorf("Stop() got err: (%v)", gotErr) diff --git a/pubsublite/internal/wire/service.go b/pubsublite/internal/wire/service.go index 2ae68d106a8..d6ec5984324 100644 --- a/pubsublite/internal/wire/service.go +++ b/pubsublite/internal/wire/service.go @@ -155,6 +155,21 @@ func (as *abstractService) unsafeUpdateStatus(targetStatus serviceStatus, err er return true } +type closeable interface { + Close() error +} + +type apiClients []closeable + +func (ac apiClients) Close() (retErr error) { + for _, c := range ac { + if err := c.Close(); retErr == nil { + retErr = err + } + } + return +} + var errChildServiceStarted = errors.New("pubsublite: dependent service must not be started") // compositeService can be embedded into other structs to manage child services. @@ -173,6 +188,9 @@ type compositeService struct { // Removed dependencies that are in the process of terminating. removed map[serviceHandle]service + // Dependencies to close when the compositeService has terminated. + toClose closeable + abstractService } @@ -267,6 +285,9 @@ func (cs *compositeService) unsafeUpdateStatus(targetStatus serviceStatus, err e close(cs.waitStarted) } if targetStatus == serviceTerminated { + if cs.toClose != nil { + cs.toClose.Close() + } close(cs.waitTerminated) } } @@ -317,39 +338,3 @@ func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status s cs.unsafeUpdateStatus(serviceActive, err) } } - -type apiClient interface { - Close() error -} - -type apiClients []apiClient - -func (ac apiClients) Close() (retErr error) { - for _, c := range ac { - if err := c.Close(); retErr == nil { - retErr = err - } - } - return -} - -// A compositeService that handles closing API clients on shutdown. -type apiClientService struct { - clients apiClients - - compositeService -} - -func (acs *apiClientService) WaitStarted() error { - err := acs.compositeService.WaitStarted() - if err != nil { - acs.WaitStopped() - } - return err -} - -func (acs *apiClientService) WaitStopped() error { - err := acs.compositeService.WaitStopped() - acs.clients.Close() - return err -} diff --git a/pubsublite/internal/wire/service_test.go b/pubsublite/internal/wire/service_test.go index c6910296950..102ccfd46b8 100644 --- a/pubsublite/internal/wire/service_test.go +++ b/pubsublite/internal/wire/service_test.go @@ -16,6 +16,7 @@ package wire import ( "errors" "fmt" + "sync" "testing" "time" @@ -46,6 +47,7 @@ func (sr *testStatusChangeReceiver) OnStatusChange(handle serviceHandle, status } func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatus) { + t.Helper() select { case status := <-sr.statusC: if status <= sr.lastStatus { @@ -61,6 +63,7 @@ func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatu } func (sr *testStatusChangeReceiver) VerifyNoStatusChanges(t *testing.T) { + t.Helper() select { case status := <-sr.statusC: t.Errorf("%s: Unexpected service status: %d", sr.name, status) @@ -189,16 +192,39 @@ func TestServiceAddRemoveStatusChangeReceiver(t *testing.T) { }) } +type testCloseable struct { + mu sync.Mutex + closed bool +} + +func (tc *testCloseable) Close() error { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.closed = true + return nil +} + +func (tc *testCloseable) IsClosed() bool { + tc.mu.Lock() + defer tc.mu.Unlock() + return tc.closed +} + type testCompositeService struct { - receiver *testStatusChangeReceiver + receiver *testStatusChangeReceiver + closeable *testCloseable compositeService } func newTestCompositeService(name string) *testCompositeService { receiver := newTestStatusChangeReceiver(name) - ts := &testCompositeService{receiver: receiver} + ts := &testCompositeService{ + receiver: receiver, + closeable: &testCloseable{}, + } ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange) ts.init() + ts.toClose = ts.closeable return ts } @@ -226,6 +252,13 @@ func (ts *testCompositeService) RemovedLen() int { return len(ts.removed) } +func (ts *testCompositeService) VerifyClosed(t *testing.T, want bool) { + t.Helper() + if got := ts.closeable.IsClosed(); got != want { + t.Errorf("closed: got %v, want %v", got, want) + } +} + func TestCompositeServiceNormalStop(t *testing.T) { child1 := newTestService("child1") child2 := newTestService("child2") @@ -258,6 +291,7 @@ func TestCompositeServiceNormalStop(t *testing.T) { t.Errorf("AddServices() got err: %v", err) } child3.receiver.VerifyStatus(t, serviceStarting) + parent.VerifyClosed(t, false) }) t.Run("Active", func(t *testing.T) { @@ -274,6 +308,7 @@ func TestCompositeServiceNormalStop(t *testing.T) { if err := parent.WaitStarted(); err != nil { t.Errorf("compositeService.WaitStarted() got err: %v", err) } + parent.VerifyClosed(t, false) }) t.Run("Stopping", func(t *testing.T) { @@ -288,6 +323,7 @@ func TestCompositeServiceNormalStop(t *testing.T) { child1.UpdateStatus(serviceTerminated, nil) child2.UpdateStatus(serviceTerminated, nil) parent.receiver.VerifyNoStatusChanges(t) + parent.VerifyClosed(t, false) child3.UpdateStatus(serviceTerminated, nil) child1.receiver.VerifyStatus(t, serviceTerminated) @@ -297,6 +333,7 @@ func TestCompositeServiceNormalStop(t *testing.T) { if err := parent.WaitStopped(); err != nil { t.Errorf("compositeService.WaitStopped() got err: %v", err) } + parent.VerifyClosed(t, true) }) } @@ -314,6 +351,7 @@ func TestCompositeServiceErrorDuringStartup(t *testing.T) { parent.receiver.VerifyStatus(t, serviceStarting) child1.receiver.VerifyStatus(t, serviceStarting) child2.receiver.VerifyStatus(t, serviceStarting) + parent.VerifyClosed(t, false) }) t.Run("Terminating", func(t *testing.T) { @@ -325,6 +363,7 @@ func TestCompositeServiceErrorDuringStartup(t *testing.T) { // This causes parent and child2 to start terminating. parent.receiver.VerifyStatus(t, serviceTerminating) child2.receiver.VerifyStatus(t, serviceTerminating) + parent.VerifyClosed(t, false) // parent has terminated once child2 has terminated. child2.UpdateStatus(serviceTerminated, nil) @@ -333,6 +372,7 @@ func TestCompositeServiceErrorDuringStartup(t *testing.T) { if gotErr := parent.WaitStarted(); !test.ErrorEqual(gotErr, wantErr) { t.Errorf("compositeService.WaitStarted() got err: (%v), want err: (%v)", gotErr, wantErr) } + parent.VerifyClosed(t, true) }) } @@ -350,6 +390,7 @@ func TestCompositeServiceErrorWhileActive(t *testing.T) { child1.receiver.VerifyStatus(t, serviceStarting) child2.receiver.VerifyStatus(t, serviceStarting) parent.receiver.VerifyStatus(t, serviceStarting) + parent.VerifyClosed(t, false) }) t.Run("Active", func(t *testing.T) { @@ -362,6 +403,7 @@ func TestCompositeServiceErrorWhileActive(t *testing.T) { if err := parent.WaitStarted(); err != nil { t.Errorf("compositeService.WaitStarted() got err: %v", err) } + parent.VerifyClosed(t, false) }) t.Run("Terminating", func(t *testing.T) { @@ -373,6 +415,7 @@ func TestCompositeServiceErrorWhileActive(t *testing.T) { // This causes parent and child1 to start terminating. child1.receiver.VerifyStatus(t, serviceTerminating) parent.receiver.VerifyStatus(t, serviceTerminating) + parent.VerifyClosed(t, false) // parent has terminated once both children have terminated. child1.UpdateStatus(serviceTerminated, nil) @@ -383,6 +426,7 @@ func TestCompositeServiceErrorWhileActive(t *testing.T) { if gotErr := parent.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) { t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr) } + parent.VerifyClosed(t, true) }) } @@ -400,6 +444,7 @@ func TestCompositeServiceRemoveService(t *testing.T) { child1.receiver.VerifyStatus(t, serviceStarting) child2.receiver.VerifyStatus(t, serviceStarting) parent.receiver.VerifyStatus(t, serviceStarting) + parent.VerifyClosed(t, false) }) t.Run("Active", func(t *testing.T) { @@ -409,6 +454,7 @@ func TestCompositeServiceRemoveService(t *testing.T) { child1.receiver.VerifyStatus(t, serviceActive) child2.receiver.VerifyStatus(t, serviceActive) parent.receiver.VerifyStatus(t, serviceActive) + parent.VerifyClosed(t, false) }) t.Run("Remove service", func(t *testing.T) { @@ -442,6 +488,7 @@ func TestCompositeServiceRemoveService(t *testing.T) { if got, want := parent.Status(), serviceActive; got != want { t.Errorf("compositeService.Status() got %v, want %v", got, want) } + parent.VerifyClosed(t, false) }) t.Run("Terminating", func(t *testing.T) { @@ -450,6 +497,7 @@ func TestCompositeServiceRemoveService(t *testing.T) { child2.receiver.VerifyStatus(t, serviceTerminating) parent.receiver.VerifyStatus(t, serviceTerminating) + parent.VerifyClosed(t, false) child2.UpdateStatus(serviceTerminated, nil) @@ -458,6 +506,7 @@ func TestCompositeServiceRemoveService(t *testing.T) { if err := parent.WaitStopped(); err != nil { t.Errorf("compositeService.WaitStopped() got err: %v", err) } + parent.VerifyClosed(t, true) if got, want := parent.DependenciesLen(), 1; got != want { t.Errorf("compositeService.dependencies: got len %d, want %d", got, want) @@ -500,6 +549,7 @@ func TestCompositeServiceTree(t *testing.T) { intermediate1.receiver.VerifyStatus(t, serviceStarting) intermediate2.receiver.VerifyStatus(t, serviceStarting) root.receiver.VerifyStatus(t, serviceStarting) + root.VerifyClosed(t, false) }) t.Run("Active", func(t *testing.T) { @@ -519,6 +569,7 @@ func TestCompositeServiceTree(t *testing.T) { if err := root.WaitStarted(); err != nil { t.Errorf("compositeService.WaitStarted() got err: %v", err) } + root.VerifyClosed(t, false) }) t.Run("Leaf fails", func(t *testing.T) { @@ -533,6 +584,7 @@ func TestCompositeServiceTree(t *testing.T) { intermediate1.receiver.VerifyStatus(t, serviceTerminating) intermediate2.receiver.VerifyStatus(t, serviceTerminating) root.receiver.VerifyStatus(t, serviceTerminating) + root.VerifyClosed(t, false) }) t.Run("Terminated", func(t *testing.T) { @@ -547,6 +599,7 @@ func TestCompositeServiceTree(t *testing.T) { intermediate1.receiver.VerifyStatus(t, serviceTerminated) intermediate2.receiver.VerifyStatus(t, serviceTerminated) root.receiver.VerifyStatus(t, serviceTerminated) + root.VerifyClosed(t, true) if gotErr := root.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) { t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr) diff --git a/pubsublite/internal/wire/service_util_test.go b/pubsublite/internal/wire/service_util_test.go index 60a138fb2bc..5e60ff6ee79 100644 --- a/pubsublite/internal/wire/service_util_test.go +++ b/pubsublite/internal/wire/service_util_test.go @@ -38,7 +38,7 @@ type serviceTestProxy struct { terminated chan struct{} } -func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string, clients ...apiClient) { +func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string, clients ...closeable) { sp.t = t sp.service = s sp.name = name diff --git a/pubsublite/internal/wire/subscriber.go b/pubsublite/internal/wire/subscriber.go index c357b718ad5..da3f4edd8c1 100644 --- a/pubsublite/internal/wire/subscriber.go +++ b/pubsublite/internal/wire/subscriber.go @@ -417,15 +417,15 @@ type multiPartitionSubscriber struct { // Immutable after creation. subscribers map[int]*singlePartitionSubscriber - apiClientService + compositeService } func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartitionSubscriberFactory) *multiPartitionSubscriber { ms := &multiPartitionSubscriber{ - subscribers: make(map[int]*singlePartitionSubscriber), - apiClientService: apiClientService{clients: allClients}, + subscribers: make(map[int]*singlePartitionSubscriber), } ms.init() + ms.toClose = allClients for _, partition := range subFactory.settings.Partitions { subscriber := subFactory.New(partition) @@ -468,18 +468,18 @@ type assigningSubscriber struct { // Subscribers keyed by partition number. Updated as assignments change. subscribers map[int]*singlePartitionSubscriber - apiClientService + compositeService } func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, reassignmentHandler ReassignmentHandlerFunc, genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) { as := &assigningSubscriber{ - apiClientService: apiClientService{clients: allClients}, reassignmentHandler: reassignmentHandler, subFactory: subFactory, subscribers: make(map[int]*singlePartitionSubscriber), } as.init() + as.toClose = allClients assigner, err := newAssigner(subFactory.ctx, assignmentClient, genUUID, subFactory.settings, subFactory.subscriptionPath, as.handleAssignment) if err != nil {