diff --git a/kv.go b/kv.go index e55250144..5fd65035b 100644 --- a/kv.go +++ b/kv.go @@ -20,7 +20,7 @@ import ( "regexp" "strconv" "strings" - "sync" + "sync/atomic" "time" ) @@ -594,9 +594,9 @@ func (kv *kvs) History(key string, opts ...WatchOpt) ([]KeyValueEntry, error) { // Implementation for Watch type watcher struct { - mu sync.Mutex updates chan KeyValueEntry sub *Subscription + closed uint32 } // Updates returns the interior channel. @@ -604,22 +604,17 @@ func (w *watcher) Updates() <-chan KeyValueEntry { if w == nil { return nil } - w.mu.Lock() - defer w.mu.Unlock() return w.updates } -// close the update chan. func (w *watcher) close() { - if w == nil { - return - } - w.mu.Lock() - if w.updates != nil { + if atomic.CompareAndSwapUint32(&w.closed, 0, 1) { close(w.updates) - w.updates = nil } - w.mu.Unlock() +} + +func (w *watcher) isClosed() bool { + return atomic.LoadUint32(&w.closed) > 0 } // Stop will unsubscribe from the watcher. @@ -627,8 +622,9 @@ func (w *watcher) Stop() error { if w == nil { return nil } + err := w.sub.Unsubscribe() w.close() - return w.sub.Unsubscribe() + return err } // WatchAll watches all keys. @@ -690,11 +686,7 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { delta: delta, op: op, } - w.mu.Lock() - if w.updates != nil { - w.updates <- entry - } - w.mu.Unlock() + w.updates <- entry } // Check if done and initial values. if !initDoneMarker { @@ -705,11 +697,7 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { } if received > initPending || delta == 0 { initDoneMarker = true - w.mu.Lock() - if w.updates != nil { - w.updates <- nil - } - w.mu.Unlock() + w.updates <- nil } } } diff --git a/nats.go b/nats.go index fb9c2d0c3..877427b76 100644 --- a/nats.go +++ b/nats.go @@ -3929,7 +3929,8 @@ func (nc *Conn) removeSub(s *Subscription) { s.mch = nil // If JS subscription then stop HB timer. - if jsi := s.jsi; jsi != nil { + jsi := s.jsi + if jsi != nil { if jsi.hbc != nil { jsi.hbc.Stop() jsi.hbc = nil @@ -3938,11 +3939,6 @@ func (nc *Conn) removeSub(s *Subscription) { jsi.csfct.Stop() jsi.csfct = nil } - // Check on any watcher. If we have one close the update chan. - if jsi.w != nil { - jsi.w.close() - jsi.w = nil - } } // Mark as invalid @@ -3950,6 +3946,13 @@ func (nc *Conn) removeSub(s *Subscription) { if s.pCond != nil { s.pCond.Broadcast() } + + // Check for watchers. + if jsi != nil && jsi.w != nil { + // Check on any watcher. If we have one close the update chan. + jsi.w.close() + jsi.w = nil + } } // SubscriptionType is the type of the Subscription.