diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 66cb01e85..ef19cd751 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -369,11 +369,6 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { return } - // general case - if s.EnableTrace { - traceCtl.DoFinish(cc, ctx, err) - } - if connectionClose { return errShortConnection } @@ -382,6 +377,10 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { if s.IdleTimeout == 0 { return } + // general case + if s.EnableTrace { + traceCtl.DoFinish(cc, ctx, err) + } ctx.ResetWithoutConn() } diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index e34bb3f9a..5f8f78247 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -359,6 +359,52 @@ func TestExpect100ContinueHandler(t *testing.T) { assert.DeepEqual(t, "", string(response.Body())) } +type mockController struct { + FinishTimes int +} + +func (m *mockController) Append(col tracer.Tracer) {} + +func (m *mockController) DoStart(ctx context.Context, c *app.RequestContext) context.Context { + return ctx +} + +func (m *mockController) DoFinish(ctx context.Context, c *app.RequestContext, err error) { + m.FinishTimes++ +} + +func (m *mockController) HasTracer() bool { return true } + +func (m *mockController) reset() { m.FinishTimes = 0 } + +func TestTraceDoFinishTimes(t *testing.T) { + server := &Server{} + server.eventStackPool = pool + server.EnableTrace = true + reqCtx := &app.RequestContext{} + controller := &mockController{} + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + ti := traceinfo.NewTraceInfo() + ti.Stats().SetLevel(2) + reqCtx.SetTraceInfo(&mockTraceInfo{ti}) + return reqCtx + }}, + controller: controller, + } + // for disableKeepAlive case + server.DisableKeepalive = true + err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) + assert.DeepEqual(t, 1, controller.FinishTimes) + // for IdleTimeout==0 case + server.IdleTimeout = 0 + controller.reset() + err = server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) + assert.True(t, errors.Is(err, errs.ErrShortConnection)) + assert.DeepEqual(t, 1, controller.FinishTimes) +} + type mockCore struct { ctxPool *sync.Pool controller tracer.Controller