diff --git a/monitor/redis/watcher.go b/monitor/redis/watcher.go index 47e4925a..e7ab1b91 100644 --- a/monitor/redis/watcher.go +++ b/monitor/redis/watcher.go @@ -71,11 +71,23 @@ func (w *Watcher) monitor(ctx context.Context, ch chan<- []*change.Change) { } func (w *Watcher) getValues(ctx context.Context, ch chan<- []*change.Change) { - values, err := w.client.MGet(ctx, w.keys...).Result() - if err != nil { - log.Errorf("failed to MGET keys %v: %v", w.keys, err) - return + values := make([]*string, len(w.keys)) + for i, key := range w.keys { + strCmd := w.client.Get(ctx, key) + if strCmd == nil { + log.Errorf("failed to get value for key %s: nil strCmd", key) + continue + } + if strCmd.Err() != nil { + if strCmd.Err() != redis.Nil { + log.Errorf("failed to get value for key %s: %s", key, strCmd.Err()) + } + continue + } + val := strCmd.Val() + values[i] = &val } + changes := make([]*change.Change, 0, len(w.keys)) for i, key := range w.keys { @@ -83,7 +95,7 @@ func (w *Watcher) getValues(ctx context.Context, ch chan<- []*change.Change) { continue } - value := values[i].(string) + value := *values[i] hash := w.hash(value) if hash == w.hashes[i] { continue diff --git a/monitor/redis/watcher_test.go b/monitor/redis/watcher_test.go index 7ce6cbed..dfe6e18b 100644 --- a/monitor/redis/watcher_test.go +++ b/monitor/redis/watcher_test.go @@ -2,6 +2,7 @@ package redis import ( "context" + "errors" "sync" "testing" "time" @@ -69,11 +70,46 @@ func TestWatcher_Watch(t *testing.T) { } func TestWatcher_Versioning(t *testing.T) { - client := (&clientStub{t: t}). - WithValues("val1.1", "val2.1", "val3.1"). // Initial values - WithValues("val1.1", "val2.2", "val3.2"). // Only keys 2 and 3 are updated - WithValues("val1.1", "val2.1", "val3.2") // Only 2 is updated, to its previous value - + watchedKeys := []string{"key1", "key2", "key3"} + // each element represent the state of redis server at each subsequent poll + redisInternalState := []map[string]interface{}{ + { + // watch triggers change in key1, key2 and key3 + "key1": "val1.1", + "key2": "val2.1", + "key3": "val3.1", + }, + { + // watch triggers change in key2 and key3 + "key1": "val1.1", // no change + "key2": "val2.2", // change + "key3": "val3.2", // change + }, + { + // whole watch does not trigger change (but errors will be logged as != redis.Nil) + "key1": errors.New("error key1"), // error occurred -> no change should be triggered + "key2": errors.New("error key2"), // error occurred -> no change should be triggered + "key3": errors.New("error key3"), // error occurred -> no change should be triggered + }, + { + // watch does not trigger change or log because key1 watch will lead to redis.Nil + "key2": "val2.2", // no change from previous + "key3": "val3.2", // no change from previous + }, + { + // all subscribed keys deleted -> do not trigger change or log because redis.Nil is ignored + "key4": "val4.1", // no change -> not subscribed to this key + }, + { + // all subscribed keys deleted -> do not trigger change or log because redis.Nil is ignored + "key4": "val4.2", // no change -> not subscribed to this key + }, + { + // all subscribed keys deleted -> do not trigger change but triggers + // log because error is different than redis.Nil is ignored + "key2": errors.New("error"), // error occurred -> no change should be triggered + }, + } expected := [][]*change.Change{ { change.New(config.SourceRedis, "key1", "val1.1", 1), @@ -84,12 +120,13 @@ func TestWatcher_Versioning(t *testing.T) { change.New(config.SourceRedis, "key2", "val2.2", 2), change.New(config.SourceRedis, "key3", "val3.2", 2), }, - { - change.New(config.SourceRedis, "key2", "val2.1", 3), - }, } - w, err := New(client, 1*time.Millisecond, []string{"key1", "key2", "key3"}) + client := clientStub{t: t, m: sync.Mutex{}, watchedKeys: watchedKeys} + for _, mv := range redisInternalState { + client.AppendMockValues(mv) + } + w, err := New(&client, 5*time.Millisecond, []string{"key1", "key2", "key3"}) require.NoError(t, err) assert.Equal(t, []uint64{0, 0, 0}, w.versions) assert.Equal(t, []string{"", "", ""}, w.hashes) @@ -100,29 +137,25 @@ func TestWatcher_Versioning(t *testing.T) { err = w.Watch(ctx, ch) assert.NoError(t, err) + // time for completing all the polling for the different states time.Sleep(100 * time.Millisecond) - cancel() - found := make([][]*change.Change, 0) - wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() for { select { case cc := <-ch: - if len(cc) == 0 { - break - } found = append(found, cc) default: - wg.Done() return } } }() + cancel() wg.Wait() assert.Equal(t, expected, found) @@ -131,20 +164,60 @@ func TestWatcher_Versioning(t *testing.T) { type clientStub struct { t *testing.T *redis.Client + m sync.Mutex + watchedKeys []string + internalGetCalls int - cmds []*redis.SliceCmd + keyToCmd []map[string]*redis.StringCmd } -func (c *clientStub) WithValues(values ...interface{}) *clientStub { - c.cmds = append(c.cmds, redis.NewSliceResult(values, nil)) +func (c *clientStub) AppendMockValues(values map[string]interface{}) *clientStub { + c.m.Lock() + defer c.m.Unlock() + mockResp := make(map[string]*redis.StringCmd) + for k, v := range values { + if v == nil { + mockResp[k] = nil + continue + } + if s, ok := v.(string); ok { + mockResp[k] = redis.NewStringResult(s, nil) + continue + } + if e, ok := v.(error); ok { + mockResp[k] = redis.NewStringResult("", e) + continue + } + mockResp[k] = redis.NewStringResult("", errors.New("Unknown Error")) + } + + if c.keyToCmd == nil { + c.keyToCmd = make([]map[string]*redis.StringCmd, 0) + } + c.keyToCmd = append(c.keyToCmd, mockResp) return c } -func (c *clientStub) MGet(_ context.Context, keys ...string) *redis.SliceCmd { - if len(c.cmds) == 0 { - return redis.NewSliceResult(make([]interface{}, len(keys)), nil) +func (c *clientStub) Get(_ context.Context, key string) *redis.StringCmd { + c.m.Lock() + defer c.m.Unlock() + c.internalGetCalls++ + defer c.rollInternalRedisState() + if len(c.keyToCmd) == 0 { + return redis.NewStringResult("", redis.Nil) + } + shifted := c.keyToCmd[0] + if v, ok := shifted[key]; ok { + return v + } + + return redis.NewStringResult("", redis.Nil) + +} + +func (c *clientStub) rollInternalRedisState() { + // replace redis virtual state every len(watchedKeys) calls to Get + if len(c.keyToCmd) > 0 && (c.internalGetCalls)%len(c.watchedKeys) == 0 { + c.keyToCmd = c.keyToCmd[1:] } - shifted := c.cmds[0] - c.cmds = c.cmds[1:] - return shifted }