From 6c5332ae536b07b68908163632e9f1c505a5bb81 Mon Sep 17 00:00:00 2001 From: nano <260391798@qq.com> Date: Thu, 19 Dec 2024 17:52:23 +0800 Subject: [PATCH] fix: routinegroup & etcd watch goroutine leak --- core/discov/internal/registry.go | 80 +++++++++++++++++++++---- core/discov/internal/registry_test.go | 51 ++++++++++++++-- core/discov/subscriber.go | 7 +++ core/discov/subscriber_test.go | 12 ++++ zrpc/resolver/internal/discovbuilder.go | 2 +- zrpc/resolver/internal/resolver.go | 6 +- zrpc/resolver/internal/resolver_test.go | 11 ++++ 7 files changed, 150 insertions(+), 19 deletions(-) diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index 9fcb5aabbde32..2c0848101e6ba 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -59,6 +59,15 @@ func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exa return c.monitor(key, l, exactMatch) } +func (r *Registry) UnMonitor(endpoints []string, key string, l UpdateListener) { + c, exists := r.getCluster(endpoints) + if !exists { + return + } + + c.unMonitor(key, l) +} + func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { clusterKey := getClusterKey(endpoints) r.lock.RLock() @@ -88,6 +97,8 @@ type cluster struct { done chan lang.PlaceholderType lock sync.RWMutex exactMatch bool + watchCtx map[string]context.CancelFunc + watchFlag map[string]bool } func newCluster(endpoints []string) *cluster { @@ -98,6 +109,8 @@ func newCluster(endpoints []string) *cluster { listeners: make(map[string][]UpdateListener), watchGroup: threading.NewRoutineGroup(), done: make(chan lang.PlaceholderType), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } } @@ -260,19 +273,48 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { c.exactMatch = exactMatch c.lock.Unlock() - cli, err := c.getClient() - if err != nil { - return err - } + if !c.watchFlag[key] { + cli, err := c.getClient() + if err != nil { + return err + } - rev := c.load(cli, key) - c.watchGroup.Run(func() { - c.watch(cli, key, rev) - }) + ctx, cancel := context.WithCancel(cli.Ctx()) + c.lock.Lock() + c.watchCtx[key] = cancel + c.watchFlag[key] = true + c.lock.Unlock() + + rev := c.load(cli, key) + c.watchGroup.Run(func() { + c.watch(cli, key, rev, ctx) + }) + } return nil } +func (c *cluster) unMonitor(key string, l UpdateListener) { + c.lock.Lock() + defer c.lock.Unlock() + + listeners := c.listeners[key] + for i, listener := range listeners { + if listener == l { + c.listeners[key] = append(listeners[:i], listeners[i+1:]...) + break + } + } + + if len(c.listeners[key]) == 0 && c.watchFlag[key] { + if cancel, ok := c.watchCtx[key]; ok { + cancel() + delete(c.watchCtx, key) + } + c.watchFlag[key] = false + } +} + func (c *cluster) newClient() (EtcdClient, error) { cli, err := NewClient(c.endpoints) if err != nil { @@ -294,20 +336,30 @@ func (c *cluster) reload(cli EtcdClient) { for k := range c.listeners { keys = append(keys, k) } + for _, cancel := range c.watchCtx { + cancel() + } + c.watchCtx = make(map[string]context.CancelFunc) + c.watchFlag = make(map[string]bool) c.lock.Unlock() for _, key := range keys { k := key c.watchGroup.Run(func() { rev := c.load(cli, k) - c.watch(cli, k, rev) + ctx, cancel := context.WithCancel(cli.Ctx()) + c.lock.Lock() + c.watchCtx[k] = cancel + c.watchFlag[k] = true + c.lock.Unlock() + c.watch(cli, k, rev, ctx) }) } } -func (c *cluster) watch(cli EtcdClient, key string, rev int64) { +func (c *cluster) watch(cli EtcdClient, key string, rev int64, ctx context.Context) { for { - err := c.watchStream(cli, key, rev) + err := c.watchStream(cli, key, rev, ctx) if err == nil { return } @@ -322,7 +374,7 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) { } } -func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { +func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context.Context) error { var ( rch clientv3.WatchChan ops []clientv3.OpOption @@ -336,7 +388,7 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { ops = append(ops, clientv3.WithRev(rev+1)) } - rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), watchKey, ops...) + rch = cli.Watch(clientv3.WithRequireLeader(ctx), watchKey, ops...) for { select { @@ -354,6 +406,8 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { c.handleWatchEvents(key, wresp.Events) case <-c.done: return nil + case <-ctx.Done(): + return nil } } } diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index bb9fd629b3bda..c448477a8745e 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -156,7 +156,7 @@ func TestCluster_Watch(t *testing.T) { defer restore() ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch) - cli.EXPECT().Ctx().Return(context.Background()) + //cli.EXPECT().Ctx().Return(context.Background()) var wg sync.WaitGroup wg.Add(1) c := &cluster{ @@ -173,7 +173,7 @@ func TestCluster_Watch(t *testing.T) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { wg.Done() }).MaxTimes(1) - go c.watch(cli, "any", 0) + go c.watch(cli, "any", 0, context.Background()) ch <- clientv3.WatchResponse{ Events: []*clientv3.Event{ { @@ -217,7 +217,7 @@ func TestClusterWatch_RespFailures(t *testing.T) { ch <- resp close(c.done) }() - c.watch(cli, "any", 0) + c.watch(cli, "any", 0, context.Background()) }) } } @@ -237,7 +237,26 @@ func TestClusterWatch_CloseChan(t *testing.T) { close(ch) close(c.done) }() - c.watch(cli, "any", 0) + c.watch(cli, "any", 0, context.Background()) +} + +func TestClusterWatch_CtxCancel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() + cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() + c := new(cluster) + c.done = make(chan lang.PlaceholderType) + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + cancelFunc() + close(ch) + }() + c.watch(cli, "any", 0, ctx) } func TestValueOnlyContext(t *testing.T) { @@ -292,6 +311,30 @@ func TestRegistry_Monitor(t *testing.T) { assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false)) } +func TestRegistry_UnMonitor(t *testing.T) { + svr, err := mockserver.StartMockServers(1) + assert.NoError(t, err) + svr.StartAt(0) + + endpoints := []string{svr.Servers[0].Address} + l := new(mockListener) + GetRegistry().lock.Lock() + GetRegistry().clusters = map[string]*cluster{ + getClusterKey(endpoints): { + listeners: map[string][]UpdateListener{"foo": {l}}, + values: map[string]map[string]string{ + "foo": { + "bar": "baz", + }, + }, + watchFlag: map[string]bool{"foo": true}, + watchCtx: map[string]context.CancelFunc{"foo": func() {}}, + }, + } + GetRegistry().lock.Unlock() + GetRegistry().UnMonitor(endpoints, "foo", l) +} + type mockListener struct { } diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go index 08f89a601ff70..515ea71d93a84 100644 --- a/core/discov/subscriber.go +++ b/core/discov/subscriber.go @@ -18,6 +18,7 @@ type ( endpoints []string exclusive bool exactMatch bool + key string items *container } ) @@ -29,6 +30,7 @@ type ( func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) { sub := &Subscriber{ endpoints: endpoints, + key: key, } for _, opt := range opts { opt(sub) @@ -52,6 +54,11 @@ func (s *Subscriber) Values() []string { return s.items.getValues() } +// Close the subscriber created watch goroutine and remove listener +func (s *Subscriber) Close() { + internal.GetRegistry().UnMonitor(s.endpoints, s.key, s.items) +} + // Exclusive means that key value can only be 1:1, // which means later added value will remove the keys associated with the same value previously. func Exclusive() SubOption { diff --git a/core/discov/subscriber_test.go b/core/discov/subscriber_test.go index 6dce7cec1c60e..1f760979f4bcf 100644 --- a/core/discov/subscriber_test.go +++ b/core/discov/subscriber_test.go @@ -214,6 +214,18 @@ func TestSubscriber(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&count)) } +func TestSubscriberClos(t *testing.T) { + l := newContainer(false) + sub := &Subscriber{ + endpoints: []string{"localhost:2379"}, + key: "foo", + items: l, + } + _ = internal.GetRegistry().Monitor(sub.endpoints, sub.key, l, false) + sub.Close() + assert.Empty(t, sub.items.listeners) +} + func TestWithSubEtcdAccount(t *testing.T) { endpoints := []string{"localhost:2379"} user := stringx.Rand() diff --git a/zrpc/resolver/internal/discovbuilder.go b/zrpc/resolver/internal/discovbuilder.go index ed377d138fa14..5a91ee73e40c8 100644 --- a/zrpc/resolver/internal/discovbuilder.go +++ b/zrpc/resolver/internal/discovbuilder.go @@ -38,7 +38,7 @@ func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ sub.AddListener(update) update() - return &nopResolver{cc: cc}, nil + return &nopResolver{cc: cc, closeFunc: func() { sub.Close() }}, nil } func (b *discovBuilder) Scheme() string { diff --git a/zrpc/resolver/internal/resolver.go b/zrpc/resolver/internal/resolver.go index 7868eca8cecfc..e04d65d875c09 100644 --- a/zrpc/resolver/internal/resolver.go +++ b/zrpc/resolver/internal/resolver.go @@ -37,10 +37,14 @@ func register() { } type nopResolver struct { - cc resolver.ClientConn + cc resolver.ClientConn + closeFunc func() } func (r *nopResolver) Close() { + if r.closeFunc != nil { + r.closeFunc() + } } func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) { diff --git a/zrpc/resolver/internal/resolver_test.go b/zrpc/resolver/internal/resolver_test.go index 7dd10ee79d101..f71605c40310b 100644 --- a/zrpc/resolver/internal/resolver_test.go +++ b/zrpc/resolver/internal/resolver_test.go @@ -1,6 +1,7 @@ package internal import ( + "github.com/zeromicro/go-zero/core/discov" "testing" "github.com/stretchr/testify/assert" @@ -18,6 +19,16 @@ func TestNopResolver(t *testing.T) { }) } +func TestNopResolverClose(t *testing.T) { + assert.NotPanics(t, func() { + sub := &discov.Subscriber{} + r := nopResolver{ + closeFunc: sub.Close, + } + r.Close() + }) +} + type mockedClientConn struct { state resolver.State err error