diff --git a/internal/persistence_test.go b/internal/persistence_test.go index e616182..73f1630 100644 --- a/internal/persistence_test.go +++ b/internal/persistence_test.go @@ -4,6 +4,7 @@ import ( "os" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -32,6 +33,7 @@ func TestStorePersistence_Simple(t *testing.T) { require.Equal(t, 10, store.policy.window.Len()) require.Equal(t, 10, store.policy.slru.protected.Len()) require.Equal(t, 10, store.policy.slru.probation.Len()) + require.Equal(t, 30, int(store.policy.weightedSize)) require.ElementsMatch(t, strings.Split("9/8/7/6/5/4/3/2/1/0", "/"), strings.Split(store.policy.slru.protected.display(), "/"), @@ -76,6 +78,7 @@ func TestStorePersistence_Simple(t *testing.T) { require.Equal(t, 10, new.policy.window.Len()) require.Equal(t, 10, new.policy.slru.protected.Len()) require.Equal(t, 10, new.policy.slru.probation.Len()) + require.Equal(t, 30, int(new.policy.weightedSize)) require.ElementsMatch(t, strings.Split("9/8/7/6/5/4/3/2/1/0", "/"), @@ -184,3 +187,78 @@ func TestStorePersistence_Resize(t *testing.T) { } } + +func TestStorePersistence_Readonly(t *testing.T) { + store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) + for i := 0; i < 1000; i++ { + _ = store.Set(i, i, 1, 0) + } + for i := 0; i < 500; i++ { + _, _ = store.Get(i) + } + store.Wait() + var counter atomic.Uint64 + persistDone := make(chan bool) + + v, ok := store.Get(100) + require.True(t, ok) + require.Equal(t, 100, v) + + go func() { + done := false + for !done { + select { + case <-persistDone: + done = true + default: + store.Get(int(counter.Load()) % 1000) + counter.Add(1) + } + } + }() + + go func() { + done := false + i := 0 + for !done { + select { + case <-persistDone: + done = true + default: + store.Set(100, i, 1, 0) + i++ + } + } + }() + + f, err := os.Create("stest") + defer os.Remove("stest") + require.Nil(t, err) + start := counter.Load() + err = store.Persist(0, f) + require.Nil(t, err) + f.Close() + persistDone <- true + + new := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) + f, err = os.Open("stest") + require.Nil(t, err) + err = new.Recover(0, f) + require.Nil(t, err) + f.Close() + + require.True(t, counter.Load()-start > 10) + + oldv, ok := store.Get(100) + require.True(t, ok) + newv, ok := new.Get(100) + require.True(t, ok) + require.NotEqual(t, oldv, newv) + + for i := 0; i < 5000; i++ { + new.Get(i) + new.Set(i, 123, 1, 0) + } + new.Wait() + +} diff --git a/internal/store.go b/internal/store.go index c99ec63..002ff34 100644 --- a/internal/store.go +++ b/internal/store.go @@ -779,6 +779,13 @@ func (m *StoreMeta) Persist(writer io.Writer, blockEncoder *gob.Encoder) error { func (s *Store[K, V]) Persist(version uint64, writer io.Writer) error { blockEncoder := gob.NewEncoder(writer) s.policyMu.Lock() + defer s.policyMu.Unlock() + + for _, s := range s.shards { + token := s.mu.RLock() + defer s.mu.RUnlock(token) + } + meta := &StoreMeta{ Version: version, StartNano: s.timerwheel.clock.Start.UnixNano(), @@ -802,7 +809,6 @@ func (s *Store[K, V]) Persist(version uint64, writer io.Writer) error { if err != nil { return err } - s.policyMu.Unlock() // write end block block := NewBlock[int](255, bytes.NewBuffer(make([]byte, 0)), blockEncoder) @@ -914,6 +920,7 @@ func (s *Store[K, V]) Recover(version uint64, reader io.Reader) error { entry := pentry.entry() s.policy.window.PushBack(entry) s.insertSimple(entry) + s.policy.weightedSize += uint(entry.policyWeight) } } case 3: // main-probation @@ -937,6 +944,7 @@ func (s *Store[K, V]) Recover(version uint64, reader io.Reader) error { entry := pentry.entry() l2.PushBack(entry) s.insertSimple(entry) + s.policy.weightedSize += uint(entry.policyWeight) } } case 4: // main protected @@ -959,6 +967,7 @@ func (s *Store[K, V]) Recover(version uint64, reader io.Reader) error { entry := pentry.entry() l.PushBack(entry) s.insertSimple(entry) + s.policy.weightedSize += uint(entry.policyWeight) } }