diff --git a/_examples/compose/main.go b/_examples/compose/main.go index 891b881..a58a633 100644 --- a/_examples/compose/main.go +++ b/_examples/compose/main.go @@ -11,10 +11,9 @@ import ( ) func main() { - mstore, err := memlru.New[string]() if err != nil { - panic(err) + log.Fatal(err) } rstore, err := redis.New[string](&redis.Config{ @@ -23,7 +22,7 @@ func main() { Port: 6379, }) if err != nil { - panic(err) + log.Fatal(err) } // Compose a store chain where we set/get keys in order of: @@ -33,7 +32,7 @@ func main() { // local cache in memory, and search the remote redis cache if its not in memory. store, err := cachestore.Compose(mstore, rstore) if err != nil { - panic(err) + log.Fatal(err) } ctx := context.Background() @@ -44,7 +43,7 @@ func main() { for i := 0; i < 100; i++ { err = store.Set(ctx, fmt.Sprintf("foo:%d", i), fmt.Sprintf("value-%d", i)) if err != nil { - panic(err) + log.Fatal(err) } } } diff --git a/cachestore.go b/cachestore.go index 57c1606..9cefd72 100644 --- a/cachestore.go +++ b/cachestore.go @@ -24,6 +24,10 @@ type Store[V any] interface { // for that key. SetEx(ctx context.Context, key string, value V, ttl time.Duration) error + // GetEx returns a stored value with ttl + // duration is nil when key does not have ttl set + GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) + // BatchSet sets all the values associated to the given keys. BatchSet(ctx context.Context, keys []string, values []V) error diff --git a/compose.go b/compose.go index 629d8f3..33978b1 100644 --- a/compose.go +++ b/compose.go @@ -73,6 +73,25 @@ func (cs *ComposeStore[V]) BatchSetEx(ctx context.Context, keys []string, values return nil } +func (cs *ComposeStore[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) { + var out V + var ttl *time.Duration + var exists bool + var err error + + for _, s := range cs.stores { + out, ttl, exists, err = s.GetEx(ctx, key) + if err != nil { + return out, ttl, exists, err + } + if exists { + break + } + } + + return out, ttl, exists, nil +} + func (cs *ComposeStore[V]) Get(ctx context.Context, key string) (V, bool, error) { var out V var exists bool diff --git a/memlru/memlru.go b/memlru/memlru.go index 58f43ff..6c47eeb 100644 --- a/memlru/memlru.go +++ b/memlru/memlru.go @@ -116,6 +116,30 @@ func (m *MemLRU[V]) BatchSetEx(ctx context.Context, keys []string, values []V, t return nil } +func (c *MemLRU[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) { + out, exists, err := c.Get(ctx, key) + if err != nil { + return out, nil, false, fmt.Errorf("get %w", err) + } + + if !exists { + return out, nil, false, nil + } + + item, ok := c.expirationQueue.GetItem(key) + if !ok { + return out, nil, true, nil + } + + if item.expiresAt.Before(time.Now()) { + return out, nil, false, nil + } + + ttl := item.expiresAt.Sub(time.Now()) + + return out, &ttl, true, nil +} + func (m *MemLRU[V]) Get(ctx context.Context, key string) (V, bool, error) { var out V m.removeExpiredKeys() diff --git a/memlru/ttl.go b/memlru/ttl.go index c8d75a9..9cb7d58 100644 --- a/memlru/ttl.go +++ b/memlru/ttl.go @@ -74,6 +74,19 @@ func (e *expirationQueue) UpdateLastCheckTime() { e.mu.Unlock() } +func (e *expirationQueue) GetItem(key string) (expirationQueueItem, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + + for _, i := range e.keys { + if i.key == key { + return i, true + } + } + + return expirationQueueItem{}, false +} + func (e *expirationQueue) expiredAt(t time.Time) []string { e.mu.Lock() defer e.mu.Unlock() diff --git a/memlru/ttl_test.go b/memlru/ttl_test.go index a08587d..1887157 100644 --- a/memlru/ttl_test.go +++ b/memlru/ttl_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func init() { @@ -24,24 +24,24 @@ func TestExpirationQueue(t *testing.T) { q.Push("e", time.Millisecond*300) q.Push("f", time.Millisecond*50) - assert.Equal(t, 6, q.Len()) + require.Equal(t, 6, q.Len()) { var lastTime time.Time for _, key := range q.keys { - assert.LessOrEqual(t, lastTime, key.expiresAt) + require.LessOrEqual(t, lastTime, key.expiresAt) lastTime = key.expiresAt } } q.Push("f", time.Millisecond*500) - assert.Equal(t, 6, q.Len()) + require.Equal(t, 6, q.Len()) { var lastTime time.Time for _, key := range q.keys { - assert.LessOrEqual(t, lastTime, key.expiresAt) + require.LessOrEqual(t, lastTime, key.expiresAt) lastTime = key.expiresAt } } @@ -49,34 +49,34 @@ func TestExpirationQueue(t *testing.T) { { keys := q.expiredAt(time.Now()) - assert.Equal(t, 6, q.Len()) - assert.Equal(t, 0, len(keys)) + require.Equal(t, 6, q.Len()) + require.Equal(t, 0, len(keys)) } { keys := q.expiredAt(time.Now().Add(time.Second * -1)) - assert.Equal(t, 6, q.Len()) - assert.Equal(t, 0, len(keys)) + require.Equal(t, 6, q.Len()) + require.Equal(t, 0, len(keys)) } { keys := q.expiredAt(time.Now().Add(time.Millisecond * 200)) - assert.Equal(t, 3, q.Len()) - assert.Equal(t, 3, len(keys)) + require.Equal(t, 3, q.Len()) + require.Equal(t, 3, len(keys)) } for i := 0; i < 100; i++ { q.Push("z", time.Millisecond*time.Duration(50+rand.Intn(500))) } - assert.Equal(t, 4, q.Len()) + require.Equal(t, 4, q.Len()) { var lastTime time.Time for _, key := range q.keys { - assert.LessOrEqual(t, lastTime, key.expiresAt) + require.LessOrEqual(t, lastTime, key.expiresAt) lastTime = key.expiresAt } } @@ -85,26 +85,26 @@ func TestExpirationQueue(t *testing.T) { q.Push(fmt.Sprintf("key-%d", i), time.Millisecond*time.Duration(50+rand.Intn(500))) } - assert.Equal(t, 104, q.Len()) + require.Equal(t, 104, q.Len()) { var lastTime time.Time for _, key := range q.keys { - assert.LessOrEqual(t, lastTime, key.expiresAt) + require.LessOrEqual(t, lastTime, key.expiresAt) lastTime = key.expiresAt } } { keys := q.expiredAt(time.Now()) - assert.Equal(t, 104, q.Len()) - assert.Equal(t, 0, len(keys)) + require.Equal(t, 104, q.Len()) + require.Equal(t, 0, len(keys)) } { keys := q.expiredAt(time.Now().Add(time.Second * 10)) - assert.Equal(t, 0, q.Len()) - assert.Equal(t, 104, len(keys)) + require.Equal(t, 0, q.Len()) + require.Equal(t, 104, len(keys)) } for i := 0; i < 100; i++ { @@ -114,19 +114,19 @@ func TestExpirationQueue(t *testing.T) { { var lastTime time.Time for _, key := range q.keys { - assert.LessOrEqual(t, lastTime, key.expiresAt) + require.LessOrEqual(t, lastTime, key.expiresAt) lastTime = key.expiresAt } } - assert.Equal(t, 100, q.Len()) + require.Equal(t, 100, q.Len()) } func TestSetEx(t *testing.T) { ctx := context.Background() c, err := NewWithSize[[]byte](50) - assert.NoError(t, err) + require.NoError(t, err) { keys := []string{} @@ -136,26 +136,26 @@ func TestSetEx(t *testing.T) { // SetEx with time 0 is the same as just a Set, because there is no expiry time // aka, the key doesn't expire. err := c.SetEx(ctx, key, []byte("a"), time.Duration(0)) - assert.NoError(t, err) + require.NoError(t, err) keys = append(keys, key) } for _, key := range keys { buf, exists, err := c.Get(ctx, key) - assert.True(t, exists) - assert.NoError(t, err) - assert.NotNil(t, buf) + require.True(t, exists) + require.NoError(t, err) + require.NotNil(t, buf) exists, err = c.Exists(ctx, key) - assert.NoError(t, err) - assert.True(t, exists) + require.NoError(t, err) + require.True(t, exists) } values, batchExists, err := c.BatchGet(ctx, keys) - assert.NoError(t, err) + require.NoError(t, err) for i := range values { - assert.NotNil(t, values[i]) - assert.True(t, batchExists[i]) + require.NotNil(t, values[i]) + require.True(t, batchExists[i]) } } @@ -164,27 +164,55 @@ func TestSetEx(t *testing.T) { for i := 0; i < 20; i++ { key := fmt.Sprintf("key-%d", i) err := c.SetEx(ctx, key, []byte("a"), time.Second*10) // a key that expires in 10 seconds - assert.NoError(t, err) + require.NoError(t, err) keys = append(keys, key) } for _, key := range keys { buf, exists, err := c.Get(ctx, key) - assert.NoError(t, err) - assert.NotNil(t, buf) - assert.True(t, exists) + require.NoError(t, err) + require.NotNil(t, buf) + require.True(t, exists) exists, err = c.Exists(ctx, key) - assert.NoError(t, err) - assert.True(t, exists) + require.NoError(t, err) + require.True(t, exists) } values, batchExists, err := c.BatchGet(ctx, keys) - assert.NoError(t, err) + require.NoError(t, err) for i := range values { - assert.NotNil(t, values[i]) - assert.True(t, batchExists[i]) + require.NotNil(t, values[i]) + require.True(t, batchExists[i]) } } } + +func TestGetEx(t *testing.T) { + ctx := context.Background() + + cache, err := NewWithSize[[]byte](50) + require.NoError(t, err) + + err = cache.SetEx(ctx, "hi", []byte("bye"), 10*time.Second) + require.NoError(t, err) + + v, ttl, exists, err := cache.GetEx(ctx, "hi") + require.NoError(t, err) + require.True(t, exists) + require.InDelta(t, 10*time.Second, *ttl, float64(1*time.Second), "TTL are not equal within the allowed delta") + require.Equal(t, []byte("bye"), v) + + v, ttl, exists, err = cache.GetEx(ctx, "not-found") + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, ttl) + + err = cache.Set(ctx, "without-ttl", []byte("hello")) + require.NoError(t, err) + + v, ttl, exists, err = cache.GetEx(ctx, "without-ttl") + require.NoError(t, err) + require.Nil(t, ttl) +} diff --git a/nostore/nostore.go b/nostore/nostore.go index e38888d..b000fcb 100644 --- a/nostore/nostore.go +++ b/nostore/nostore.go @@ -41,6 +41,12 @@ func (s *NoStore[V]) SetEx(ctx context.Context, key string, value V, ttl time.Du return nil } +func (c *NoStore[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) { + var out V + ttl := time.Duration(0) + return out, &ttl, false, nil +} + func (s *NoStore[V]) BatchSet(ctx context.Context, keys []string, values []V) error { return nil } diff --git a/redis/redis.go b/redis/redis.go index 8ff68be..d7eadcc 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -14,7 +14,7 @@ import ( "github.com/redis/go-redis/v9" ) -const LongTime = time.Second * 24 * 60 * 60 // 1 day in seconds +const DefaultTTL = time.Second * 24 * 60 * 60 // 1 day in seconds var _ cachestore.Store[any] = &RedisStore[any]{} @@ -60,7 +60,7 @@ func New[V any](cfg *Config, opts ...cachestore.StoreOptions) (cachestore.Store[ cfg.Port = 6379 } if cfg.KeyTTL == 0 { - cfg.KeyTTL = LongTime // default setting + cfg.KeyTTL = DefaultTTL // default setting } // Create store and connect to backend @@ -85,7 +85,7 @@ func New[V any](cfg *Config, opts ...cachestore.StoreOptions) (cachestore.Store[ // Set default key expiry for a long time on redis always. This is how we ensure // the cache will always function as a LRU. if store.options.DefaultKeyExpiry == 0 { - store.options.DefaultKeyExpiry = LongTime + store.options.DefaultKeyExpiry = DefaultTTL } return store, nil @@ -181,6 +181,59 @@ func (c *RedisStore[V]) BatchSetEx(ctx context.Context, keys []string, values [] return nil } +func (c *RedisStore[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) { + var out V + var ttl *time.Duration + + _, err := c.client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + var err error + + getVal := pipe.Get(ctx, key) + getTTL := pipe.TTL(ctx, key) + + if _, err = pipe.Exec(ctx); err != nil { + if errors.Is(err, redis.Nil) { + return err + } + + return fmt.Errorf("exec: %w", err) + } + + ttlRes, err := getTTL.Result() + if err != nil { + return fmt.Errorf("TTL command failed: %w", err) + } + + if ttlRes == -1 { + ttl = nil + } else { + ttl = &ttlRes + } + + data, err := getVal.Bytes() + if err != nil { + return fmt.Errorf("get bytes: %w", err) + } + + out, err = deserialize[V](data) + if err != nil { + return fmt.Errorf("deserialize: %w", err) + } + + return nil + }) + + if err != nil { + if errors.Is(err, redis.Nil) { + return out, ttl, false, nil + } + + return out, ttl, false, fmt.Errorf("GetEx: %w", err) + } + + return out, ttl, true, nil +} + func (c *RedisStore[V]) Get(ctx context.Context, key string) (V, bool, error) { var out V diff --git a/redis/redis_test.go b/redis/redis_test.go index 8b1ebf0..031c8a1 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/goware/cachestore" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -249,14 +248,45 @@ func TestGetOrSetWithLock(t *testing.T) { } require.NoError(t, wg.Wait()) - assert.Equalf(t, 1, int(counter.Load()), "getter should be called only once") + require.Equalf(t, 1, int(counter.Load()), "getter should be called only once") for i := 0; i < concurrentCalls; i++ { select { case v := <-results: - assert.Equal(t, "result:"+key, v) + require.Equal(t, "result:"+key, v) default: t.Errorf("expected %d results but only got %d", concurrentCalls, i) } } } + +func TestGetEx(t *testing.T) { + ctx := context.Background() + + cache, err := New[string](&Config{Enabled: true, Host: "localhost"}, cachestore.WithDefaultKeyExpiry(-1*time.Second)) + require.NoError(t, err) + + _, ok := cache.(*RedisStore[string]) + require.True(t, ok) + + err = cache.SetEx(ctx, "hi", "bye", 10*time.Second) + require.NoError(t, err) + + v, ttl, exists, err := cache.GetEx(ctx, "hi") + require.NoError(t, err) + require.True(t, exists) + require.InDelta(t, 10*time.Second, *ttl, float64(1*time.Second), "TTL are not equal within the allowed delta") + require.Equal(t, "bye", v) + + v, ttl, exists, err = cache.GetEx(ctx, "not-found") + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, ttl) + + err = cache.Set(ctx, "without-ttl", "hello") + require.NoError(t, err) + + v, ttl, exists, err = cache.GetEx(ctx, "without-ttl") + require.NoError(t, err) + require.Nil(t, ttl) +}