Skip to content

Commit

Permalink
update code for unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanosk07 authored and kevwan committed Dec 23, 2024
1 parent 8de60f9 commit c3098e6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 30 deletions.
41 changes: 22 additions & 19 deletions core/discov/internal/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,9 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error {
return err
}

ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
c.watchCtx[key] = cancel
c.watchFlag[key] = true
c.lock.Unlock()

rev := c.load(cli, key)
c.watchGroup.Run(func() {
c.watch(cli, key, rev, ctx)
c.watch(cli, key, rev)
})
}

Expand Down Expand Up @@ -336,28 +330,20 @@ func (c *cluster) reload(cli EtcdClient) {
for k := range c.listeners {
keys = append(keys, k)
}
for _, cancel := range c.watchCtx {
cancel()
}
c.watchCtx = make(map[string]context.CancelFunc)
c.watchFlag = make(map[string]bool)
c.clearWatch()
c.lock.Unlock()

for _, key := range keys {
k := key
c.watchGroup.Run(func() {
rev := c.load(cli, k)
ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
c.watchCtx[k] = cancel
c.watchFlag[k] = true
c.lock.Unlock()
c.watch(cli, k, rev, ctx)
c.watch(cli, k, rev)
})
}
}

func (c *cluster) watch(cli EtcdClient, key string, rev int64, ctx context.Context) {
func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
ctx := c.addWatch(key, cli)
for {
err := c.watchStream(cli, key, rev, ctx)
if err == nil {
Expand Down Expand Up @@ -420,6 +406,23 @@ func (c *cluster) watchConnState(cli EtcdClient) {
watcher.watch(cli.ActiveConnection())
}

func (c *cluster) addWatch(key string, cli EtcdClient) context.Context {
ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
c.watchCtx[key] = cancel
c.watchFlag[key] = true
c.lock.Unlock()
return ctx
}

func (c *cluster) clearWatch() {
for _, cancel := range c.watchCtx {
cancel()
}
c.watchCtx = make(map[string]context.CancelFunc)
c.watchFlag = make(map[string]bool)
}

// DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{
Expand Down
38 changes: 27 additions & 11 deletions core/discov/internal/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ func TestCluster_Watch(t *testing.T) {
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
//cli.EXPECT().Ctx().Return(context.Background())
cli.EXPECT().Ctx().Return(context.Background())
var wg sync.WaitGroup
wg.Add(1)
c := &cluster{
listeners: make(map[string][]UpdateListener),
values: make(map[string]map[string]string),
listeners: make(map[string][]UpdateListener),
watchCtx: make(map[string]context.CancelFunc),
watchFlag: make(map[string]bool),
}
listener := NewMockUpdateListener(ctrl)
c.listeners["any"] = []UpdateListener{listener}
Expand All @@ -173,7 +175,7 @@ func TestCluster_Watch(t *testing.T) {
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) {
wg.Done()
}).MaxTimes(1)
go c.watch(cli, "any", 0, context.Background())
go c.watch(cli, "any", 0)
ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{
{
Expand Down Expand Up @@ -211,13 +213,13 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := new(cluster)
c := newCluster([]string{})
c.done = make(chan lang.PlaceholderType)
go func() {
ch <- resp
close(c.done)
}()
c.watch(cli, "any", 0, context.Background())
c.watch(cli, "any", 0)
})
}
}
Expand All @@ -231,13 +233,13 @@ func TestClusterWatch_CloseChan(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := new(cluster)
c := newCluster([]string{})
c.done = make(chan lang.PlaceholderType)
go func() {
close(ch)
close(c.done)
}()
c.watch(cli, "any", 0, context.Background())
c.watch(cli, "any", 0)
}

func TestClusterWatch_CtxCancel(t *testing.T) {
Expand All @@ -248,15 +250,27 @@ func TestClusterWatch_CtxCancel(t *testing.T) {
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := new(cluster)
c.done = make(chan lang.PlaceholderType)
ctx, cancelFunc := context.WithCancel(context.Background())
cli.EXPECT().Ctx().Return(ctx).AnyTimes()
c := newCluster([]string{})
c.done = make(chan lang.PlaceholderType)
go func() {
cancelFunc()
close(ch)
}()
c.watch(cli, "any", 0, ctx)
c.watch(cli, "any", 0)
}

func TestCluster_ClearWatch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
c := &cluster{
watchCtx: map[string]context.CancelFunc{"foo": cancel},
watchFlag: map[string]bool{"foo": true},
}
c.clearWatch()
assert.Equal(t, ctx.Err(), context.Canceled)
assert.Equal(t, 0, len(c.watchCtx))
assert.Equal(t, 0, len(c.watchFlag))
}

func TestValueOnlyContext(t *testing.T) {
Expand Down Expand Up @@ -305,6 +319,8 @@ func TestRegistry_Monitor(t *testing.T) {
"bar": "baz",
},
},
watchCtx: map[string]context.CancelFunc{},
watchFlag: map[string]bool{},
},
}
GetRegistry().lock.Unlock()
Expand Down

0 comments on commit c3098e6

Please sign in to comment.