Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: http download retries too many times #849

Merged
merged 7 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions internal/protocol/http/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
Expand All @@ -25,6 +26,9 @@ import (
"golang.org/x/sync/errgroup"
)

const connectTimeout = 15 * time.Second
const readTimeout = 15 * time.Second

type RequestError struct {
Code int
Msg string
Expand Down Expand Up @@ -103,7 +107,7 @@ func (f *Fetcher) Resolve(req *base.Request) error {
if base.HttpCodePartialContent == httpResp.StatusCode || (base.HttpCodeOK == httpResp.StatusCode && httpResp.Header.Get(base.HttpHeaderAcceptRanges) == base.HttpHeaderBytes && strings.HasPrefix(httpResp.Header.Get(base.HttpHeaderContentRange), base.HttpHeaderBytes)) {
// response 206 status code, support breakpoint continuation
res.Range = true
// 解析资源大小: bytes 0-1000/1001 => 1001
// parse content length from Content-Range header, eg: bytes 0-1000/1001
contentTotal := path.Base(httpResp.Header.Get(base.HttpHeaderContentRange))
if contentTotal != "" {
parse, err := strconv.ParseInt(contentTotal, 10, 64)
Expand Down Expand Up @@ -146,7 +150,7 @@ func (f *Fetcher) Resolve(req *base.Request) error {
file.Name = filename
}
}
// Get file filePath by URL
// get file filePath by URL
if file.Name == "" {
file.Name = path.Base(httpReq.URL.Path)
}
Expand Down Expand Up @@ -250,23 +254,42 @@ func (f *Fetcher) Wait() (err error) {
return <-f.doneCh
}

type fetchResult struct {
err error
}

func (f *Fetcher) fetch() {
var ctx context.Context
ctx, f.cancel = context.WithCancel(context.Background())
f.eg, _ = errgroup.WithContext(ctx)
chunkErrs := make([]error, len(f.chunks))
for i := 0; i < len(f.chunks); i++ {
i := i
f.eg.Go(func() error {
return f.fetchChunk(i, ctx)
err := f.fetchChunk(i, ctx)
// if canceled, fail fast
if errors.Is(err, context.Canceled) {
return err
}
chunkErrs[i] = err
return nil
})
}

go func() {
err := f.eg.Wait()
// check if canceled
if errors.Is(err, context.Canceled) {
// error returned only if canceled, just return
if err != nil {
return
}
// check all fetch results, if any error, return
for _, chunkErr := range chunkErrs {
if chunkErr != nil {
err = chunkErr
break
}
}

f.file.Close()
// Update file last modified time
if f.config.UseServerCtime && f.meta.Res.Files[0].Ctime != nil {
Expand All @@ -289,24 +312,11 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
for {
// if chunk is completed, return
if f.meta.Res.Range && chunk.Downloaded >= chunk.End-chunk.Begin+1 {
return
return nil
}

if chunk.retryTimes >= maxRetries {
if !f.meta.Res.Range {
return
}
// check if all failed
allFailed := true
for _, c := range f.chunks {
if chunk.Downloaded < chunk.End-chunk.Begin+1 && c.retryTimes < maxRetries {
allFailed = false
break
}
}
if allFailed {
return
}
return
}

var (
Expand All @@ -333,10 +343,9 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
err = NewRequestError(resp.StatusCode, resp.Status)
return err
}
// Http request success, reset retry times
chunk.retryTimes = 0
reader := NewTimeoutReader(resp.Body, readTimeout)
for {
n, err := resp.Body.Read(buf)
n, err := reader.Read(buf)
if n > 0 {
_, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded)
if err != nil {
Expand All @@ -351,7 +360,6 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
return err
}
}
return nil
}()
if err != nil {
// If canceled, do not retry
Expand Down Expand Up @@ -452,6 +460,9 @@ func (f *Fetcher) splitChunk() (chunks []*chunk) {

func (f *Fetcher) buildClient() *http.Client {
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: connectTimeout,
}).DialContext,
Proxy: f.ctl.GetProxy(f.meta.Req.Proxy),
TLSClientConfig: &tls.Config{
InsecureSkipVerify: f.meta.Req.SkipVerifyCert,
Expand Down
8 changes: 8 additions & 0 deletions internal/protocol/http/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ func TestFetcher_DownloadLimit(t *testing.T) {
downloadNormal(listener, 8, t)
}

func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) {
listener := test.StartTestLimitServer(16, readTimeout.Milliseconds()+5000)
defer listener.Close()

downloadError(listener, 1, t)
downloadError(listener, 4, t)
}

func TestFetcher_DownloadResume(t *testing.T) {
listener := test.StartTestFileServer()
defer listener.Close()
Expand Down
40 changes: 40 additions & 0 deletions internal/protocol/http/timeout_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package http

import (
"context"
"io"
"time"
)

type TimeoutReader struct {
reader io.Reader
timeout time.Duration
}

func NewTimeoutReader(r io.Reader, timeout time.Duration) *TimeoutReader {
return &TimeoutReader{
reader: r,
timeout: timeout,
}
}

func (tr *TimeoutReader) Read(p []byte) (n int, err error) {
ctx, cancel := context.WithTimeout(context.Background(), tr.timeout)
defer cancel()

done := make(chan struct{})
var readErr error
var bytesRead int

go func() {
bytesRead, readErr = tr.reader.Read(p)
close(done)
}()

select {
case <-done:
return bytesRead, readErr
case <-ctx.Done():
return 0, ctx.Err()
}
}
51 changes: 51 additions & 0 deletions internal/protocol/http/timeout_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package http

import (
"bytes"
"context"
"errors"
"io"
"testing"
"time"
)

func TestTimeoutReader_Read(t *testing.T) {
data := []byte("Hello, World!")
reader := bytes.NewReader(data)
timeoutReader := NewTimeoutReader(reader, 1*time.Second)

buf := make([]byte, len(data))
n, err := timeoutReader.Read(buf)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if n != len(data) {
t.Fatalf("expected to read %d bytes, read %d", len(data), n)
}
if !bytes.Equal(buf, data) {
t.Fatalf("expected %s, got %s", data, buf)
}
}

func TestTimeoutReader_ReadTimeout(t *testing.T) {
reader := &slowReader{delay: 2 * time.Second}
timeoutReader := NewTimeoutReader(reader, 1*time.Second)

buf := make([]byte, 8192)
_, err := timeoutReader.Read(buf)
if err == nil {
t.Fatal("expected timeout error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
}
}

type slowReader struct {
delay time.Duration
}

func (sr *slowReader) Read(p []byte) (n int, err error) {
time.Sleep(sr.delay)
return 0, io.EOF
}
Loading
Loading