Skip to content

Commit

Permalink
GetEx Implementation: Get value and TTL atomically (#17)
Browse files Browse the repository at this point in the history
* remove TTL

* set max uint when key does not have ttl set

* keep ret values consistent

* check if expiration is in past

* happy path

* handle redis exec err path in a better way

* do not use return named parameters

* when no ttl is set return nil
  • Loading branch information
LukasJenicek authored Sep 23, 2024
1 parent 40af488 commit 509f7d6
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 51 deletions.
9 changes: 4 additions & 5 deletions _examples/compose/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import (
)

func main() {

mstore, err := memlru.New[string]()
if err != nil {
panic(err)
log.Fatal(err)
}

rstore, err := redis.New[string](&redis.Config{
Expand All @@ -23,7 +22,7 @@ func main() {
Port: 6379,
})
if err != nil {
panic(err)
log.Fatal(err)
}

// Compose a store chain where we set/get keys in order of:
Expand All @@ -33,7 +32,7 @@ func main() {
// local cache in memory, and search the remote redis cache if its not in memory.
store, err := cachestore.Compose(mstore, rstore)
if err != nil {
panic(err)
log.Fatal(err)
}

ctx := context.Background()
Expand All @@ -44,7 +43,7 @@ func main() {
for i := 0; i < 100; i++ {
err = store.Set(ctx, fmt.Sprintf("foo:%d", i), fmt.Sprintf("value-%d", i))
if err != nil {
panic(err)
log.Fatal(err)
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions cachestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ type Store[V any] interface {
// for that key.
SetEx(ctx context.Context, key string, value V, ttl time.Duration) error

// GetEx returns a stored value with ttl
// duration is nil when key does not have ttl set
GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error)

// BatchSet sets all the values associated to the given keys.
BatchSet(ctx context.Context, keys []string, values []V) error

Expand Down
19 changes: 19 additions & 0 deletions compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ func (cs *ComposeStore[V]) BatchSetEx(ctx context.Context, keys []string, values
return nil
}

func (cs *ComposeStore[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) {
var out V
var ttl *time.Duration
var exists bool
var err error

for _, s := range cs.stores {
out, ttl, exists, err = s.GetEx(ctx, key)
if err != nil {
return out, ttl, exists, err
}
if exists {
break
}
}

return out, ttl, exists, nil
}

func (cs *ComposeStore[V]) Get(ctx context.Context, key string) (V, bool, error) {
var out V
var exists bool
Expand Down
24 changes: 24 additions & 0 deletions memlru/memlru.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ func (m *MemLRU[V]) BatchSetEx(ctx context.Context, keys []string, values []V, t
return nil
}

func (c *MemLRU[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) {
out, exists, err := c.Get(ctx, key)
if err != nil {
return out, nil, false, fmt.Errorf("get %w", err)
}

if !exists {
return out, nil, false, nil
}

item, ok := c.expirationQueue.GetItem(key)
if !ok {
return out, nil, true, nil
}

if item.expiresAt.Before(time.Now()) {
return out, nil, false, nil
}

ttl := item.expiresAt.Sub(time.Now())

return out, &ttl, true, nil
}

func (m *MemLRU[V]) Get(ctx context.Context, key string) (V, bool, error) {
var out V
m.removeExpiredKeys()
Expand Down
13 changes: 13 additions & 0 deletions memlru/ttl.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ func (e *expirationQueue) UpdateLastCheckTime() {
e.mu.Unlock()
}

func (e *expirationQueue) GetItem(key string) (expirationQueueItem, bool) {
e.mu.RLock()
defer e.mu.RUnlock()

for _, i := range e.keys {
if i.key == key {
return i, true
}
}

return expirationQueueItem{}, false
}

func (e *expirationQueue) expiredAt(t time.Time) []string {
e.mu.Lock()
defer e.mu.Unlock()
Expand Down
108 changes: 68 additions & 40 deletions memlru/ttl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func init() {
Expand All @@ -24,59 +24,59 @@ func TestExpirationQueue(t *testing.T) {
q.Push("e", time.Millisecond*300)
q.Push("f", time.Millisecond*50)

assert.Equal(t, 6, q.Len())
require.Equal(t, 6, q.Len())

{
var lastTime time.Time
for _, key := range q.keys {
assert.LessOrEqual(t, lastTime, key.expiresAt)
require.LessOrEqual(t, lastTime, key.expiresAt)
lastTime = key.expiresAt
}
}

q.Push("f", time.Millisecond*500)

assert.Equal(t, 6, q.Len())
require.Equal(t, 6, q.Len())

{
var lastTime time.Time
for _, key := range q.keys {
assert.LessOrEqual(t, lastTime, key.expiresAt)
require.LessOrEqual(t, lastTime, key.expiresAt)
lastTime = key.expiresAt
}
}

{
keys := q.expiredAt(time.Now())

assert.Equal(t, 6, q.Len())
assert.Equal(t, 0, len(keys))
require.Equal(t, 6, q.Len())
require.Equal(t, 0, len(keys))
}

{
keys := q.expiredAt(time.Now().Add(time.Second * -1))

assert.Equal(t, 6, q.Len())
assert.Equal(t, 0, len(keys))
require.Equal(t, 6, q.Len())
require.Equal(t, 0, len(keys))
}

{
keys := q.expiredAt(time.Now().Add(time.Millisecond * 200))

assert.Equal(t, 3, q.Len())
assert.Equal(t, 3, len(keys))
require.Equal(t, 3, q.Len())
require.Equal(t, 3, len(keys))
}

for i := 0; i < 100; i++ {
q.Push("z", time.Millisecond*time.Duration(50+rand.Intn(500)))
}

assert.Equal(t, 4, q.Len())
require.Equal(t, 4, q.Len())

{
var lastTime time.Time
for _, key := range q.keys {
assert.LessOrEqual(t, lastTime, key.expiresAt)
require.LessOrEqual(t, lastTime, key.expiresAt)
lastTime = key.expiresAt
}
}
Expand All @@ -85,26 +85,26 @@ func TestExpirationQueue(t *testing.T) {
q.Push(fmt.Sprintf("key-%d", i), time.Millisecond*time.Duration(50+rand.Intn(500)))
}

assert.Equal(t, 104, q.Len())
require.Equal(t, 104, q.Len())

{
var lastTime time.Time
for _, key := range q.keys {
assert.LessOrEqual(t, lastTime, key.expiresAt)
require.LessOrEqual(t, lastTime, key.expiresAt)
lastTime = key.expiresAt
}
}

{
keys := q.expiredAt(time.Now())
assert.Equal(t, 104, q.Len())
assert.Equal(t, 0, len(keys))
require.Equal(t, 104, q.Len())
require.Equal(t, 0, len(keys))
}

{
keys := q.expiredAt(time.Now().Add(time.Second * 10))
assert.Equal(t, 0, q.Len())
assert.Equal(t, 104, len(keys))
require.Equal(t, 0, q.Len())
require.Equal(t, 104, len(keys))
}

for i := 0; i < 100; i++ {
Expand All @@ -114,19 +114,19 @@ func TestExpirationQueue(t *testing.T) {
{
var lastTime time.Time
for _, key := range q.keys {
assert.LessOrEqual(t, lastTime, key.expiresAt)
require.LessOrEqual(t, lastTime, key.expiresAt)
lastTime = key.expiresAt
}
}

assert.Equal(t, 100, q.Len())
require.Equal(t, 100, q.Len())
}

func TestSetEx(t *testing.T) {
ctx := context.Background()

c, err := NewWithSize[[]byte](50)
assert.NoError(t, err)
require.NoError(t, err)

{
keys := []string{}
Expand All @@ -136,26 +136,26 @@ func TestSetEx(t *testing.T) {
// SetEx with time 0 is the same as just a Set, because there is no expiry time
// aka, the key doesn't expire.
err := c.SetEx(ctx, key, []byte("a"), time.Duration(0))
assert.NoError(t, err)
require.NoError(t, err)
keys = append(keys, key)
}

for _, key := range keys {
buf, exists, err := c.Get(ctx, key)
assert.True(t, exists)
assert.NoError(t, err)
assert.NotNil(t, buf)
require.True(t, exists)
require.NoError(t, err)
require.NotNil(t, buf)

exists, err = c.Exists(ctx, key)
assert.NoError(t, err)
assert.True(t, exists)
require.NoError(t, err)
require.True(t, exists)
}

values, batchExists, err := c.BatchGet(ctx, keys)
assert.NoError(t, err)
require.NoError(t, err)
for i := range values {
assert.NotNil(t, values[i])
assert.True(t, batchExists[i])
require.NotNil(t, values[i])
require.True(t, batchExists[i])
}
}

Expand All @@ -164,27 +164,55 @@ func TestSetEx(t *testing.T) {
for i := 0; i < 20; i++ {
key := fmt.Sprintf("key-%d", i)
err := c.SetEx(ctx, key, []byte("a"), time.Second*10) // a key that expires in 10 seconds
assert.NoError(t, err)
require.NoError(t, err)
keys = append(keys, key)
}

for _, key := range keys {
buf, exists, err := c.Get(ctx, key)
assert.NoError(t, err)
assert.NotNil(t, buf)
assert.True(t, exists)
require.NoError(t, err)
require.NotNil(t, buf)
require.True(t, exists)

exists, err = c.Exists(ctx, key)
assert.NoError(t, err)
assert.True(t, exists)
require.NoError(t, err)
require.True(t, exists)
}

values, batchExists, err := c.BatchGet(ctx, keys)
assert.NoError(t, err)
require.NoError(t, err)

for i := range values {
assert.NotNil(t, values[i])
assert.True(t, batchExists[i])
require.NotNil(t, values[i])
require.True(t, batchExists[i])
}
}
}

func TestGetEx(t *testing.T) {
ctx := context.Background()

cache, err := NewWithSize[[]byte](50)
require.NoError(t, err)

err = cache.SetEx(ctx, "hi", []byte("bye"), 10*time.Second)
require.NoError(t, err)

v, ttl, exists, err := cache.GetEx(ctx, "hi")
require.NoError(t, err)
require.True(t, exists)
require.InDelta(t, 10*time.Second, *ttl, float64(1*time.Second), "TTL are not equal within the allowed delta")
require.Equal(t, []byte("bye"), v)

v, ttl, exists, err = cache.GetEx(ctx, "not-found")
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, ttl)

err = cache.Set(ctx, "without-ttl", []byte("hello"))
require.NoError(t, err)

v, ttl, exists, err = cache.GetEx(ctx, "without-ttl")
require.NoError(t, err)
require.Nil(t, ttl)
}
6 changes: 6 additions & 0 deletions nostore/nostore.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ func (s *NoStore[V]) SetEx(ctx context.Context, key string, value V, ttl time.Du
return nil
}

func (c *NoStore[V]) GetEx(ctx context.Context, key string) (V, *time.Duration, bool, error) {
var out V
ttl := time.Duration(0)
return out, &ttl, false, nil
}

func (s *NoStore[V]) BatchSet(ctx context.Context, keys []string, values []V) error {
return nil
}
Expand Down
Loading

0 comments on commit 509f7d6

Please sign in to comment.