diff --git a/got_test.go b/got_test.go index b4dac67..6400585 100644 --- a/got_test.go +++ b/got_test.go @@ -6,6 +6,10 @@ import ( "net/http/httptest" "os" "testing" + "strings" + "time" + + "io/ioutil" "github.com/melbahja/got" ) @@ -30,7 +34,8 @@ func TestGot(t *testing.T) { switch r.URL.String() { case "/file1": - http.ServeContent(w, r, "go.mod", stat.ModTime(), file) + // http.ServeContent(w, r, "go.mod", stat.ModTime(), file) + http.ServeFile(w, r, "go.mod") return case "/file2": @@ -47,6 +52,16 @@ func TestGot(t *testing.T) { w.WriteHeader(http.StatusMethodNotAllowed) return + + case "/file5": + + if strings.Contains(r.Header.Get("range"), "3-") { + + time.Sleep(2 * time.Second) + } + + http.ServeFile(w, r, "go.mod") + return } w.WriteHeader(http.StatusNotFound) @@ -104,14 +119,22 @@ func TestGot(t *testing.T) { // test when partial content not supprted. downloadPartialContentNotSupportedTest(t, httpt.URL+"/file2") }) + + t.Run("fileContentTest", func(t *testing.T) { + + // test when partial content not supprted. + fileContentTest(t, httpt.URL+"/file5") + }) }) } func getInfoTest(t *testing.T, url string, expect got.Info) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) - d, err := got.New(url, "/tmp/got_dl_test") + + d, err := got.New(url, tmpFile) if err != nil { t.Error(err) @@ -133,11 +156,13 @@ func getInfoTest(t *testing.T, url string, expect got.Info) { func initTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + d := got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, } if err := d.Init(); err != nil { @@ -147,9 +172,11 @@ func initTest(t *testing.T, url string) { func downloadChunksTest(t *testing.T, url string, size int64) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + - d, err := got.New(url, "/tmp/got_dl_test") + d, err := got.New(url, tmpFile) if err != nil { @@ -166,11 +193,13 @@ func downloadChunksTest(t *testing.T, url string, size int64) { func downloadTest(t *testing.T, url string, size int64) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + d := &got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, Concurrency: 2, StopProgress: true, } @@ -185,7 +214,7 @@ func downloadTest(t *testing.T, url string, size int64) { t.Error(err) } - stat, err := os.Stat("/tmp/got_dl_test") + stat, err := os.Stat(tmpFile) if err != nil { t.Error(err) @@ -198,9 +227,11 @@ func downloadTest(t *testing.T, url string, size int64) { func downloadNotFoundTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) + - _, err := got.New(url, "/tmp/got_dl_test") + _, err := got.New(url, tmpFile) if err == nil { t.Error("It sould have an error") @@ -210,10 +241,12 @@ func downloadNotFoundTest(t *testing.T, url string) { func downloadHeadNotSupported(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) d := &got.Download{ URL: url, + Dest: tmpFile, } // init @@ -238,11 +271,12 @@ func downloadHeadNotSupported(t *testing.T, url string) { func downloadPartialContentNotSupportedTest(t *testing.T, url string) { - defer clean() + tmpFile := createTemp() + defer clean(tmpFile) d := &got.Download{ URL: url, - Dest: "/tmp/got_dl_test", + Dest: tmpFile, } if err := d.Init(); err != nil { @@ -258,7 +292,7 @@ func downloadPartialContentNotSupportedTest(t *testing.T, url string) { t.Error(err) } - stat, err := os.Stat("/tmp/got_dl_test") + stat, err := os.Stat(tmpFile) if err != nil { t.Error(err) @@ -269,7 +303,63 @@ func downloadPartialContentNotSupportedTest(t *testing.T, url string) { } } -func clean() { +func fileContentTest(t *testing.T, url string) { + + tmpFile := createTemp() + defer clean(tmpFile) + + d := &got.Download{ + URL: url, + Dest: tmpFile, + ChunkSize: 10, + } + + if err := d.Init(); err != nil { + t.Error(err) + return + } + + if err := d.Start(); err != nil { + t.Error(err) + return + } + + mod, err := ioutil.ReadFile("go.mod") + + if err != nil { + t.Error(err) + return + } + + dlFile, err := ioutil.ReadFile(tmpFile) + + if err != nil { + t.Error(err) + return + } + + if string(mod) != string(dlFile) { + + fmt.Println("a", string(mod)) + fmt.Println("b", string(dlFile)) + t.Error("Corrupted file") + } +} + +func createTemp() string { + + tmp, err := ioutil.TempFile("", "") + + if err != nil { + panic(err) + } + + defer tmp.Close() + + return tmp.Name() +} + +func clean(tmpFile string) { - os.Remove("/tmp/got_dl_test") + os.Remove(tmpFile) }