diff --git a/statistic/memory/memory.go b/statistic/memory/memory.go index e4c1f52ba..4de5b6ae6 100644 --- a/statistic/memory/memory.go +++ b/statistic/memory/memory.go @@ -21,13 +21,14 @@ type User struct { recv uint64 lastSent uint64 lastRecv uint64 - speedLock sync.Mutex + speedLock sync.RWMutex sendSpeed uint64 recvSpeed uint64 hash string - ipTableLock sync.Mutex + ipTableLock sync.RWMutex ipTable map[string]struct{} maxIPNum int + limiterLock sync.RWMutex sendLimiter *rate.Limiter recvLimiter *rate.Limiter ctx context.Context @@ -72,8 +73,8 @@ func (u *User) DelIP(ip string) bool { } func (u *User) GetIP() int { - u.ipTableLock.Lock() - defer u.ipTableLock.Unlock() + u.ipTableLock.RLock() + defer u.ipTableLock.RUnlock() return len(u.ipTable) } @@ -86,9 +87,12 @@ func (u *User) GetIPLimit() int { } func (u *User) AddTraffic(sent, recv int) { - if u.sendLimiter != nil && sent != 0 { + u.limiterLock.Lock() + defer u.limiterLock.Unlock() + + if u.sendLimiter != nil && sent >= 0 { u.sendLimiter.WaitN(u.ctx, sent) - } else if u.recvLimiter != nil && recv != 0 { + } else if u.recvLimiter != nil && recv >= 0 { u.recvLimiter.WaitN(u.ctx, recv) } atomic.AddUint64(&u.sent, uint64(sent)) @@ -96,6 +100,9 @@ func (u *User) AddTraffic(sent, recv int) { } func (u *User) SetSpeedLimit(send, recv int) { + u.limiterLock.Lock() + defer u.limiterLock.Unlock() + if send <= 0 { u.sendLimiter = nil } else { @@ -109,6 +116,9 @@ func (u *User) SetSpeedLimit(send, recv int) { } func (u *User) GetSpeedLimit() (send, recv int) { + u.limiterLock.RLock() + defer u.limiterLock.RUnlock() + sendLimit := 0 recvLimit := 0 if u.sendLimiter != nil { @@ -159,8 +169,8 @@ func (u *User) speedUpdater() { } func (u *User) GetSpeed() (uint64, uint64) { - u.speedLock.Lock() - defer u.speedLock.Unlock() + u.speedLock.RLock() + defer u.speedLock.RUnlock() return u.sendSpeed, u.recvSpeed } diff --git a/statistic/memory/memory_test.go b/statistic/memory/memory_test.go index 8d735ff20..a8a71419b 100644 --- a/statistic/memory/memory_test.go +++ b/statistic/memory/memory_test.go @@ -76,10 +76,8 @@ func TestMemoryAuth(t *testing.T) { go func() { for { k := 100 - select { - case <-time.After(time.Second / time.Duration(k)): - user.AddTraffic(200/k, 100/k) - } + time.Sleep(time.Second / time.Duration(k)) + user.AddTraffic(200/k, 100/k) } }() time.Sleep(time.Second * 4)