diff --git a/pkg/app/context.go b/pkg/app/context.go index 4e2d4be33..510c9725f 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -328,19 +328,71 @@ func (ctx *RequestContext) GetIndex() int8 { return ctx.index } +// SetIndex reset the handler's execution index +// Disclaimer: You can loop yourself to deal with this, use wisely. +func (ctx *RequestContext) SetIndex(index int8) { + ctx.index = index +} + type HandlerFunc func(c context.Context, ctx *RequestContext) // HandlersChain defines a HandlerFunc array. type HandlersChain []HandlerFunc -var handlerNames = make(map[uintptr]string) +type HandlerNameOperator interface { + SetHandlerName(handler HandlerFunc, name string) + GetHandlerName(handler HandlerFunc) string +} + +func SetHandlerNameOperator(o HandlerNameOperator) { + inbuiltHandlerNameOperator = o +} + +type inbuiltHandlerNameOperatorStruct struct { + handlerNames map[uintptr]string +} + +func (o *inbuiltHandlerNameOperatorStruct) SetHandlerName(handler HandlerFunc, name string) { + o.handlerNames[getFuncAddr(handler)] = name +} + +func (o *inbuiltHandlerNameOperatorStruct) GetHandlerName(handler HandlerFunc) string { + return o.handlerNames[getFuncAddr(handler)] +} + +type concurrentHandlerNameOperatorStruct struct { + handlerNames map[uintptr]string + lock sync.RWMutex +} + +func (o *concurrentHandlerNameOperatorStruct) SetHandlerName(handler HandlerFunc, name string) { + o.lock.Lock() + defer o.lock.Unlock() + o.handlerNames[getFuncAddr(handler)] = name +} + +func (o *concurrentHandlerNameOperatorStruct) GetHandlerName(handler HandlerFunc) string { + o.lock.RLock() + defer o.lock.RUnlock() + return o.handlerNames[getFuncAddr(handler)] +} + +func SetConcurrentHandlerNameOperator() { + SetHandlerNameOperator(&concurrentHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) +} + +func init() { + inbuiltHandlerNameOperator = &inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}} +} + +var inbuiltHandlerNameOperator HandlerNameOperator func SetHandlerName(handler HandlerFunc, name string) { - handlerNames[getFuncAddr(handler)] = name + inbuiltHandlerNameOperator.SetHandlerName(handler, name) } func GetHandlerName(handler HandlerFunc) string { - return handlerNames[getFuncAddr(handler)] + return inbuiltHandlerNameOperator.GetHandlerName(handler) } func getFuncAddr(v interface{}) uintptr { diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index c065d482c..85d686c08 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1344,16 +1344,44 @@ func TestGetWriter(t *testing.T) { func TestIndex(t *testing.T) { ctx := NewContext(0) ctx.ResetWithoutConn() - res := ctx.GetIndex() exc := int8(-1) + res := ctx.GetIndex() + assert.DeepEqual(t, exc, res) + ctx.SetIndex(int8(1)) + res = ctx.GetIndex() + exc = int8(1) assert.DeepEqual(t, exc, res) } +func TestConcurrentHandlerName(t *testing.T) { + SetConcurrentHandlerNameOperator() + defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) + h := func(c context.Context, ctx *RequestContext) {} + SetHandlerName(h, "test1") + for i := 0; i < 50; i++ { + go func() { + name := GetHandlerName(h) + assert.DeepEqual(t, "test1", name) + }() + } + + time.Sleep(time.Second) + + go func() { + SetHandlerName(h, "test2") + }() + + time.Sleep(time.Second) + + name := GetHandlerName(h) + assert.DeepEqual(t, "test2", name) +} + func TestHandlerName(t *testing.T) { h := func(c context.Context, ctx *RequestContext) {} - SetHandlerName(h, "test") + SetHandlerName(h, "test1") name := GetHandlerName(h) - assert.DeepEqual(t, "test", name) + assert.DeepEqual(t, "test1", name) } func TestHijack(t *testing.T) { @@ -1644,3 +1672,23 @@ func TestRequestContext_VisitAll(t *testing.T) { }) }) } + +func BenchmarkInbuiltHandlerNameOperator(b *testing.B) { + for n := 0; n < b.N; n++ { + fn := func(c context.Context, ctx *RequestContext) { + } + SetHandlerName(fn, fmt.Sprintf("%d", n)) + GetHandlerName(fn) + } +} + +func BenchmarkConcurrentHandlerNameOperator(b *testing.B) { + SetConcurrentHandlerNameOperator() + defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) + for n := 0; n < b.N; n++ { + fn := func(c context.Context, ctx *RequestContext) { + } + SetHandlerName(fn, fmt.Sprintf("%d", n)) + GetHandlerName(fn) + } +}