diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index 9fcb5aabbde3..7a4134e3ae15 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -59,6 +59,15 @@ func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exa return c.monitor(key, l, exactMatch) } +func (r *Registry) UnMonitor(endpoints []string, key string, l UpdateListener) { + c, exists := r.getCluster(endpoints) + if !exists { + return + } + + c.unMonitor(key, l) +} + func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { clusterKey := getClusterKey(endpoints) r.lock.RLock() @@ -88,6 +97,8 @@ type cluster struct { done chan lang.PlaceholderType lock sync.RWMutex exactMatch bool + watchCtx map[string]context.CancelFunc + watchFlag map[string]bool } func newCluster(endpoints []string) *cluster { @@ -98,6 +109,8 @@ func newCluster(endpoints []string) *cluster { listeners: make(map[string][]UpdateListener), watchGroup: threading.NewRoutineGroup(), done: make(chan lang.PlaceholderType), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } } @@ -260,19 +273,42 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { c.exactMatch = exactMatch c.lock.Unlock() - cli, err := c.getClient() - if err != nil { - return err - } + if !c.watchFlag[key] { + cli, err := c.getClient() + if err != nil { + return err + } - rev := c.load(cli, key) - c.watchGroup.Run(func() { - c.watch(cli, key, rev) - }) + rev := c.load(cli, key) + c.watchGroup.Run(func() { + c.watch(cli, key, rev) + }) + } return nil } +func (c *cluster) unMonitor(key string, l UpdateListener) { + c.lock.Lock() + defer c.lock.Unlock() + + listeners := c.listeners[key] + for i, listener := range listeners { + if listener == l { + c.listeners[key] = append(listeners[:i], listeners[i+1:]...) + break + } + } + + if len(c.listeners[key]) == 0 && c.watchFlag[key] { + if cancel, ok := c.watchCtx[key]; ok { + cancel() + delete(c.watchCtx, key) + } + c.watchFlag[key] = false + } +} + func (c *cluster) newClient() (EtcdClient, error) { cli, err := NewClient(c.endpoints) if err != nil { @@ -294,6 +330,7 @@ func (c *cluster) reload(cli EtcdClient) { for k := range c.listeners { keys = append(keys, k) } + c.clearWatch() c.lock.Unlock() for _, key := range keys { @@ -306,8 +343,9 @@ func (c *cluster) reload(cli EtcdClient) { } func (c *cluster) watch(cli EtcdClient, key string, rev int64) { + ctx := c.addWatch(key, cli) for { - err := c.watchStream(cli, key, rev) + err := c.watchStream(cli, key, rev, ctx) if err == nil { return } @@ -322,7 +360,7 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) { } } -func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { +func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context.Context) error { var ( rch clientv3.WatchChan ops []clientv3.OpOption @@ -336,7 +374,7 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { ops = append(ops, clientv3.WithRev(rev+1)) } - rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), watchKey, ops...) + rch = cli.Watch(clientv3.WithRequireLeader(ctx), watchKey, ops...) for { select { @@ -354,6 +392,8 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { c.handleWatchEvents(key, wresp.Events) case <-c.done: return nil + case <-ctx.Done(): + return nil } } } @@ -366,6 +406,23 @@ func (c *cluster) watchConnState(cli EtcdClient) { watcher.watch(cli.ActiveConnection()) } +func (c *cluster) addWatch(key string, cli EtcdClient) context.Context { + ctx, cancel := context.WithCancel(cli.Ctx()) + c.lock.Lock() + c.watchCtx[key] = cancel + c.watchFlag[key] = true + c.lock.Unlock() + return ctx +} + +func (c *cluster) clearWatch() { + for _, cancel := range c.watchCtx { + cancel() + } + c.watchCtx = make(map[string]context.CancelFunc) + c.watchFlag = make(map[string]bool) +} + // DialClient dials an etcd cluster with given endpoints. func DialClient(endpoints []string) (EtcdClient, error) { cfg := clientv3.Config{ diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index bb9fd629b3bd..197aad5a4de9 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -160,8 +160,10 @@ func TestCluster_Watch(t *testing.T) { var wg sync.WaitGroup wg.Add(1) c := &cluster{ - listeners: make(map[string][]UpdateListener), values: make(map[string]map[string]string), + listeners: make(map[string][]UpdateListener), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } listener := NewMockUpdateListener(ctrl) c.listeners["any"] = []UpdateListener{listener} @@ -211,7 +213,7 @@ func TestClusterWatch_RespFailures(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { ch <- resp @@ -231,7 +233,7 @@ func TestClusterWatch_CloseChan(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { close(ch) @@ -240,6 +242,37 @@ func TestClusterWatch_CloseChan(t *testing.T) { c.watch(cli, "any", 0) } +func TestClusterWatch_CtxCancel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() + ctx, cancelFunc := context.WithCancel(context.Background()) + cli.EXPECT().Ctx().Return(ctx).AnyTimes() + c := newCluster([]string{}) + c.done = make(chan lang.PlaceholderType) + go func() { + cancelFunc() + close(ch) + }() + c.watch(cli, "any", 0) +} + +func TestCluster_ClearWatch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := &cluster{ + watchCtx: map[string]context.CancelFunc{"foo": cancel}, + watchFlag: map[string]bool{"foo": true}, + } + c.clearWatch() + assert.Equal(t, ctx.Err(), context.Canceled) + assert.Equal(t, 0, len(c.watchCtx)) + assert.Equal(t, 0, len(c.watchFlag)) +} + func TestValueOnlyContext(t *testing.T) { ctx := contextx.ValueOnlyFrom(context.Background()) ctx.Done() @@ -286,12 +319,38 @@ func TestRegistry_Monitor(t *testing.T) { "bar": "baz", }, }, + watchCtx: map[string]context.CancelFunc{}, + watchFlag: map[string]bool{}, }, } GetRegistry().lock.Unlock() assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false)) } +func TestRegistry_UnMonitor(t *testing.T) { + svr, err := mockserver.StartMockServers(1) + assert.NoError(t, err) + svr.StartAt(0) + + endpoints := []string{svr.Servers[0].Address} + l := new(mockListener) + GetRegistry().lock.Lock() + GetRegistry().clusters = map[string]*cluster{ + getClusterKey(endpoints): { + listeners: map[string][]UpdateListener{"foo": {l}}, + values: map[string]map[string]string{ + "foo": { + "bar": "baz", + }, + }, + watchFlag: map[string]bool{"foo": true}, + watchCtx: map[string]context.CancelFunc{"foo": func() {}}, + }, + } + GetRegistry().lock.Unlock() + GetRegistry().UnMonitor(endpoints, "foo", l) +} + type mockListener struct { } diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go index 08f89a601ff7..515ea71d93a8 100644 --- a/core/discov/subscriber.go +++ b/core/discov/subscriber.go @@ -18,6 +18,7 @@ type ( endpoints []string exclusive bool exactMatch bool + key string items *container } ) @@ -29,6 +30,7 @@ type ( func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) { sub := &Subscriber{ endpoints: endpoints, + key: key, } for _, opt := range opts { opt(sub) @@ -52,6 +54,11 @@ func (s *Subscriber) Values() []string { return s.items.getValues() } +// Close the subscriber created watch goroutine and remove listener +func (s *Subscriber) Close() { + internal.GetRegistry().UnMonitor(s.endpoints, s.key, s.items) +} + // Exclusive means that key value can only be 1:1, // which means later added value will remove the keys associated with the same value previously. func Exclusive() SubOption { diff --git a/core/discov/subscriber_test.go b/core/discov/subscriber_test.go index 6dce7cec1c60..1f760979f4bc 100644 --- a/core/discov/subscriber_test.go +++ b/core/discov/subscriber_test.go @@ -214,6 +214,18 @@ func TestSubscriber(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&count)) } +func TestSubscriberClos(t *testing.T) { + l := newContainer(false) + sub := &Subscriber{ + endpoints: []string{"localhost:2379"}, + key: "foo", + items: l, + } + _ = internal.GetRegistry().Monitor(sub.endpoints, sub.key, l, false) + sub.Close() + assert.Empty(t, sub.items.listeners) +} + func TestWithSubEtcdAccount(t *testing.T) { endpoints := []string{"localhost:2379"} user := stringx.Rand() diff --git a/zrpc/resolver/internal/discovbuilder.go b/zrpc/resolver/internal/discovbuilder.go index ed377d138fa1..5a91ee73e40c 100644 --- a/zrpc/resolver/internal/discovbuilder.go +++ b/zrpc/resolver/internal/discovbuilder.go @@ -38,7 +38,7 @@ func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ sub.AddListener(update) update() - return &nopResolver{cc: cc}, nil + return &nopResolver{cc: cc, closeFunc: func() { sub.Close() }}, nil } func (b *discovBuilder) Scheme() string { diff --git a/zrpc/resolver/internal/resolver.go b/zrpc/resolver/internal/resolver.go index 7868eca8cecf..e04d65d875c0 100644 --- a/zrpc/resolver/internal/resolver.go +++ b/zrpc/resolver/internal/resolver.go @@ -37,10 +37,14 @@ func register() { } type nopResolver struct { - cc resolver.ClientConn + cc resolver.ClientConn + closeFunc func() } func (r *nopResolver) Close() { + if r.closeFunc != nil { + r.closeFunc() + } } func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) { diff --git a/zrpc/resolver/internal/resolver_test.go b/zrpc/resolver/internal/resolver_test.go index 7dd10ee79d10..f71605c40310 100644 --- a/zrpc/resolver/internal/resolver_test.go +++ b/zrpc/resolver/internal/resolver_test.go @@ -1,6 +1,7 @@ package internal import ( + "github.com/zeromicro/go-zero/core/discov" "testing" "github.com/stretchr/testify/assert" @@ -18,6 +19,16 @@ func TestNopResolver(t *testing.T) { }) } +func TestNopResolverClose(t *testing.T) { + assert.NotPanics(t, func() { + sub := &discov.Subscriber{} + r := nopResolver{ + closeFunc: sub.Close, + } + r.Close() + }) +} + type mockedClientConn struct { state resolver.State err error