diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 5f3e8c62f..831e45c7a 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2307,7 +2307,7 @@ func TestClientDoWithDialFunc(t *testing.T) { func TestClientState(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = ":10037" + opt.Addr = "127.0.0.1:10037" engine := route.NewEngine(opt) go engine.Run() diff --git a/pkg/common/utils/env.go b/pkg/common/utils/env.go new file mode 100644 index 000000000..72ca2c6eb --- /dev/null +++ b/pkg/common/utils/env.go @@ -0,0 +1,20 @@ +package utils + +import ( + "os" + "strconv" + "strings" + + "github.com/cloudwego/hertz/pkg/common/errors" +) + +// Get bool from env +func GetBoolFromEnv(key string) (bool, error) { + value, isExist := os.LookupEnv(key) + if !isExist { + return false, errors.NewPublic("env not exist") + } + + value = strings.TrimSpace(value) + return strconv.ParseBool(value) +} diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 3ea659603..186d430b5 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -32,6 +32,7 @@ import ( errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -41,6 +42,12 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/suite" ) +func init() { + if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil { + disabaleRequestContextPool = b + } +} + // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/1.1's TLS setup. // Also used for server addressing @@ -51,6 +58,8 @@ var ( errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil) errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") + + disabaleRequestContextPool = false ) type Option struct { @@ -80,6 +89,20 @@ type Server struct { eventStackPool *sync.Pool } +func (s Server) getRequestContext() *app.RequestContext { + if disabaleRequestContextPool { + return &app.RequestContext{} + } + return s.Core.GetCtxPool().Get().(*app.RequestContext) +} + +func (s Server) putRequestContext(ctx *app.RequestContext) { + if disabaleRequestContextPool { + return + } + s.Core.GetCtxPool().Put(ctx) +} + func (s Server) Serve(c context.Context, conn network.Conn) (err error) { var ( zr network.Reader @@ -97,8 +120,8 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { // 1. Get a request context // 2. Prepare it // 3. Process it - // 4. Reset and recycle - ctx = s.Core.GetCtxPool().Get().(*app.RequestContext) + // 4. Reset and recycle(in pooled mode) + ctx = s.getRequestContext() traceCtl = s.Core.GetTracer() eventsToTrigger *eventStack @@ -138,8 +161,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { return } - ctx.Reset() - s.Core.GetCtxPool().Put(ctx) + s.putRequestContext(ctx) }() ctx.HTMLRender = s.HTMLRender diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index 2263ece77..d478b36fc 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -218,6 +218,46 @@ func TestDefaultWriter(t *testing.T) { assert.DeepEqual(t, "hello, hertz", string(response.Body())) } +func TestServerDisableReqCtxPool(t *testing.T) { + server := &Server{} + reqCtx := &app.RequestContext{} + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + reqCtx.Set("POOL_KEY", "in pool") + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + if ctx.GetString("POOL_KEY") != "in pool" { + t.Fatal("reqCtx is not in pool") + } + }, + isRunning: true, + } + defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") + err := server.Serve(context.TODO(), defaultConn) + assert.Nil(t, err) + disabaleRequestContextPool = true + defer func() { + // reset global variable + disabaleRequestContextPool = false + }() + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + reqCtx.Set("POOL_KEY", "in pool") + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + if len(ctx.GetString("POOL_KEY")) != 0 { + t.Fatal("must not get pool key") + } + }, + isRunning: true, + } + defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") + err = server.Serve(context.TODO(), defaultConn) + assert.Nil(t, err) +} + func TestHijackResponseWriter(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{}