From abb47926046442983400ce276c9f9287f9bc4525 Mon Sep 17 00:00:00 2001 From: Ola Rozenfeld Date: Tue, 23 Jan 2024 12:26:08 -0500 Subject: [PATCH] Local disk CAS --- go/pkg/client/BUILD.bazel | 1 + go/pkg/client/cas_download.go | 16 ++ go/pkg/client/client.go | 16 ++ go/pkg/diskcas/BUILD.bazel | 29 ++++ go/pkg/diskcas/atim_darwin.go | 16 ++ go/pkg/diskcas/atim_linux.go | 16 ++ go/pkg/diskcas/atim_windows.go | 18 +++ go/pkg/diskcas/diskcas.go | 281 +++++++++++++++++++++++++++++++++ go/pkg/diskcas/diskcas_test.go | 249 +++++++++++++++++++++++++++++ go/pkg/flags/flags.go | 12 +- 10 files changed, 653 insertions(+), 1 deletion(-) create mode 100644 go/pkg/diskcas/BUILD.bazel create mode 100644 go/pkg/diskcas/atim_darwin.go create mode 100644 go/pkg/diskcas/atim_linux.go create mode 100644 go/pkg/diskcas/atim_windows.go create mode 100644 go/pkg/diskcas/diskcas.go create mode 100644 go/pkg/diskcas/diskcas_test.go diff --git a/go/pkg/client/BUILD.bazel b/go/pkg/client/BUILD.bazel index 3f163c5d7..9d0281cab 100644 --- a/go/pkg/client/BUILD.bazel +++ b/go/pkg/client/BUILD.bazel @@ -25,6 +25,7 @@ go_library( "//go/pkg/command", "//go/pkg/contextmd", "//go/pkg/digest", + "//go/pkg/diskcas", "//go/pkg/filemetadata", "//go/pkg/io/impath", "//go/pkg/io/walker", diff --git a/go/pkg/client/cas_download.go b/go/pkg/client/cas_download.go index fccb8f718..82f6511f0 100644 --- a/go/pkg/client/cas_download.go +++ b/go/pkg/client/cas_download.go @@ -102,6 +102,17 @@ func (c *Client) DownloadOutputs(ctx context.Context, outs map[string]*TreeOutpu symlinks = append(symlinks, out) continue } + if c.diskCas != nil { + absPath := out.Path + if !filepath.IsAbs(absPath) { + absPath = filepath.Join(outDir, absPath) + } + if c.diskCas.Load(out.Digest, absPath) { + fullStats.Requested += out.Digest.Size + fullStats.Cached += out.Digest.Size + continue + } + } if _, ok := downloads[out.Digest]; ok { copies = append(copies, out) // All copies are effectivelly cached @@ -130,6 +141,11 @@ func (c *Client) DownloadOutputs(ctx context.Context, outs map[string]*TreeOutpu if err := cache.Update(absPath, md); err != nil { return fullStats, err } + if c.diskCas != nil { + if err := c.diskCas.Store(output.Digest, absPath); err != nil { + return fullStats, err + } + } } for _, out := range copies { perm := c.RegularMode diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index f3f48401d..ce1c0c6bb 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -19,6 +19,7 @@ import ( "github.com/bazelbuild/remote-apis-sdks/go/pkg/casng" "github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/diskcas" "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" "github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo" "github.com/pkg/errors" @@ -188,6 +189,7 @@ type Client struct { uploadOnce sync.Once downloadOnce sync.Once useBatchCompression UseBatchCompression + diskCas *diskcas.DiskCas } const ( @@ -333,6 +335,20 @@ func (o *TreeSymlinkOpts) Apply(c *Client) { c.TreeSymlinkOpts = o } +type DiskCasOpts struct { + Context context.Context + Path string + MaxCapacityGb float64 +} + +// Apply sets the client's TreeSymlinkOpts. +func (o *DiskCasOpts) Apply(c *Client) { + if o.Path != "" { + capBytes := uint64(o.MaxCapacityGb * 1024 * 1024 * 1024) + c.diskCas = diskcas.New(o.Context, o.Path, capBytes) + } +} + // MaxBatchDigests is maximum amount of digests to batch in upload and download operations. type MaxBatchDigests int diff --git a/go/pkg/diskcas/BUILD.bazel b/go/pkg/diskcas/BUILD.bazel new file mode 100644 index 000000000..a8c838769 --- /dev/null +++ b/go/pkg/diskcas/BUILD.bazel @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "diskcas", + srcs = [ + "atim_darwin.go", + "atim_linux.go", + "atim_windows.go", + "diskcas.go", + ], + importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/diskcas", + visibility = ["//visibility:public"], + deps = [ + "//go/pkg/digest", + "@com_github_golang_glog//:go_default_library", + ], +) + +go_test( + name = "diskcas_test", + srcs = ["diskcas_test.go"], + embed = [":diskcas"], + deps = [ + "//go/pkg/digest", + "//go/pkg/testutil", + "@com_github_pborman_uuid//:go_default_library", + "@org_golang_x_sync//errgroup:go_default_library", + ], +) diff --git a/go/pkg/diskcas/atim_darwin.go b/go/pkg/diskcas/atim_darwin.go new file mode 100644 index 000000000..887377c04 --- /dev/null +++ b/go/pkg/diskcas/atim_darwin.go @@ -0,0 +1,16 @@ +// Utility to get the last accessed time on Darwin. +package diskcas + +import ( + "os" + "syscall" + "time" +) + +func GetLastAccessTime(path string) (time.Time, error) { + info, err := os.Stat(path) + if err != nil { + return time.Time{}, err + } + return time.Unix(info.Sys().(*syscall.Stat_t).Atimespec.Unix()), nil +} diff --git a/go/pkg/diskcas/atim_linux.go b/go/pkg/diskcas/atim_linux.go new file mode 100644 index 000000000..c7f0c02f8 --- /dev/null +++ b/go/pkg/diskcas/atim_linux.go @@ -0,0 +1,16 @@ +// Utility to get the last accessed time on Linux. +package diskcas + +import ( + "os" + "syscall" + "time" +) + +func GetLastAccessTime(path string) (time.Time, error) { + info, err := os.Stat(path) + if err != nil { + return time.Time{}, err + } + return time.Unix(info.Sys().(*syscall.Stat_t).Atim.Unix()), nil +} diff --git a/go/pkg/diskcas/atim_windows.go b/go/pkg/diskcas/atim_windows.go new file mode 100644 index 000000000..c72d08a12 --- /dev/null +++ b/go/pkg/diskcas/atim_windows.go @@ -0,0 +1,18 @@ +// Utility to get the last accessed time on Windows. +package diskcas + +import ( + "os" + "syscall" + "time" +) + +// This will return correct values only if `fsutil behavior set disablelastaccess 0` is set. +// Tracking of last access time is disabled by default on Windows. +func GetLastAccessTime(path string) (time.Time, error) { + info, err := os.Stat(path) + if err != nil { + return time.Time{}, err + } + return time.Unix(0, info.Sys().(*syscall.Win32FileAttributeData).LastAccessTime.Nanoseconds()), nil +} diff --git a/go/pkg/diskcas/diskcas.go b/go/pkg/diskcas/diskcas.go new file mode 100644 index 000000000..f46b1b701 --- /dev/null +++ b/go/pkg/diskcas/diskcas.go @@ -0,0 +1,281 @@ +// Package diskcas implements a local disk LRU CAS cache. +package diskcas + +import ( + "container/heap" + "context" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + log "github.com/golang/glog" +) + +// An qitem is something we manage in a priority queue. +type qitem struct { + digest digest.Digest + lat time.Time // The last accessed time of the file. + index int // The index of the item in the heap. + mu sync.RWMutex // Protects the data-structure consistency for the given digest. +} + +// A priorityQueue implements heap.Interface and holds qitems. +type priorityQueue struct { + items []*qitem + n int +} + +func (q *priorityQueue) Len() int { + return q.n +} + +func (q *priorityQueue) Less(i, j int) bool { + // We want Pop to give us the oldest item. + return q.items[i].lat.Before(q.items[j].lat) +} + +func (q priorityQueue) Swap(i, j int) { + q.items[i], q.items[j] = q.items[j], q.items[i] + q.items[i].index = i + q.items[j].index = j +} + +func (q *priorityQueue) Push(x any) { + if q.n == cap(q.items) { + // Resize the queue + old := q.items + q.items = make([]*qitem, 2*cap(old)) // Initial capacity needs to be > 0. + copy(q.items, old) + } + item := x.(*qitem) + item.index = q.n + q.items[item.index] = item + q.n++ +} + +func (q *priorityQueue) Pop() any { + item := q.items[q.n-1] + q.items[q.n-1] = nil // avoid memory leak + item.index = -1 // for safety + q.n-- + return item +} + +// bumps item to the head of the queue. +func (q *priorityQueue) Bump(item *qitem) { + // Sanity check, necessary because of possible racing between Bump and GC: + if item.index < 0 || item.index >= q.n || q.items[item.index].digest != item.digest { + return + } + item.lat = time.Now() + heap.Fix(q, item.index) +} + +const maxConcurrentRequests = 1000 + +// DiskCas is a local disk LRU CAS cache. +type DiskCas struct { + root string // path to the root directory of the disk cache. + maxCapacityBytes uint64 // if disk size exceeds this, old items will be evicted as needed. + mu sync.Mutex // protects the queue. + store sync.Map // map of digests to qitems. + queue *priorityQueue // digests by last accessed time. + sizeBytes int64 // total size. + ctx context.Context + shutdown chan bool + gcTick uint64 + gcReq chan uint64 + testGcTicks chan uint64 +} + +func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCas { + res := &DiskCas{ + root: root, + maxCapacityBytes: maxCapacityBytes, + ctx: ctx, + queue: &priorityQueue{ + items: make([]*qitem, 1000), + }, + gcReq: make(chan uint64, maxConcurrentRequests), + shutdown: make(chan bool), + } + heap.Init(res.queue) + os.MkdirAll(root, os.ModePerm) + filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + // We log and continue on all errors, because cache read errors are not critical. + if err != nil { + log.Errorf("Error reading cache directory: %v", err) + return nil + } + if d.IsDir() { + return nil + } + fname := d.Name() + pair := strings.Split(fname, ".") + if len(pair) != 2 { + log.Errorf("Expected file name in the form hash/size, got %s", fname) + return nil + } + size, err := strconv.ParseInt(pair[1], 10, 64) + if err != nil { + log.Errorf("invalid size in digest %s: %s", fname, err) + return nil + } + dg, err := digest.New(pair[0], size) + if err != nil { + log.Errorf("invalid digest from file name %s: %v", fname, err) + return nil + } + atime, err := GetLastAccessTime(filepath.Join(root, fname)) + if err != nil { + log.Errorf("Error getting last accessed time of %s: %v", err) + return nil + } + it := &qitem{ + digest: dg, + lat: atime, + } + res.store.Store(dg, it) + atomic.AddInt64(&res.sizeBytes, dg.Size) + heap.Push(res.queue, it) + return nil + }) + go res.gc() + return res +} + +// Releases resources and terminates the GC daemon. Should be the last call to the DiskCas. +func (d *DiskCas) Shutdown() { + d.shutdown <- true +} + +func (d *DiskCas) TotalSizeBytes() uint64 { + return uint64(atomic.LoadInt64(&d.sizeBytes)) +} + +func (d *DiskCas) getPath(dg digest.Digest) string { + return filepath.Join(d.root, fmt.Sprintf("%s.%d", dg.Hash, dg.Size)) +} + +func (d *DiskCas) Store(dg digest.Digest, path string) error { + if dg.Size > int64(d.maxCapacityBytes) { + return fmt.Errorf("blob size %d exceeds DiskCas capacity %d", dg.Size, d.maxCapacityBytes) + } + it := &qitem{ + digest: dg, + lat: time.Now(), + } + it.mu.Lock() + defer it.mu.Unlock() + _, exists := d.store.LoadOrStore(dg, it) + if exists { + return nil + } + d.mu.Lock() + heap.Push(d.queue, it) + d.mu.Unlock() + if err := copyFile(path, d.getPath(dg), dg.Size); err != nil { + return err + } + newSize := uint64(atomic.AddInt64(&d.sizeBytes, dg.Size)) + if newSize > d.maxCapacityBytes { + select { + case d.gcReq <- atomic.AddUint64(&d.gcTick, 1): + default: + } + } + return nil +} + +func (d *DiskCas) gc() { + for { + select { + case <-d.shutdown: + return + case <-d.ctx.Done(): + return + case t := <-d.gcReq: + // Evict old entries until total size is below cap. + for uint64(atomic.LoadInt64(&d.sizeBytes)) > d.maxCapacityBytes { + d.mu.Lock() + it := heap.Pop(d.queue).(*qitem) + d.mu.Unlock() + atomic.AddInt64(&d.sizeBytes, -it.digest.Size) + it.mu.Lock() + os.Remove(d.getPath(it.digest)) + d.store.Delete(it.digest) + it.mu.Unlock() + } + if d.testGcTicks != nil { + select { + case d.testGcTicks <- t: + default: + } + } + } + } +} + +// Copy file contents retaining the source permissions. +func copyFile(src, dst string, size int64) error { + srcInfo, err := os.Stat(src) + if err != nil { + return err + } + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + if err := out.Chmod(srcInfo.Mode()); err != nil { + return err + } + defer out.Close() + _, err = io.Copy(out, in) + if err != nil { + return err + } + // Required sanity check: sometimes the copy pretends to succeed, but doesn't, if + // the file is being concurrently deleted. + dstInfo, err := os.Stat(dst) + if err != nil { + return err + } + if dstInfo.Size() != size { + return fmt.Errorf("copy of %s to %s failed: src/dst size mismatch: wanted %d, got %d", src, dst, size, dstInfo.Size()) + } + return err +} + +// If the digest exists in the disk cache, copy the file contents to the given path. +func (d *DiskCas) Load(dg digest.Digest, path string) bool { + iUntyped, loaded := d.store.Load(dg) + if !loaded { + return false + } + it := iUntyped.(*qitem) + it.mu.RLock() + if err := copyFile(d.getPath(dg), path, dg.Size); err != nil { + // It is not possible to prevent a race with GC; hence, we return false on copy errors. + it.mu.RUnlock() + return false + } + it.mu.RUnlock() + + d.mu.Lock() + d.queue.Bump(it) + d.mu.Unlock() + return true +} diff --git a/go/pkg/diskcas/diskcas_test.go b/go/pkg/diskcas/diskcas_test.go new file mode 100644 index 000000000..21b7604ec --- /dev/null +++ b/go/pkg/diskcas/diskcas_test.go @@ -0,0 +1,249 @@ +package diskcas + +import ( + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "sync/atomic" + "testing" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/testutil" + "github.com/pborman/uuid" + "golang.org/x/sync/errgroup" +) + +// Test utility only. Assumes all modifications are done, and at least one GC is expected. +func waitForGc(d *DiskCas) { + for t := range d.testGcTicks { + if t == d.gcTick { + return + } + } +} + +func TestStoreLoadPerm(t *testing.T) { + tests := []struct { + name string + executable bool + }{ + { + name: "+X", + executable: true, + }, + { + name: "-X", + executable: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + root := t.TempDir() + d := New(context.Background(), filepath.Join(root, "cache"), 20) + defer d.Shutdown() + fname, _ := testutil.CreateFile(t, tc.executable, "12345") + srcInfo, err := os.Stat(fname) + if err != nil { + t.Fatalf("os.Stat() failed: %v", err) + } + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.Store(dg, fname); err != nil { + t.Errorf("Store(%s, %s) failed: %v", dg, fname, err) + } + newName := filepath.Join(root, "new") + if !d.Load(dg, newName) { + t.Errorf("expected to load %s from the cache to %s", dg, newName) + } + fileInfo, err := os.Stat(newName) + if err != nil { + t.Fatalf("os.Stat(%s) failed: %v", newName, err) + } + if fileInfo.Mode() != srcInfo.Mode() { + t.Errorf("expected %s to have %v permissions, got: %v", newName, srcInfo.Mode(), fileInfo.Mode()) + } + contents, err := os.ReadFile(newName) + if err != nil { + t.Errorf("error reading from %s: %v", newName, err) + } + if string(contents) != "12345" { + t.Errorf("Cached result did not match: want %q, got %q", "12345", string(contents)) + } + }) + } +} + +func TestLoadNotFound(t *testing.T) { + root := t.TempDir() + d := New(context.Background(), filepath.Join(root, "cache"), 20) + defer d.Shutdown() + newName := filepath.Join(root, "new") + dg := digest.NewFromBlob([]byte("bla")) + if d.Load(dg, newName) { + t.Errorf("expected to not load %s from the cache to %s", dg, newName) + } +} + +func TestGcOldest(t *testing.T) { + root := t.TempDir() + d := New(context.Background(), filepath.Join(root, "cache"), 20) + defer d.Shutdown() + d.testGcTicks = make(chan uint64, 1) + for i := 0; i < 5; i++ { + fname, _ := testutil.CreateFile(t, false, fmt.Sprintf("aaa %d", i)) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.Store(dg, fname); err != nil { + t.Errorf("Store(%s, %s) failed: %v", dg, fname, err) + } + } + waitForGc(d) + if d.TotalSizeBytes() != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, d.TotalSizeBytes()) + } + newName := filepath.Join(root, "new") + for i := 0; i < 5; i++ { + dg := digest.NewFromBlob([]byte(fmt.Sprintf("aaa %d", i))) + if d.Load(dg, newName) != (i > 0) { + t.Errorf("expected loaded to be %v for %s from the cache to %s", i > 0, dg, newName) + } + } +} + +// We say that Last Access Time is behaving accurately on a system if reading from the file +// bumps the LAT time forward. From experience, Mac and Linux Debian are accurate. Ubuntu -- not. +// From experience, even when the LAT gets modified on access on Ubuntu, it can be imprecise to +// an order of seconds (!). +func isSystemLastAccessTimeAccurate(t *testing.T) bool { + fname, _ := testutil.CreateFile(t, false, "foo") + lat, _ := GetLastAccessTime(fname) + os.ReadFile(fname) + newLat, _ := GetLastAccessTime(fname) + return lat.Before(newLat) +} + +func TestInitFromExisting(t *testing.T) { + if !isSystemLastAccessTimeAccurate(t) { + // This effectively skips the test on Ubuntu, because to make the test work there, + // we would need to inject too many / too long time.Sleep statements to beat the system's + // inaccuracy. + t.Logf("Skipping TestInitFromExisting, because system Last Access Time is unreliable.") + return + } + root := t.TempDir() + d := New(context.Background(), filepath.Join(root, "cache"), 20) + for i := 0; i < 4; i++ { + fname, _ := testutil.CreateFile(t, false, fmt.Sprintf("aaa %d", i)) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.Store(dg, fname); err != nil { + t.Errorf("Store(%s, %s) failed: %v", dg, fname, err) + } + } + newName := filepath.Join(root, "new") + dg := digest.NewFromBlob([]byte("aaa 0")) + if !d.Load(dg, newName) { // Now 0 has been accessed, 1 is the oldest file. + t.Errorf("expected %s to be cached", dg) + } + d.Shutdown() + + // Re-initialize from existing files. + d = New(context.Background(), filepath.Join(root, "cache"), 20) + defer d.Shutdown() + d.testGcTicks = make(chan uint64, 1) + + // Check old files are cached: + dg = digest.NewFromBlob([]byte("aaa 1")) + if !d.Load(dg, newName) { // Now 1 has been accessed, 2 is the oldest file. + t.Errorf("expected %s to be cached", dg) + } + fname, _ := testutil.CreateFile(t, false, "aaa 4") + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if d.TotalSizeBytes() != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, d.TotalSizeBytes()) + } + // Trigger a GC by adding a new file. + if err := d.Store(dg, fname); err != nil { + t.Errorf("Store(%s, %s) failed: %v", dg, fname, err) + } + waitForGc(d) + dg = digest.NewFromBlob([]byte("aaa 2")) + if d.Load(dg, newName) { + t.Errorf("expected to not load %s from the cache to %s", dg, newName) + } +} + +func TestThreadSafety(t *testing.T) { + root := t.TempDir() + os.MkdirAll(filepath.Join(root, "orig"), os.ModePerm) + os.MkdirAll(filepath.Join(root, "new"), os.ModePerm) + nFiles := 10 + attempts := 5000 + // All blobs are size 5 exactly. We will have half the byte capacity we need. + d := New(context.Background(), filepath.Join(root, "cache"), uint64(nFiles*5)/2) + d.testGcTicks = make(chan uint64, attempts) + defer d.Shutdown() + var files []string + var dgs []digest.Digest + for i := 0; i < nFiles; i++ { + fname := filepath.Join(root, "orig", fmt.Sprintf("%d", i)) + if err := os.WriteFile(fname, []byte(fmt.Sprintf("aa %02d", i)), 0644); err != nil { + t.Fatalf("os.WriteFile: %v", err) + } + files = append(files, fname) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + dgs = append(dgs, dg) + if err := d.Store(dg, fname); err != nil { + t.Errorf("Store(%s, %s) failed: %v", dg, fname, err) + } + } + // Randomly access and store files from different threads. + eg, _ := errgroup.WithContext(context.Background()) + var hits atomic.Uint64 + var runs []int + for k := 0; k < attempts; k++ { + eg.Go(func() error { + i := rand.Intn(nFiles) + runs = append(runs, i) + newName := filepath.Join(root, "new", uuid.New()) + if d.Load(dgs[i], newName) { + hits.Add(1) + contents, err := os.ReadFile(newName) + if err != nil { + return fmt.Errorf("os.ReadFile: %v", err) + } + want := fmt.Sprintf("aa %02d", i) + if string(contents) != want { + return fmt.Errorf("Cached result did not match: want %q, got %q for digest %v", want, string(contents), dgs[i]) + } + } else if err := d.Store(dgs[i], files[i]); err != nil { + return fmt.Errorf("Store: %v", err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + t.Error(err) + } + waitForGc(d) + if d.TotalSizeBytes() != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, d.TotalSizeBytes()) + } + if int(hits.Load()) < attempts/2 { + t.Errorf("Unexpectedly low cache hits %d out of %d attempts", hits.Load(), attempts) + } +} diff --git a/go/pkg/flags/flags.go b/go/pkg/flags/flags.go index d7acc01b3..c28b37b26 100644 --- a/go/pkg/flags/flags.go +++ b/go/pkg/flags/flags.go @@ -74,6 +74,8 @@ var ( KeepAliveTimeout = flag.Duration("grpc_keepalive_timeout", 20*time.Second, "After having pinged for keepalive check, the client waits for a duration of Timeout and if no activity is seen even after that the connection is closed. Default is 20s.") // KeepAlivePermitWithoutStream specifies gRPCs keepalive permitWithoutStream parameter. KeepAlivePermitWithoutStream = flag.Bool("grpc_keepalive_permit_without_stream", false, "If true, client sends keepalive pings even with no active RPCs; otherwise, doesn't send pings even if time and timeout are set. Default is false.") + DiskCasPath = flag.String("disk_cas_path", "", "If set, will use a local disk cache for downloaded outputs.") + DiskCasCapacityGb = flag.Float64("disk_cas_max_gb", 1.0, "Maximum GB to store in the local disk cache. A Noop if --disk_cas_path is not set.") ) func init() { @@ -89,7 +91,15 @@ func init() { // NewClientFromFlags connects to a remote execution service and returns a client suitable for higher-level // functionality. It uses the flags from above to configure the connection to remote execution. func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client, error) { - opts = append(opts, []client.Opt{client.CASConcurrency(*CASConcurrency), client.StartupCapabilities(*StartupCapabilities)}...) + opts = append(opts, []client.Opt{ + client.CASConcurrency(*CASConcurrency), + client.StartupCapabilities(*StartupCapabilities), + &client.DiskCasOpts{ + Context: ctx, + Path: *DiskCasPath, + MaxCapacityGb: *DiskCasCapacityGb, + }, + }...) if len(RPCTimeouts) > 0 { timeouts := make(map[string]time.Duration) for rpc, d := range client.DefaultRPCTimeouts {