Skip to content

Commit

Permalink
feat(http1): add option to disable request context pool
Browse files Browse the repository at this point in the history
test: add ut

refactor: use env instead of option

optimize: method to get & put
  • Loading branch information
welkeyever committed Aug 21, 2024
1 parent a64f390 commit b43cd0c
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pkg/app/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2307,7 +2307,7 @@ func TestClientDoWithDialFunc(t *testing.T) {

func TestClientState(t *testing.T) {
opt := config.NewOptions([]config.Option{})
opt.Addr = ":10037"
opt.Addr = "127.0.0.1:10037"
engine := route.NewEngine(opt)
go engine.Run()

Expand Down
20 changes: 20 additions & 0 deletions pkg/common/utils/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package utils

import (
"os"
"strconv"
"strings"

"github.com/cloudwego/hertz/pkg/common/errors"
)

// Get bool from env
func GetBoolFromEnv(key string) (bool, error) {
value, isExist := os.LookupEnv(key)
if !isExist {
return false, errors.NewPublic("env not exist")
}

value = strings.TrimSpace(value)
return strconv.ParseBool(value)
}
30 changes: 26 additions & 4 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/tracer/stats"
"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand All @@ -41,6 +42,12 @@ import (
"github.com/cloudwego/hertz/pkg/protocol/suite"
)

func init() {
if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil {
disabaleRequestContextPool = b
}
}

// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/1.1's TLS setup.
// Also used for server addressing
Expand All @@ -51,6 +58,8 @@ var (
errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil)
errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")

disabaleRequestContextPool = false
)

type Option struct {
Expand Down Expand Up @@ -80,6 +89,20 @@ type Server struct {
eventStackPool *sync.Pool
}

func (s Server) getRequestContext() *app.RequestContext {
if disabaleRequestContextPool {
return &app.RequestContext{}
}
return s.Core.GetCtxPool().Get().(*app.RequestContext)
}

func (s Server) putRequestContext(ctx *app.RequestContext) {
if disabaleRequestContextPool {
return
}
s.Core.GetCtxPool().Put(ctx)
}

func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
var (
zr network.Reader
Expand All @@ -97,8 +120,8 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
// 1. Get a request context
// 2. Prepare it
// 3. Process it
// 4. Reset and recycle
ctx = s.Core.GetCtxPool().Get().(*app.RequestContext)
// 4. Reset and recycle(in pooled mode)
ctx = s.getRequestContext()

traceCtl = s.Core.GetTracer()
eventsToTrigger *eventStack
Expand Down Expand Up @@ -138,8 +161,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
return
}

ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
s.putRequestContext(ctx)
}()

ctx.HTMLRender = s.HTMLRender
Expand Down
40 changes: 40 additions & 0 deletions pkg/protocol/http1/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ func TestDefaultWriter(t *testing.T) {
assert.DeepEqual(t, "hello, hertz", string(response.Body()))
}

func TestServerDisableReqCtxPool(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if ctx.GetString("POOL_KEY") != "in pool" {
t.Fatal("reqCtx is not in pool")
}
},
isRunning: true,
}
defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err := server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
disabaleRequestContextPool = true
defer func() {
// reset global variable
disabaleRequestContextPool = false
}()
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if len(ctx.GetString("POOL_KEY")) != 0 {
t.Fatal("must not get pool key")
}
},
isRunning: true,
}
defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err = server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
}

func TestHijackResponseWriter(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
Expand Down

0 comments on commit b43cd0c

Please sign in to comment.