Skip to content

Commit

Permalink
Update LockedMap & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
taobig committed Jun 8, 2024
1 parent 77704dd commit cd4e447
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 62 deletions.
11 changes: 11 additions & 0 deletions lockedmap/lockedmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,14 @@ func (m *LockedMap[K, V]) Values() []V {

return maps.Values(m.data)
}

func (m *LockedMap[K, V]) Range(f func(key K, value V) bool) {
m.mu.Lock()
defer m.mu.Unlock()

for key, value := range m.data {
if !f(key, value) {
break
}
}
}
35 changes: 35 additions & 0 deletions lockedmap/lockedmap_race_bad_practice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//go:build go1.20

package lockedmap

import (
"github.com/stretchr/testify/assert"
"sync"
"testing"
"time"
)

func TestDataRaceError(t *testing.T) {
unsafeLockedMap.Set("a", &TestStruct{Name: "Alice"})

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

val, ok := unsafeLockedMap.Get("a")
assert.True(t, ok)
val.Name = "Alice" + time.Now().Format(time.RFC3339Nano) //write
}()

wg.Add(1)
go func() {
defer wg.Done()

val, ok := unsafeLockedMap.Get("a")
assert.True(t, ok)
_ = val.Name //read
}()

wg.Wait()
}
173 changes: 173 additions & 0 deletions lockedmap/lockedmap_race_best_practice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package lockedmap

import (
"sync"
"testing"
"time"
)

// LockedMap的value是struct,并发修改value的属性不会有问题(因为Get()返回的是value的拷贝)。
var safeLockedMap = New[string, TestStruct]()

// 虽然LockedMap的value是指针,但是返回的结构体是只读的,所以不会存在修改value的属性的问题。
var safeLockedMap2 = New[string, *ReadonlyTestStruct]()

// LockedMap的value是指针,并发修改value的属性会有问题(因为Get()返回的是value的指针)。
var unsafeLockedMap = New[string, *TestStruct]()

type TestStruct struct {
Name string
Age int64
}

type ReadonlyTestStruct struct {
name string
}

func (r ReadonlyTestStruct) Name() string {
return r.name
}

func TestTestStructDataRaceOK(t *testing.T) {
safeLockedMap.Set("a", TestStruct{Name: "Alice", Age: 20})

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap.Set("a", TestStruct{Name: "Alice" + time.Now().Format(time.RFC3339Nano)})
}()

wg.Add(1)
go func() {
defer wg.Done()

_, _ = safeLockedMap.Get("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap.Remove("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap.RemoveAll()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap.Size()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap.Exists("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap.Keys()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap.Values()
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap.Range(func(key string, value TestStruct) bool {
return true
})
}()

wg.Wait()
}

func TestReadonlyTestStructDataRaceOK(t *testing.T) {
safeLockedMap2.Set("a", &ReadonlyTestStruct{name: "Alice"})

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap2.Set("a", &ReadonlyTestStruct{name: "Alice" + time.Now().Format(time.RFC3339Nano)})
}()

wg.Add(1)
go func() {
defer wg.Done()

_, _ = safeLockedMap2.Get("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap2.Remove("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap2.RemoveAll()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap2.Size()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap2.Exists("a")
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap2.Keys()
}()

wg.Add(1)
go func() {
defer wg.Done()

_ = safeLockedMap2.Values()
}()

wg.Add(1)
go func() {
defer wg.Done()

safeLockedMap2.Range(func(key string, value *ReadonlyTestStruct) bool {
return true
})
}()

wg.Wait()
}
86 changes: 24 additions & 62 deletions lockedmap/lockedmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@ import (
"testing"
)

func TestSetGet(t *testing.T) {
func TestLockedMap_Exists(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if !m.Exists("a") {
t.Errorf("Exists failed")
}
}

func TestLockedMap_Set(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if v, ok := m.Get("a"); !ok || v != 1 {
t.Errorf("Get failed")
}
}

func TestExists(t *testing.T) {
func TestLockedMap_Get(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if !m.Exists("a") {
t.Errorf("Exists failed")
if v, ok := m.Get("a"); !ok || v != 1 {
t.Errorf("Get failed")
}
}

func TestRemove(t *testing.T) {
func TestLockedMap_Remove(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Remove("a")
Expand All @@ -30,7 +38,7 @@ func TestRemove(t *testing.T) {
}
}

func TestRemoveAll(t *testing.T) {
func TestLockedMap_RemoveAll(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.RemoveAll()
Expand All @@ -40,82 +48,36 @@ func TestRemoveAll(t *testing.T) {
assert.Equal(t, 0, m.Size())
}

func TestSize(t *testing.T) {
func TestLockedMap_Size(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if m.Size() != 1 {
t.Errorf("Size failed")
}
}

func TestKeys(t *testing.T) {
func TestLockedMap_Keys(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if len(m.Keys()) != 1 {
t.Errorf("Keys failed")
}
}

func TestValues(t *testing.T) {
func TestLockedMap_Values(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
if len(m.Values()) != 1 {
t.Errorf("Values failed")
}
}

func TestConcurrentSetGet(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Get("a")
}
}

func TestConcurrentRemove(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Remove("a")
}
}

func TestConcurrentSize(t *testing.T) {
func TestLockedMap_Range(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Size()
}
}

func TestConcurrentExists(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Exists("a")
}
}

func TestConcurrentKeys(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Keys()
}
}

func TestConcurrentGet(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.Get("a")
}
}

func TestConcurrentRemoveAll(t *testing.T) {
m := New[string, int]()
for i := 0; i < 1000; i++ {
go m.Set("a", i)
go m.RemoveAll()
}
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
m.Range(func(key string, value int) bool {
return true
})
}

0 comments on commit cd4e447

Please sign in to comment.