Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List downloads2 #232

Merged
merged 17 commits into from
Sep 11, 2024
219 changes: 32 additions & 187 deletions internal/cli/serverless/export/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,20 @@ package export

import (
"fmt"
"io"
"os"
"strings"
"sync"
"tidbcloud-cli/internal"
"tidbcloud-cli/internal/config"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/service/cloud"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"

"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/dustin/go-humanize"
"github.com/fatih/color"
"github.com/juju/errors"
"github.com/spf13/cobra"

"tidbcloud-cli/internal"
"tidbcloud-cli/internal/config"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/service/cloud"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"
"tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/export"
)

const DefaultConcurrency = 3
Expand Down Expand Up @@ -161,16 +154,16 @@ func DownloadCmd(h *internal.Helper) *cobra.Command {
return errors.Trace(err)
}

resp, err := d.DownloadExport(ctx, clusterID, exportID)
exportFiles, err := cloud.GetAllExportFiles(ctx, clusterID, exportID, d)
if err != nil {
return errors.Trace(err)
}

var totalSize int64
for _, download := range resp.Downloads {
totalSize += *download.Size
for _, file := range exportFiles {
totalSize += *file.Size
}
fileMessage := fmt.Sprintf("There are %d files to download, total size is %s.", len(resp.Downloads), humanize.IBytes(uint64(totalSize)))
fileMessage := fmt.Sprintf("There are %d files to download, total size is %s.", len(exportFiles), humanize.IBytes(uint64(totalSize)))

if !force {
if !h.IOStreams.CanPrompt {
Expand All @@ -196,14 +189,13 @@ func DownloadCmd(h *internal.Helper) *cobra.Command {
} else {
fmt.Fprintf(h.IOStreams.Out, "%s\n", color.BlueString(fileMessage))
}

if h.IOStreams.CanPrompt {
err = DownloadFilesPrompt(h, resp.Downloads, path, concurrency)
err = DownloadFilesPrompt(h, path, concurrency, exportID, clusterID, totalSize, len(exportFiles), d)
if err != nil {
return errors.Trace(err)
}
} else {
err = DownloadFilesWithoutPrompt(h, resp.Downloads, path, concurrency)
err = DownloadFilesWithoutPrompt(h, path, concurrency, exportID, clusterID, len(exportFiles), d)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -221,7 +213,8 @@ func DownloadCmd(h *internal.Helper) *cobra.Command {
return downloadCmd
}

func DownloadFilesPrompt(h *internal.Helper, urls []export.DownloadUrl, path string, concurrency int) error {
func DownloadFilesPrompt(h *internal.Helper, path string,
concurrency int, exportID, clusterID string, totalSize int64, count int, client cloud.TiDBCloudClient) error {
if concurrency <= 0 {
concurrency = DefaultConcurrency
}
Expand All @@ -234,19 +227,14 @@ func DownloadFilesPrompt(h *internal.Helper, urls []export.DownloadUrl, path str

// init the concurrency progress model
var p *tea.Program
urlMsgs := make([]ui.URLMsg, 0)
for _, u := range urls {
url := ui.URLMsg{
Name: *u.Name,
Url: *u.Url,
Size: *u.Size,
}
urlMsgs = append(urlMsgs, url)
}
m := ui.NewProcessDownloadModel(
urlMsgs,
m := NewProcessDownloadModel(
concurrency,
path,
exportID,
clusterID,
client,
int(totalSize),
count,
)

// run the program
Expand All @@ -256,7 +244,7 @@ func DownloadFilesPrompt(h *internal.Helper, urls []export.DownloadUrl, path str
if err != nil {
return errors.Trace(err)
}
if m, _ := model.(*ui.ProcessDownloadModel); m.Interrupted {
if m, _ := model.(*ProcessDownloadModel); m.Interrupted {
return util.InterruptError
}

Expand All @@ -265,20 +253,20 @@ func DownloadFilesPrompt(h *internal.Helper, urls []export.DownloadUrl, path str
skippedCount := 0
for _, f := range m.GetFinishedJobs() {
switch f.GetStatus() {
case ui.Succeeded:
case Succeeded:
succeededCount++
case ui.Failed:
case Failed:
failedCount++
case ui.Skipped:
case Skipped:
skippedCount++
}
}
fmt.Fprint(h.IOStreams.Out, generateDownloadSummary(succeededCount, skippedCount, failedCount))
index := 0
for _, f := range m.GetFinishedJobs() {
if f.GetStatus() != ui.Succeeded {
if f.GetStatus() != Succeeded {
index++
fmt.Fprintf(h.IOStreams.Out, "%d.%s\n", index, f.GetErrorString())
fmt.Fprintf(h.IOStreams.Out, "%d.%s\n", index, f.GetResult())
}
}

Expand All @@ -288,158 +276,15 @@ func DownloadFilesPrompt(h *internal.Helper, urls []export.DownloadUrl, path str
return nil
}

func initialDownloadPathInputModel() ui.TextInputModel {
m := ui.TextInputModel{
Inputs: make([]textinput.Model, len(DownloadPathInputFields)),
}
for k, v := range DownloadPathInputFields {
t := textinput.New()
switch k {
case flag.OutputPath:
t.Placeholder = "Where you want to download the file. Press Enter to skip and download to the current directory"
t.Focus()
t.PromptStyle = config.FocusedStyle
t.TextStyle = config.FocusedStyle
}
m.Inputs[v] = t
}
return m
}

func GetDownloadPathInput() (tea.Model, error) {
p := tea.NewProgram(initialDownloadPathInputModel())
inputModel, err := p.Run()
func DownloadFilesWithoutPrompt(h *internal.Helper, path string,
concurrency int, exportID, clusterID string, count int, client cloud.TiDBCloudClient) error {
exportDownloadPool, err := NewDownloadPool(h, path, concurrency, exportID, clusterID, count, client)
if err != nil {
return nil, errors.Trace(err)
}
if inputModel.(ui.TextInputModel).Interrupted {
return nil, util.InterruptError
}
return inputModel, nil
}

var wg sync.WaitGroup

type downloadJob struct {
url export.DownloadUrl
path string
}

type downloadResult struct {
name string
err error
status ui.JobStatus
}

func (r *downloadResult) GetErrorString() string {
if r.status == ui.Succeeded {
return ""
}
if r.err == nil {
return fmt.Sprintf("%s %s", r.name, r.status)
}
return fmt.Sprintf("%s %s: %s", r.name, r.status, r.err.Error())
}

func DownloadFilesWithoutPrompt(h *internal.Helper, urls []export.DownloadUrl, path string, concurrency int) error {
if concurrency <= 0 {
concurrency = DefaultConcurrency
return errors.Trace(err)
}
// create the path if not exist
err := util.CreateFolder(path)
err = exportDownloadPool.Start()
if err != nil {
return err
}

jobs := make(chan *downloadJob, len(urls))
results := make(chan *downloadResult, len(urls))
// Start consumers:
for i := 0; i < concurrency; i++ {
wg.Add(1)
go consume(h, jobs, results)
}
// Start producing
for _, u := range urls {
jobs <- &downloadJob{url: u, path: path}
}
close(jobs)
wg.Wait()
close(results)

succeededCount := 0
failedCount := 0
skippedCount := 0
downloadResults := make([]*downloadResult, 0)
for result := range results {
switch result.status {
case ui.Succeeded:
succeededCount++
case ui.Failed:
failedCount++
case ui.Skipped:
skippedCount++
}
downloadResults = append(downloadResults, result)
}
fmt.Fprint(h.IOStreams.Out, generateDownloadSummary(succeededCount, skippedCount, failedCount))
index := 0
for _, f := range downloadResults {
if f.status != ui.Succeeded {
index++
fmt.Fprintf(h.IOStreams.Out, "%d.%s\n", index, f.GetErrorString())
}
}
if failedCount > 0 {
return errors.New(fmt.Sprintf("%d file(s) failed to download", failedCount))
return errors.Trace(err)
}
return nil
}

func consume(h *internal.Helper, jobs <-chan *downloadJob, results chan *downloadResult) {
defer wg.Done()
for job := range jobs {
func() {
var err error
defer func() {
if err != nil {
if strings.Contains(err.Error(), "file already exists") {
fmt.Fprintf(h.IOStreams.Out, "download %s skipped: %s\n", *job.url.Name, err.Error())
results <- &downloadResult{name: *job.url.Name, err: err, status: ui.Skipped}
} else {
fmt.Fprintf(h.IOStreams.Out, "download %s failed: %s\n", *job.url.Name, err.Error())
results <- &downloadResult{name: *job.url.Name, err: err, status: ui.Failed}
}
} else {
fmt.Fprintf(h.IOStreams.Out, "download %s succeeded\n", *job.url.Name)
results <- &downloadResult{name: *job.url.Name, err: nil, status: ui.Succeeded}
}
}()

fmt.Fprintf(h.IOStreams.Out, "downloading %s | %s\n", *job.url.Name, humanize.IBytes(uint64(*job.url.Size)))

// request the url
resp, err := util.GetResponse(*job.url.Url, os.Getenv(config.DebugEnv) != "")
if err != nil {
return
}
defer resp.Body.Close()

file, err := util.CreateFile(job.path, *job.url.Name)
if err != nil {
return
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
}()
}
}

func generateDownloadSummary(succeededCount, skippedCount, failedCount int) string {
summaryMessage := fmt.Sprintf("%s %s %s", color.BlueString("download summary:"), color.GreenString("succeeded: %d", succeededCount), color.GreenString("skipped: %d", skippedCount))
if failedCount > 0 {
summaryMessage += color.RedString(" failed: %d", failedCount)
} else {
summaryMessage += fmt.Sprintf(" failed: %d", failedCount)
}
return summaryMessage + "\n"
}
Loading
Loading