From d831adf1161f14c75928f6cf48da464a8271ba7d Mon Sep 17 00:00:00 2001 From: liwei Date: Mon, 23 Dec 2024 01:28:58 +0800 Subject: [PATCH 1/6] fix: http response body read timeout handle --- internal/protocol/http/fetcher.go | 21 ++++---- internal/protocol/http/fetcher_test.go | 7 +++ internal/protocol/http/timeout_reader.go | 40 +++++++++++++++ internal/protocol/http/timeout_reader_test.go | 51 +++++++++++++++++++ internal/test/httptest.go | 4 ++ 5 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 internal/protocol/http/timeout_reader.go create mode 100644 internal/protocol/http/timeout_reader_test.go diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index 1d3def5ae..0119f0011 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -333,10 +333,18 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { err = NewRequestError(resp.StatusCode, resp.Status) return err } - // Http request success, reset retry times - chunk.retryTimes = 0 + reader := NewTimeoutReader(resp.Body, 30*time.Second) for { - n, err := resp.Body.Read(buf) + n, err := reader.Read(buf) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + // download success, reset retry times + chunk.retryTimes = 0 if n > 0 { _, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded) if err != nil { @@ -344,14 +352,7 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { } chunk.Downloaded += int64(n) } - if err != nil { - if err == io.EOF { - return nil - } - return err - } } - return nil }() if err != nil { // If canceled, do not retry diff --git a/internal/protocol/http/fetcher_test.go b/internal/protocol/http/fetcher_test.go index 8109cbb3d..f6b75dac9 100644 --- a/internal/protocol/http/fetcher_test.go +++ b/internal/protocol/http/fetcher_test.go @@ -132,6 +132,13 @@ func TestFetcher_DownloadLimit(t *testing.T) { downloadNormal(listener, 8, t) } +func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) { + listener := test.StartTestLimitServer(16, 35000) + defer listener.Close() + + downloadError(listener, 1, t) +} + func TestFetcher_DownloadResume(t *testing.T) { listener := test.StartTestFileServer() defer listener.Close() diff --git a/internal/protocol/http/timeout_reader.go b/internal/protocol/http/timeout_reader.go new file mode 100644 index 000000000..f23e16e9a --- /dev/null +++ b/internal/protocol/http/timeout_reader.go @@ -0,0 +1,40 @@ +package http + +import ( + "context" + "io" + "time" +) + +type TimeoutReader struct { + reader io.Reader + timeout time.Duration +} + +func NewTimeoutReader(r io.Reader, timeout time.Duration) *TimeoutReader { + return &TimeoutReader{ + reader: r, + timeout: timeout, + } +} + +func (tr *TimeoutReader) Read(p []byte) (n int, err error) { + ctx, cancel := context.WithTimeout(context.Background(), tr.timeout) + defer cancel() + + done := make(chan struct{}) + var readErr error + var bytesRead int + + go func() { + bytesRead, readErr = tr.reader.Read(p) + close(done) + }() + + select { + case <-done: + return bytesRead, readErr + case <-ctx.Done(): + return 0, ctx.Err() + } +} diff --git a/internal/protocol/http/timeout_reader_test.go b/internal/protocol/http/timeout_reader_test.go new file mode 100644 index 000000000..337d55b34 --- /dev/null +++ b/internal/protocol/http/timeout_reader_test.go @@ -0,0 +1,51 @@ +package http + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + "time" +) + +func TestTimeoutReader_Read(t *testing.T) { + data := []byte("Hello, World!") + reader := bytes.NewReader(data) + timeoutReader := NewTimeoutReader(reader, 1*time.Second) + + buf := make([]byte, len(data)) + n, err := timeoutReader.Read(buf) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if n != len(data) { + t.Fatalf("expected to read %d bytes, read %d", len(data), n) + } + if !bytes.Equal(buf, data) { + t.Fatalf("expected %s, got %s", data, buf) + } +} + +func TestTimeoutReader_ReadTimeout(t *testing.T) { + reader := &slowReader{delay: 2 * time.Second} + timeoutReader := NewTimeoutReader(reader, 1*time.Second) + + buf := make([]byte, 8192) + _, err := timeoutReader.Read(buf) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err) + } +} + +type slowReader struct { + delay time.Duration +} + +func (sr *slowReader) Read(p []byte) (n int, err error) { + time.Sleep(sr.delay) + return 0, io.EOF +} diff --git a/internal/test/httptest.go b/internal/test/httptest.go index 56142b949..6e2eab1d3 100644 --- a/internal/test/httptest.go +++ b/internal/test/httptest.go @@ -165,6 +165,8 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { if r == "" { writer.Header().Set("Content-Length", fmt.Sprintf("%d", BuildSize)) writer.WriteHeader(200) + (writer.(http.Flusher)).Flush() + file, err := os.Open(BuildFile) if err != nil { panic(err) @@ -204,6 +206,8 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { writer.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, BuildSize)) writer.Header().Set("Accept-Ranges", "bytes") writer.WriteHeader(206) + (writer.(http.Flusher)).Flush() + file, err := os.Open(BuildFile) if err != nil { writer.WriteHeader(500) From 5f72db3663012bbe60d039b1ffca70c465acee9a Mon Sep 17 00:00:00 2001 From: liwei Date: Mon, 23 Dec 2024 02:16:08 +0800 Subject: [PATCH 2/6] fix --- internal/protocol/http/fetcher.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index 0119f0011..f504efcc6 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -299,7 +299,7 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { // check if all failed allFailed := true for _, c := range f.chunks { - if chunk.Downloaded < chunk.End-chunk.Begin+1 && c.retryTimes < maxRetries { + if c.Downloaded < c.End-c.Begin+1 && c.retryTimes < maxRetries { allFailed = false break } @@ -336,22 +336,21 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { reader := NewTimeoutReader(resp.Body, 30*time.Second) for { n, err := reader.Read(buf) + if n > 0 { + _, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded) + if err != nil { + return err + } + chunk.Downloaded += int64(n) + } if err != nil { if err == io.EOF { return nil } return err } - // download success, reset retry times chunk.retryTimes = 0 - if n > 0 { - _, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded) - if err != nil { - return err - } - chunk.Downloaded += int64(n) - } } }() if err != nil { From fce79aeb3dae4bcaba27269410d43fcba33b0817 Mon Sep 17 00:00:00 2001 From: liwei Date: Mon, 23 Dec 2024 02:21:00 +0800 Subject: [PATCH 3/6] add connect timeout --- internal/protocol/http/fetcher.go | 6 +++++- internal/protocol/http/fetcher_test.go | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index f504efcc6..0948700c4 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "mime" + "net" "net/http" "net/http/cookiejar" "net/url" @@ -333,7 +334,7 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { err = NewRequestError(resp.StatusCode, resp.Status) return err } - reader := NewTimeoutReader(resp.Body, 30*time.Second) + reader := NewTimeoutReader(resp.Body, 15*time.Second) for { n, err := reader.Read(buf) if n > 0 { @@ -452,6 +453,9 @@ func (f *Fetcher) splitChunk() (chunks []*chunk) { func (f *Fetcher) buildClient() *http.Client { transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 15 * time.Second, + }).DialContext, Proxy: f.ctl.GetProxy(f.meta.Req.Proxy), TLSClientConfig: &tls.Config{ InsecureSkipVerify: f.meta.Req.SkipVerifyCert, diff --git a/internal/protocol/http/fetcher_test.go b/internal/protocol/http/fetcher_test.go index f6b75dac9..ef0a45dbb 100644 --- a/internal/protocol/http/fetcher_test.go +++ b/internal/protocol/http/fetcher_test.go @@ -133,7 +133,7 @@ func TestFetcher_DownloadLimit(t *testing.T) { } func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) { - listener := test.StartTestLimitServer(16, 35000) + listener := test.StartTestLimitServer(16, 10000) defer listener.Close() downloadError(listener, 1, t) From 6cb81e3f7b4dfea76e58679c4981b3689a1bb5c8 Mon Sep 17 00:00:00 2001 From: liwei Date: Mon, 23 Dec 2024 11:54:51 +0800 Subject: [PATCH 4/6] fix: fail fast on chunk fetch --- internal/protocol/http/fetcher.go | 57 ++++++++++++++------------ internal/protocol/http/fetcher_test.go | 3 +- internal/test/httptest.go | 45 +++++++++++--------- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index 0948700c4..aed6beae5 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -26,6 +26,9 @@ import ( "golang.org/x/sync/errgroup" ) +const connectTimeout = 15 * time.Second +const readTimeout = 15 * time.Second + type RequestError struct { Code int Msg string @@ -104,7 +107,7 @@ func (f *Fetcher) Resolve(req *base.Request) error { if base.HttpCodePartialContent == httpResp.StatusCode || (base.HttpCodeOK == httpResp.StatusCode && httpResp.Header.Get(base.HttpHeaderAcceptRanges) == base.HttpHeaderBytes && strings.HasPrefix(httpResp.Header.Get(base.HttpHeaderContentRange), base.HttpHeaderBytes)) { // response 206 status code, support breakpoint continuation res.Range = true - // 解析资源大小: bytes 0-1000/1001 => 1001 + // parse content length from Content-Range header, eg: bytes 0-1000/1001 contentTotal := path.Base(httpResp.Header.Get(base.HttpHeaderContentRange)) if contentTotal != "" { parse, err := strconv.ParseInt(contentTotal, 10, 64) @@ -147,7 +150,7 @@ func (f *Fetcher) Resolve(req *base.Request) error { file.Name = filename } } - // Get file filePath by URL + // get file filePath by URL if file.Name == "" { file.Name = path.Base(httpReq.URL.Path) } @@ -251,23 +254,40 @@ func (f *Fetcher) Wait() (err error) { return <-f.doneCh } +type fetchResult struct { + err error +} + func (f *Fetcher) fetch() { var ctx context.Context ctx, f.cancel = context.WithCancel(context.Background()) f.eg, _ = errgroup.WithContext(ctx) + fetchResults := make([]*fetchResult, len(f.chunks)) for i := 0; i < len(f.chunks); i++ { i := i f.eg.Go(func() error { - return f.fetchChunk(i, ctx) + fr := &fetchResult{ + err: f.fetchChunk(i, ctx), + } + fetchResults[i] = fr + return nil }) } go func() { - err := f.eg.Wait() - // check if canceled - if errors.Is(err, context.Canceled) { - return + f.eg.Wait() + var err error + for _, fr := range fetchResults { + // check if canceled + if errors.Is(fr.err, context.Canceled) { + return + } + // return first error + if err == nil && fr.err != nil { + err = fr.err + } } + f.file.Close() // Update file last modified time if f.config.UseServerCtime && f.meta.Res.Files[0].Ctime != nil { @@ -290,24 +310,11 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { for { // if chunk is completed, return if f.meta.Res.Range && chunk.Downloaded >= chunk.End-chunk.Begin+1 { - return + return nil } if chunk.retryTimes >= maxRetries { - if !f.meta.Res.Range { - return - } - // check if all failed - allFailed := true - for _, c := range f.chunks { - if c.Downloaded < c.End-c.Begin+1 && c.retryTimes < maxRetries { - allFailed = false - break - } - } - if allFailed { - return - } + return } var ( @@ -334,7 +341,7 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { err = NewRequestError(resp.StatusCode, resp.Status) return err } - reader := NewTimeoutReader(resp.Body, 15*time.Second) + reader := NewTimeoutReader(resp.Body, readTimeout) for { n, err := reader.Read(buf) if n > 0 { @@ -350,8 +357,6 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { } return err } - // download success, reset retry times - chunk.retryTimes = 0 } }() if err != nil { @@ -454,7 +459,7 @@ func (f *Fetcher) splitChunk() (chunks []*chunk) { func (f *Fetcher) buildClient() *http.Client { transport := &http.Transport{ DialContext: (&net.Dialer{ - Timeout: 15 * time.Second, + Timeout: connectTimeout, }).DialContext, Proxy: f.ctl.GetProxy(f.meta.Req.Proxy), TLSClientConfig: &tls.Config{ diff --git a/internal/protocol/http/fetcher_test.go b/internal/protocol/http/fetcher_test.go index ef0a45dbb..89e4e6ef9 100644 --- a/internal/protocol/http/fetcher_test.go +++ b/internal/protocol/http/fetcher_test.go @@ -133,10 +133,11 @@ func TestFetcher_DownloadLimit(t *testing.T) { } func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) { - listener := test.StartTestLimitServer(16, 10000) + listener := test.StartTestLimitServer(16, readTimeout.Milliseconds()+5000) defer listener.Close() downloadError(listener, 1, t) + downloadError(listener, 4, t) } func TestFetcher_DownloadResume(t *testing.T) { diff --git a/internal/test/httptest.go b/internal/test/httptest.go index 6e2eab1d3..77f183eb8 100644 --- a/internal/test/httptest.go +++ b/internal/test/httptest.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/json" + "errors" "fmt" "github.com/GopeedLab/gopeed/pkg/base" "github.com/armon/go-socks5" @@ -36,7 +37,7 @@ const ( ) func StartTestFileServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { return http.FileServer(http.Dir(Dir)) }) } @@ -52,7 +53,7 @@ func (s *SlowFileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func StartTestSlowFileServer(delay time.Duration) net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { return &SlowFileServer{ delay: delay, handler: http.FileServer(http.Dir(Dir)), @@ -61,7 +62,7 @@ func StartTestSlowFileServer(delay time.Duration) net.Listener { } func StartTestCustomServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { file, err := os.Open(BuildFile) @@ -88,7 +89,7 @@ func StartTestCustomServer() net.Listener { func StartTestRetryServer() net.Listener { counter := 0 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { counter++ @@ -108,7 +109,7 @@ func StartTestRetryServer() net.Listener { } func StartTestPostServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { if request.Method == "POST" && request.Header.Get("Authorization") != "" { @@ -132,7 +133,7 @@ func StartTestPostServer() net.Listener { func StartTestErrorServer() net.Listener { counter := 0 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { counter++ @@ -149,7 +150,7 @@ func StartTestErrorServer() net.Listener { func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { var connections atomic.Int32 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { defer func() { @@ -172,7 +173,7 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { panic(err) } defer file.Close() - slowCopy(writer, file, delay) + slowCopy(sl, writer, file, delay) } else { // split range s := strings.Split(r, "=") @@ -215,7 +216,7 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { } defer file.Close() file.Seek(start, 0) - slowCopyN(writer, file, end-start+1, delay) + slowCopyN(sl, writer, file, end-start+1, delay) } }) return mux @@ -223,9 +224,12 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { } // slowCopyN copies n bytes from src to dst, speed limit is bytes per second -func slowCopy(dst io.Writer, src io.Reader, delay int64) (written int64, err error) { +func slowCopy(sl *shutdownListener, dst io.Writer, src io.Reader, delay int64) (written int64, err error) { buf := make([]byte, 32*1024) for { + if sl.isShutdown { + return 0, errors.New("server shutdown") + } nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) @@ -254,8 +258,8 @@ func slowCopy(dst io.Writer, src io.Reader, delay int64) (written int64, err err return written, err } -func slowCopyN(dst io.Writer, src io.Reader, n int64, delay int64) (written int64, err error) { - written, err = slowCopy(dst, io.LimitReader(src, n), delay) +func slowCopyN(sl *shutdownListener, dst io.Writer, src io.Reader, n int64, delay int64) (written int64, err error) { + written, err = slowCopy(sl, dst, io.LimitReader(src, n), delay) if written == n { return n, nil } @@ -266,7 +270,7 @@ func slowCopyN(dst io.Writer, src io.Reader, n int64, delay int64) (written int6 return } -func startTestServer(serverHandle func() http.Handler) net.Listener { +func startTestServer(serverHandle func(sl *shutdownListener) http.Handler) net.Listener { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { panic(err) @@ -276,7 +280,7 @@ func startTestServer(serverHandle func() http.Handler) net.Listener { panic(err) } defer file.Close() - // 随机生成一个文件 + // Write random data l := int64(8192) buf := make([]byte, l) size := int64(0) @@ -293,21 +297,24 @@ func startTestServer(serverHandle func() http.Handler) net.Listener { size += l } server := &http.Server{} - server.Handler = serverHandle() - go server.Serve(listener) - - return &shutdownListener{ + sl := &shutdownListener{ server: server, Listener: listener, } + server.Handler = serverHandle(sl) + go server.Serve(listener) + + return sl } type shutdownListener struct { - server *http.Server + server *http.Server + isShutdown bool net.Listener } func (c *shutdownListener) Close() error { + c.isShutdown = true closeErr := c.server.Shutdown(context.Background()) if err := ifExistAndRemove(BuildFile); err != nil { fmt.Println(err) From 351c47b7caad95e524175be80edabb99b6bc8e67 Mon Sep 17 00:00:00 2001 From: liwei Date: Mon, 23 Dec 2024 14:13:03 +0800 Subject: [PATCH 5/6] fix: fail fast on cancel --- internal/protocol/http/fetcher.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index aed6beae5..31284b19a 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -266,8 +266,13 @@ func (f *Fetcher) fetch() { for i := 0; i < len(f.chunks); i++ { i := i f.eg.Go(func() error { + err := f.fetchChunk(i, ctx) + // if canceled, fail fast + if errors.Is(err, context.Canceled) { + return err + } fr := &fetchResult{ - err: f.fetchChunk(i, ctx), + err: err, } fetchResults[i] = fr return nil @@ -275,16 +280,16 @@ func (f *Fetcher) fetch() { } go func() { - f.eg.Wait() - var err error + err := f.eg.Wait() + // error returned only if canceled, just return + if err != nil { + return + } + // check all fetch results, if any error, return for _, fr := range fetchResults { - // check if canceled - if errors.Is(fr.err, context.Canceled) { - return - } - // return first error - if err == nil && fr.err != nil { + if fr.err != nil { err = fr.err + break } } From 32e87457bbf409a47b1724bdd36805a9d31fd325 Mon Sep 17 00:00:00 2001 From: liwei Date: Tue, 24 Dec 2024 20:08:50 +0800 Subject: [PATCH 6/6] update --- internal/protocol/http/fetcher.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index 31284b19a..d40be4525 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -262,7 +262,7 @@ func (f *Fetcher) fetch() { var ctx context.Context ctx, f.cancel = context.WithCancel(context.Background()) f.eg, _ = errgroup.WithContext(ctx) - fetchResults := make([]*fetchResult, len(f.chunks)) + chunkErrs := make([]error, len(f.chunks)) for i := 0; i < len(f.chunks); i++ { i := i f.eg.Go(func() error { @@ -271,10 +271,7 @@ func (f *Fetcher) fetch() { if errors.Is(err, context.Canceled) { return err } - fr := &fetchResult{ - err: err, - } - fetchResults[i] = fr + chunkErrs[i] = err return nil }) } @@ -286,9 +283,9 @@ func (f *Fetcher) fetch() { return } // check all fetch results, if any error, return - for _, fr := range fetchResults { - if fr.err != nil { - err = fr.err + for _, chunkErr := range chunkErrs { + if chunkErr != nil { + err = chunkErr break } }