diff --git a/chunk.go b/chunk.go index c611c85..280febe 100644 --- a/chunk.go +++ b/chunk.go @@ -21,6 +21,10 @@ type Chunk struct { // Path name where this chunk downloaded. Path string + + Downloaded bool + + Merged bool } // Download a chunk, and report to Progress, it returns error if any! @@ -51,9 +55,5 @@ func (c *Chunk) Download(URL string, client *http.Client, dest *os.File) (err er _, err = io.Copy(dest, io.TeeReader(res.Body, c.Progress)) - if err == nil { - c.Path = dest.Name() - } - return err } diff --git a/got.go b/got.go index b7a22ab..be7cf4e 100644 --- a/got.go +++ b/got.go @@ -72,9 +72,16 @@ type ( // Progress... progress *Progress + + // Chunk merge index. + index int + + // Sync mutex. + mu sync.RWMutex } ) + // Check Download and split file to chunks and set defaults, // you should call Init first then call Start func (d *Download) Init() error { @@ -98,7 +105,9 @@ func (d *Download) Init() error { } // Init progress. - d.progress = new(Progress) + d.progress = &Progress{ + mu: d.mu, + } // Set default interval. if d.Interval == 0 { @@ -106,15 +115,13 @@ func (d *Download) Init() error { } // Get URL info. - d.Info, err = d.GetInfo() - - if err != nil { + if d.Info, err = d.GetInfo(); err != nil { return err } // Partial content not supported 😢! if d.Info.Rangeable == false || d.Info.Length == 0 { - return err + return nil } // Set concurrency default to 10. @@ -132,11 +139,6 @@ func (d *Download) Init() error { d.ChunkSize = d.ChunkSize / 2 } - // Change ChunkSize if MaxChunkSize are set and ChunkSize > Max size - if d.MaxChunkSize > 0 && d.ChunkSize > d.MaxChunkSize { - d.ChunkSize = d.MaxChunkSize - } - // Set default min chunk size to 1m, or file size / 2 if d.MinChunkSize == 0 { @@ -147,29 +149,39 @@ func (d *Download) Init() error { } } - // if Chunk size < Min size set chunk size to length / 2 + // if Chunk size < Min size set chunk size to min. if d.ChunkSize < d.MinChunkSize { d.ChunkSize = d.MinChunkSize } - } - // Avoid divide by zero - if d.ChunkSize > 0 { - chunksLen = d.Info.Length / d.ChunkSize + // Change ChunkSize if MaxChunkSize are set and ChunkSize > Max size + if d.MaxChunkSize > 0 && d.ChunkSize > d.MaxChunkSize { + d.ChunkSize = d.MaxChunkSize + } + + } else if d.ChunkSize > d.Info.Length { + + d.ChunkSize = d.Info.Length / 2 } - // Set chunks. + chunksLen = d.Info.Length / d.ChunkSize + + // Set chunk ranges. for ; i < chunksLen; i++ { - startRange = (d.ChunkSize * i) + 1 + startRange = (d.ChunkSize * i) + i + endRange = startRange + d.ChunkSize if i == 0 { + startRange = 0 - } - endRange = startRange + d.ChunkSize + } else if d.chunks[i - 1].End == 0 { - if i == (chunksLen - 1) { + break + } + + if endRange > d.Info.Length || i == (chunksLen - 1) { endRange = 0 } @@ -205,8 +217,30 @@ func (d *Download) Start() (err error) { // Run progress func. go d.progress.Run(d) - // Start download chunks. - go d.work(&errChan, &okChan) + // Partial content not supported, + // just download the file in one chunk. + if len(d.chunks) == 0 { + + file, err := os.Create(d.Dest) + + if err != nil { + return err + } + + defer file.Close() + + chunk := &Chunk{ + Progress: d.progress, + } + + return chunk.Download(d.URL, d.client, file) + } + + // Download chunks. + go d.dl(&errChan) + + // Merge chunks. + go d.merge(&errChan, &okChan) // Wait for chunks... for { @@ -214,7 +248,12 @@ func (d *Download) Start() (err error) { select { case err := <-errChan: - return err + + if err != nil { + return err + } + + break case <-okChan: @@ -258,115 +297,100 @@ func (d *Download) GetInfo() (*Info, error) { }, nil } -// Download chunks and in same time merge them into dest path. -func (d *Download) work(echan *chan error, done *chan bool) { - var ( - // Next chunk index. - next int = 0 +// Merge downloaded chunks. +func (d *Download) merge(echan *chan error, done *chan bool) { - // Waiting group. - swg sync.WaitGroup - - // Concurrency limit. - max chan int = make(chan int, d.Concurrency) - - // Chunk file. - chunk *os.File - ) - - go func() { - - chunksLen := len(d.chunks) - - file, err := os.Create(d.Dest) + file, err := os.Create(d.Dest) - if err != nil { - *echan <- err - return - } - - defer file.Close() - - // Partial content not supported or file length is unknown, - // so just download it directly in one chunk! - if chunksLen == 0 { - - chunk := &Chunk{ - Progress: d.progress, - } - - if err := chunk.Download(d.URL, d.client, file); err != nil { - *echan <- err - return - } + if err != nil { + *echan <-err + return + } - *done <- true - return - } + defer file.Close() - for { + chunksLen := len(d.chunks) - for i := 0; i < len(d.chunks); i++ { + for { - if next == i && d.chunks[i].Path != "" { + for i := range d.chunks { - chunk, err = os.Open(d.chunks[i].Path) + d.mu.RLock() + if d.chunks[i].Downloaded && d.chunks[i].Merged == false && i == d.index { - if err != nil { + chunk, err := os.Open(d.chunks[i].Path) - *echan <- err - return - } + if err != nil { + *echan <-err + return + } - // Copy chunk content to dest file. - _, err = io.Copy(file, chunk) + _, err = io.Copy(file, chunk) - // Close chunk fd. - chunk.Close() + if err != nil { + *echan <-err + return + } - if err != nil { + go chunk.Close() - *echan <- err - return - } + // Sync dest file. + file.Sync() - next++ - } + d.chunks[i].Merged = true + d.index++ } + d.mu.RUnlock() - if next == len(d.chunks) { + // done, all chunks merged. + if d.index == chunksLen { *done <- true return } - - time.Sleep(6 * time.Millisecond) } - }() + } +} + + +// Download chunks +func (d *Download) dl(echan *chan error) { + + var ( + + // Waiting group. + swg sync.WaitGroup + + // Concurrency limit. + max chan int = make(chan int, d.Concurrency) + ) for i := 0; i < len(d.chunks); i++ { - max <- 1 + max <-1 swg.Add(1) go func(i int) { defer swg.Done() - chunk, err := ioutil.TempFile(d.temp, fmt.Sprintf("chunk-%d", i)) + chunk, err := os.Create(fmt.Sprintf("%s/chunk-%d", d.temp, i)) if err != nil { - *echan <- err + *echan <-err return } // Close chunk fd. defer chunk.Close() - // Donwload the chunk. - if err = d.chunks[i].Download(d.URL, d.client, chunk); err != nil { - *echan <- err - } + // Donwload chunk. + *echan <-d.chunks[i].Download(d.URL, d.client, chunk) + + d.mu.Lock() + d.chunks[i].Path = chunk.Name() + d.chunks[i].Downloaded = true + d.mu.Unlock() <-max }(i) diff --git a/got_test.go b/got_test.go index b4dac67..6400585 100644 --- a/got_test.go +++ b/got_test.go @@ -6,6 +6,10 @@ import ( "net/http/httptest" "os" "testing" + "strings" + "time" + + "io/ioutil" "github.com/melbahja/got" ) @@ -30,7 +34,8 @@ func TestGot(t *testing.T) { switch r.URL.String() { case "/file1": - http.ServeContent(w, r, "go.mod", stat.ModTime(), file) + // http.ServeContent(w, r, "go.mod", stat.ModTime(), file) + http.ServeFile(w, r, "go.mod") return case "/file2": @@ -47,6 +52,16 @@ func TestGot(t *testing.T) { w.WriteHeader(http.StatusMethodNotAllowed) return + + case "/file5": + + if strings.Contains(r.Header.Get("range"), "3-") { + + time.Sleep(2 * time.Second) + } + + http.ServeFile(w, r, "go.mod") + return } w.WriteHeader(http.StatusNotFound) @@ -104,14 +119,22 @@ func TestGot(t *testing.T) { // test when partial content not supprted. downloadPartialContentNotSupportedTest(t, httpt.URL+"/file2") }) + + t.Run("fileContentTest", func(t *testing.T) { + + // test when partial content not supprted. + fileContentTest(t, httpt.URL+"/file5") + }) }) } func getInfoTest(t *testing.T, url string, expect got.Info) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) - d, err := got.New(url, "/tmp/got_dl_test") + + d, err := got.New(url, tmpFile) if err != nil { t.Error(err) @@ -133,11 +156,13 @@ func getInfoTest(t *testing.T, url string, expect got.Info) { func initTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + d := got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, } if err := d.Init(); err != nil { @@ -147,9 +172,11 @@ func initTest(t *testing.T, url string) { func downloadChunksTest(t *testing.T, url string, size int64) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + - d, err := got.New(url, "/tmp/got_dl_test") + d, err := got.New(url, tmpFile) if err != nil { @@ -166,11 +193,13 @@ func downloadChunksTest(t *testing.T, url string, size int64) { func downloadTest(t *testing.T, url string, size int64) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + d := &got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, Concurrency: 2, StopProgress: true, } @@ -185,7 +214,7 @@ func downloadTest(t *testing.T, url string, size int64) { t.Error(err) } - stat, err := os.Stat("/tmp/got_dl_test") + stat, err := os.Stat(tmpFile) if err != nil { t.Error(err) @@ -198,9 +227,11 @@ func downloadTest(t *testing.T, url string, size int64) { func downloadNotFoundTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + - _, err := got.New(url, "/tmp/got_dl_test") + _, err := got.New(url, tmpFile) if err == nil { t.Error("It sould have an error") @@ -210,10 +241,12 @@ func downloadNotFoundTest(t *testing.T, url string) { func downloadHeadNotSupported(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) d := &got.Download{ URL: url, + Dest: tmpFile, } // init @@ -238,11 +271,12 @@ func downloadHeadNotSupported(t *testing.T, url string) { func downloadPartialContentNotSupportedTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) d := &got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, } if err := d.Init(); err != nil { @@ -258,7 +292,7 @@ func downloadPartialContentNotSupportedTest(t *testing.T, url string) { t.Error(err) } - stat, err := os.Stat("/tmp/got_dl_test") + stat, err := os.Stat(tmpFile) if err != nil { t.Error(err) @@ -269,7 +303,63 @@ func downloadPartialContentNotSupportedTest(t *testing.T, url string) { } } -func clean() { +func fileContentTest(t *testing.T, url string) { + + tmpFile := createTemp() + defer clean(tmpFile) + + d := &got.Download{ + URL: url, + Dest: tmpFile, + ChunkSize: 10, + } + + if err := d.Init(); err != nil { + t.Error(err) + return + } + + if err := d.Start(); err != nil { + t.Error(err) + return + } + + mod, err := ioutil.ReadFile("go.mod") + + if err != nil { + t.Error(err) + return + } + + dlFile, err := ioutil.ReadFile(tmpFile) + + if err != nil { + t.Error(err) + return + } + + if string(mod) != string(dlFile) { + + fmt.Println("a", string(mod)) + fmt.Println("b", string(dlFile)) + t.Error("Corrupted file") + } +} + +func createTemp() string { + + tmp, err := ioutil.TempFile("", "") + + if err != nil { + panic(err) + } + + defer tmp.Close() + + return tmp.Name() +} + +func clean(tmpFile string) { - os.Remove("/tmp/got_dl_test") + os.Remove(tmpFile) } diff --git a/progress.go b/progress.go index 77d8489..a35e5b1 100644 --- a/progress.go +++ b/progress.go @@ -2,13 +2,15 @@ package got import ( "time" + "sync" ) type ( // Download progress. Progress struct { - Length int64 + Size int64 + mu sync.RWMutex } // Progress report func. @@ -25,7 +27,7 @@ func (p *Progress) Run(d *Download) { break } - d.ProgressFunc(p.Length, d.Info.Length, d) + d.ProgressFunc(p.Size, d.Info.Length, d) time.Sleep(time.Duration(d.Interval) * time.Millisecond) } @@ -33,7 +35,9 @@ func (p *Progress) Run(d *Download) { } func (p *Progress) Write(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() n := len(b) - p.Length += int64(n) + p.Size += int64(n) return n, nil }