From cd4e447597af13787e0cef5b1f0c3dc9a519ff93 Mon Sep 17 00:00:00 2001 From: taobig Date: Sat, 8 Jun 2024 10:09:38 +0800 Subject: [PATCH] Update LockedMap & tests --- lockedmap/lockedmap.go | 11 ++ lockedmap/lockedmap_race_bad_practice_test.go | 35 ++++ .../lockedmap_race_best_practice_test.go | 173 ++++++++++++++++++ lockedmap/lockedmap_test.go | 86 +++------ 4 files changed, 243 insertions(+), 62 deletions(-) create mode 100644 lockedmap/lockedmap_race_bad_practice_test.go create mode 100644 lockedmap/lockedmap_race_best_practice_test.go diff --git a/lockedmap/lockedmap.go b/lockedmap/lockedmap.go index 42a9fb3..3e2fc76 100644 --- a/lockedmap/lockedmap.go +++ b/lockedmap/lockedmap.go @@ -80,3 +80,14 @@ func (m *LockedMap[K, V]) Values() []V { return maps.Values(m.data) } + +func (m *LockedMap[K, V]) Range(f func(key K, value V) bool) { + m.mu.Lock() + defer m.mu.Unlock() + + for key, value := range m.data { + if !f(key, value) { + break + } + } +} diff --git a/lockedmap/lockedmap_race_bad_practice_test.go b/lockedmap/lockedmap_race_bad_practice_test.go new file mode 100644 index 0000000..e8d08eb --- /dev/null +++ b/lockedmap/lockedmap_race_bad_practice_test.go @@ -0,0 +1,35 @@ +//go:build go1.20 + +package lockedmap + +import ( + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +func TestDataRaceError(t *testing.T) { + unsafeLockedMap.Set("a", &TestStruct{Name: "Alice"}) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + val, ok := unsafeLockedMap.Get("a") + assert.True(t, ok) + val.Name = "Alice" + time.Now().Format(time.RFC3339Nano) //write + }() + + wg.Add(1) + go func() { + defer wg.Done() + + val, ok := unsafeLockedMap.Get("a") + assert.True(t, ok) + _ = val.Name //read + }() + + wg.Wait() +} diff --git a/lockedmap/lockedmap_race_best_practice_test.go b/lockedmap/lockedmap_race_best_practice_test.go new file mode 100644 index 0000000..60b2fd0 --- /dev/null +++ b/lockedmap/lockedmap_race_best_practice_test.go @@ -0,0 +1,173 @@ +package lockedmap + +import ( + "sync" + "testing" + "time" +) + +// LockedMap的value是struct,并发修改value的属性不会有问题(因为Get()返回的是value的拷贝)。 +var safeLockedMap = New[string, TestStruct]() + +// 虽然LockedMap的value是指针,但是返回的结构体是只读的,所以不会存在修改value的属性的问题。 +var safeLockedMap2 = New[string, *ReadonlyTestStruct]() + +// LockedMap的value是指针,并发修改value的属性会有问题(因为Get()返回的是value的指针)。 +var unsafeLockedMap = New[string, *TestStruct]() + +type TestStruct struct { + Name string + Age int64 +} + +type ReadonlyTestStruct struct { + name string +} + +func (r ReadonlyTestStruct) Name() string { + return r.name +} + +func TestTestStructDataRaceOK(t *testing.T) { + safeLockedMap.Set("a", TestStruct{Name: "Alice", Age: 20}) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap.Set("a", TestStruct{Name: "Alice" + time.Now().Format(time.RFC3339Nano)}) + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _, _ = safeLockedMap.Get("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap.Remove("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap.RemoveAll() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap.Size() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap.Exists("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap.Keys() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap.Values() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap.Range(func(key string, value TestStruct) bool { + return true + }) + }() + + wg.Wait() +} + +func TestReadonlyTestStructDataRaceOK(t *testing.T) { + safeLockedMap2.Set("a", &ReadonlyTestStruct{name: "Alice"}) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap2.Set("a", &ReadonlyTestStruct{name: "Alice" + time.Now().Format(time.RFC3339Nano)}) + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _, _ = safeLockedMap2.Get("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap2.Remove("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap2.RemoveAll() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap2.Size() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap2.Exists("a") + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap2.Keys() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + _ = safeLockedMap2.Values() + }() + + wg.Add(1) + go func() { + defer wg.Done() + + safeLockedMap2.Range(func(key string, value *ReadonlyTestStruct) bool { + return true + }) + }() + + wg.Wait() +} diff --git a/lockedmap/lockedmap_test.go b/lockedmap/lockedmap_test.go index 8234a7a..6b69c5f 100644 --- a/lockedmap/lockedmap_test.go +++ b/lockedmap/lockedmap_test.go @@ -5,7 +5,15 @@ import ( "testing" ) -func TestSetGet(t *testing.T) { +func TestLockedMap_Exists(t *testing.T) { + m := New[string, int]() + m.Set("a", 1) + if !m.Exists("a") { + t.Errorf("Exists failed") + } +} + +func TestLockedMap_Set(t *testing.T) { m := New[string, int]() m.Set("a", 1) if v, ok := m.Get("a"); !ok || v != 1 { @@ -13,15 +21,15 @@ func TestSetGet(t *testing.T) { } } -func TestExists(t *testing.T) { +func TestLockedMap_Get(t *testing.T) { m := New[string, int]() m.Set("a", 1) - if !m.Exists("a") { - t.Errorf("Exists failed") + if v, ok := m.Get("a"); !ok || v != 1 { + t.Errorf("Get failed") } } -func TestRemove(t *testing.T) { +func TestLockedMap_Remove(t *testing.T) { m := New[string, int]() m.Set("a", 1) m.Remove("a") @@ -30,7 +38,7 @@ func TestRemove(t *testing.T) { } } -func TestRemoveAll(t *testing.T) { +func TestLockedMap_RemoveAll(t *testing.T) { m := New[string, int]() m.Set("a", 1) m.RemoveAll() @@ -40,7 +48,7 @@ func TestRemoveAll(t *testing.T) { assert.Equal(t, 0, m.Size()) } -func TestSize(t *testing.T) { +func TestLockedMap_Size(t *testing.T) { m := New[string, int]() m.Set("a", 1) if m.Size() != 1 { @@ -48,7 +56,7 @@ func TestSize(t *testing.T) { } } -func TestKeys(t *testing.T) { +func TestLockedMap_Keys(t *testing.T) { m := New[string, int]() m.Set("a", 1) if len(m.Keys()) != 1 { @@ -56,7 +64,7 @@ func TestKeys(t *testing.T) { } } -func TestValues(t *testing.T) { +func TestLockedMap_Values(t *testing.T) { m := New[string, int]() m.Set("a", 1) if len(m.Values()) != 1 { @@ -64,58 +72,12 @@ func TestValues(t *testing.T) { } } -func TestConcurrentSetGet(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Get("a") - } -} - -func TestConcurrentRemove(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Remove("a") - } -} - -func TestConcurrentSize(t *testing.T) { +func TestLockedMap_Range(t *testing.T) { m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Size() - } -} - -func TestConcurrentExists(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Exists("a") - } -} - -func TestConcurrentKeys(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Keys() - } -} - -func TestConcurrentGet(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.Get("a") - } -} - -func TestConcurrentRemoveAll(t *testing.T) { - m := New[string, int]() - for i := 0; i < 1000; i++ { - go m.Set("a", i) - go m.RemoveAll() - } + m.Set("a", 1) + m.Set("b", 2) + m.Set("c", 3) + m.Range(func(key string, value int) bool { + return true + }) }