diff --git a/client.go b/client.go index 4b671c2e..520e0a81 100644 --- a/client.go +++ b/client.go @@ -1223,8 +1223,11 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { if r.trace == nil && r.client.trace { r.trace = &clientTrace{} } + + ctx := r.ctx + if r.trace != nil { - r.ctx = r.trace.createContext(r.Context()) + ctx = r.trace.createContext(r.Context()) } // setup url and host @@ -1260,7 +1263,6 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { for _, cookie := range r.Cookies { req.AddCookie(cookie) } - ctx := r.ctx if r.isSaveResponse && r.downloadCallback != nil { var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { return &callbackReader{ @@ -1275,10 +1277,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { interval: r.downloadCallbackInterval, } } - if ctx == nil { - ctx = context.Background() - } - ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) + ctx = context.WithValue(r.Context(), wrapResponseBodyKey, wrap) } if ctx != nil { req = req.WithContext(ctx) @@ -1371,7 +1370,9 @@ func (c *Client) do(r *Request) (resp *Response, err error) { if r.dumpBuffer != nil { r.dumpBuffer.Reset() } - r.trace = nil + if r.trace != nil { + r.trace = &clientTrace{} + } resp.body = nil resp.result = nil resp.error = nil