Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tool.go to create output path if not exist #516

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go/cmd/remotetool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ var (
operation = flag.String("operation", "", fmt.Sprintf("Specifies the operation to perform. Supported values: %v", supportedOps))
digest = flag.String("digest", "", "Digest in <digest/size_bytes> format.")
pathPrefix = flag.String("path", "", "Path to which outputs should be downloaded to.")
overwrite = flag.Bool("overwrite", false, "Overwrite the output path if it already exist.")
actionRoot = flag.String("action_root", "", "For execute_action: the root of the action spec, containing ac.textproto (Action proto), cmd.textproto (Command proto), and input/ (root of the input tree).")
execAttempts = flag.Int("exec_attempts", 10, "For check_determinism: the number of times to remotely execute the action and check for mismatches.")
_ = flag.String("input_root", "", "Deprecated. Use action root instead.")
Expand Down Expand Up @@ -115,7 +116,7 @@ func main() {
os.Stdout.Write([]byte(res))

case downloadAction:
err := c.DownloadAction(ctx, getDigestFlag(), getPathFlag())
err := c.DownloadAction(ctx, getDigestFlag(), getPathFlag(), *overwrite)
if err != nil {
log.Exitf("error fetching action %v: %v", getDigestFlag(), err)
}
Expand Down
33 changes: 32 additions & 1 deletion go/pkg/tool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package tool

import (
"bufio"
"bytes"
"context"
"fmt"
Expand Down Expand Up @@ -342,7 +343,7 @@ func (c *Client) writeProto(m proto.Message, baseName string) error {
// 4. input_node_properties.txtproto: all the NodeProperties defined on the
// input tree, as an InputSpec proto file in text format. Will be omitted
// if no NodeProperties are defined.
func (c *Client) DownloadAction(ctx context.Context, actionDigest, outputPath string) error {
func (c *Client) DownloadAction(ctx context.Context, actionDigest, outputPath string, overwrite bool) error {
acDg, err := digest.NewFromString(actionDigest)
if err != nil {
return err
Expand All @@ -352,6 +353,36 @@ func (c *Client) DownloadAction(ctx context.Context, actionDigest, outputPath st
if _, err := c.GrpcClient.ReadProto(ctx, acDg, actionProto); err != nil {
return err
}

// Directory already exists, ask the user for confirmation before overwrite it.
if _, err := os.Stat(outputPath); !os.IsNotExist(err) {
fmt.Printf("Directory '%s' already exists. Do you want to overwrite it? (yes/no): ", outputPath)
if !overwrite {
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("error reading user input: %v", err)
}
input = strings.TrimSpace(input)
input = strings.ToLower(input)

if !(input == "yes" || input == "y") {
return errors.Errorf("operation aborted.")
}
}
// If the user confirms, remove the existing directory and create a new one
err = os.RemoveAll(outputPath)
if err != nil {
return fmt.Errorf("error removing existing directory: %v", err)
}
}
// Directory doesn't exist, create it.
err = os.MkdirAll(outputPath, os.ModePerm)
if err != nil {
return fmt.Errorf("error creating the directory: %v", err)
}
log.Infof("Directory created:", outputPath)

if err := c.writeProto(actionProto, filepath.Join(outputPath, "ac.textproto")); err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions go/pkg/tool/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestTool_DownloadAction(t *testing.T) {
client := &Client{GrpcClient: e.Client.GrpcClient}
tmpDir := filepath.Join(t.TempDir(), "action_root")
os.MkdirAll(tmpDir, os.ModePerm)
if err := client.DownloadAction(context.Background(), acDg.String(), tmpDir); err != nil {
if err := client.DownloadAction(context.Background(), acDg.String(), tmpDir, true); err != nil {
t.Errorf("error DownloadAction: %v", err)
}

Expand Down Expand Up @@ -232,7 +232,7 @@ func TestTool_ExecuteAction(t *testing.T) {
tmpDir := filepath.Join(t.TempDir(), "action_root")
os.MkdirAll(tmpDir, os.ModePerm)
inputRoot := filepath.Join(tmpDir, "input")
if err := client.DownloadAction(context.Background(), acDg.String(), tmpDir); err != nil {
if err := client.DownloadAction(context.Background(), acDg.String(), tmpDir, true); err != nil {
t.Errorf("error DownloadAction: %v", err)
}
if err := os.WriteFile(filepath.Join(inputRoot, "i1"), []byte("i11"), 0644); err != nil {
Expand Down
Loading