From c3098e6ffe5874cb5107a26ce599c8d0f5dfca91 Mon Sep 17 00:00:00 2001 From: nano <260391798@qq.com> Date: Mon, 23 Dec 2024 14:17:40 +0800 Subject: [PATCH] update code for unit test --- core/discov/internal/registry.go | 41 ++++++++++++++------------- core/discov/internal/registry_test.go | 38 ++++++++++++++++++------- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index 2c0848101e6b..7a4134e3ae15 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -279,15 +279,9 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { return err } - ctx, cancel := context.WithCancel(cli.Ctx()) - c.lock.Lock() - c.watchCtx[key] = cancel - c.watchFlag[key] = true - c.lock.Unlock() - rev := c.load(cli, key) c.watchGroup.Run(func() { - c.watch(cli, key, rev, ctx) + c.watch(cli, key, rev) }) } @@ -336,28 +330,20 @@ func (c *cluster) reload(cli EtcdClient) { for k := range c.listeners { keys = append(keys, k) } - for _, cancel := range c.watchCtx { - cancel() - } - c.watchCtx = make(map[string]context.CancelFunc) - c.watchFlag = make(map[string]bool) + c.clearWatch() c.lock.Unlock() for _, key := range keys { k := key c.watchGroup.Run(func() { rev := c.load(cli, k) - ctx, cancel := context.WithCancel(cli.Ctx()) - c.lock.Lock() - c.watchCtx[k] = cancel - c.watchFlag[k] = true - c.lock.Unlock() - c.watch(cli, k, rev, ctx) + c.watch(cli, k, rev) }) } } -func (c *cluster) watch(cli EtcdClient, key string, rev int64, ctx context.Context) { +func (c *cluster) watch(cli EtcdClient, key string, rev int64) { + ctx := c.addWatch(key, cli) for { err := c.watchStream(cli, key, rev, ctx) if err == nil { @@ -420,6 +406,23 @@ func (c *cluster) watchConnState(cli EtcdClient) { watcher.watch(cli.ActiveConnection()) } +func (c *cluster) addWatch(key string, cli EtcdClient) context.Context { + ctx, cancel := context.WithCancel(cli.Ctx()) + c.lock.Lock() + c.watchCtx[key] = cancel + c.watchFlag[key] = true + c.lock.Unlock() + return ctx +} + +func (c *cluster) clearWatch() { + for _, cancel := range c.watchCtx { + cancel() + } + c.watchCtx = make(map[string]context.CancelFunc) + c.watchFlag = make(map[string]bool) +} + // DialClient dials an etcd cluster with given endpoints. func DialClient(endpoints []string) (EtcdClient, error) { cfg := clientv3.Config{ diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index c448477a8745..197aad5a4de9 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -156,12 +156,14 @@ func TestCluster_Watch(t *testing.T) { defer restore() ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch) - //cli.EXPECT().Ctx().Return(context.Background()) + cli.EXPECT().Ctx().Return(context.Background()) var wg sync.WaitGroup wg.Add(1) c := &cluster{ - listeners: make(map[string][]UpdateListener), values: make(map[string]map[string]string), + listeners: make(map[string][]UpdateListener), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } listener := NewMockUpdateListener(ctrl) c.listeners["any"] = []UpdateListener{listener} @@ -173,7 +175,7 @@ func TestCluster_Watch(t *testing.T) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { wg.Done() }).MaxTimes(1) - go c.watch(cli, "any", 0, context.Background()) + go c.watch(cli, "any", 0) ch <- clientv3.WatchResponse{ Events: []*clientv3.Event{ { @@ -211,13 +213,13 @@ func TestClusterWatch_RespFailures(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { ch <- resp close(c.done) }() - c.watch(cli, "any", 0, context.Background()) + c.watch(cli, "any", 0) }) } } @@ -231,13 +233,13 @@ func TestClusterWatch_CloseChan(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { close(ch) close(c.done) }() - c.watch(cli, "any", 0, context.Background()) + c.watch(cli, "any", 0) } func TestClusterWatch_CtxCancel(t *testing.T) { @@ -248,15 +250,27 @@ func TestClusterWatch_CtxCancel(t *testing.T) { defer restore() ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() - cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) - c.done = make(chan lang.PlaceholderType) ctx, cancelFunc := context.WithCancel(context.Background()) + cli.EXPECT().Ctx().Return(ctx).AnyTimes() + c := newCluster([]string{}) + c.done = make(chan lang.PlaceholderType) go func() { cancelFunc() close(ch) }() - c.watch(cli, "any", 0, ctx) + c.watch(cli, "any", 0) +} + +func TestCluster_ClearWatch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := &cluster{ + watchCtx: map[string]context.CancelFunc{"foo": cancel}, + watchFlag: map[string]bool{"foo": true}, + } + c.clearWatch() + assert.Equal(t, ctx.Err(), context.Canceled) + assert.Equal(t, 0, len(c.watchCtx)) + assert.Equal(t, 0, len(c.watchFlag)) } func TestValueOnlyContext(t *testing.T) { @@ -305,6 +319,8 @@ func TestRegistry_Monitor(t *testing.T) { "bar": "baz", }, }, + watchCtx: map[string]context.CancelFunc{}, + watchFlag: map[string]bool{}, }, } GetRegistry().lock.Unlock()