Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stats in BatchDownloadBlobs call #509

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/pkg/client/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c *Client) CheckCapabilities(ctx context.Context) (err error) {
}
for _, compressor := range c.serverCaps.CacheCapabilities.SupportedBatchUpdateCompressors {
if compressor == repb.Compressor_ZSTD {
c.batchCompression = true
c.useBatchCompression = UseBatchCompression(true)
}
}
}
Expand Down
58 changes: 40 additions & 18 deletions go/pkg/client/cas_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,19 @@ func (c *Client) DownloadDirectory(ctx context.Context, d digest.Digest, outDir
// zstdDecoder is a shared instance that should only be used in stateless mode, i.e. only by calling DecodeAll()
var zstdDecoder, _ = zstd.NewReader(nil)

// BatchDownloadBlobs downloads a number of blobs from the CAS to memory. They must collectively be below the
// maximum total size for a batch read, which is about 4 MB (see MaxBatchSize). Digests must be
// computed in advance by the caller. In case multiple errors occur during the blob read, the
// last error will be returned.
func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (map[digest.Digest][]byte, error) {
// CompressedBlobInfo is primarily used to store stats about compressed blob size
// in addition to the actual blob data.
type CompressedBlobInfo struct {
CompressedSize int64
Data []byte
}

func (c *Client) BatchDownloadBlobsWithStats(ctx context.Context, dgs []digest.Digest) (map[digest.Digest]CompressedBlobInfo, error) {
if len(dgs) > int(c.MaxBatchDigests) {
return nil, fmt.Errorf("batch read of %d total blobs exceeds maximum of %d", len(dgs), c.MaxBatchDigests)
}
req := &repb.BatchReadBlobsRequest{InstanceName: c.InstanceName}
if c.batchCompression {
if c.useBatchCompression {
req.AcceptableCompressors = []repb.Compressor_Value{repb.Compressor_ZSTD}
}
var sz int64
Expand All @@ -211,9 +214,9 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m
if sz > int64(c.MaxBatchSize) {
return nil, fmt.Errorf("batch read of %d total bytes exceeds maximum of %d", sz, c.MaxBatchSize)
}
res := make(map[digest.Digest][]byte)
res := make(map[digest.Digest]CompressedBlobInfo)
if foundEmpty {
res[digest.Empty] = nil
res[digest.Empty] = CompressedBlobInfo{}
}
opts := c.RPCOpts()
closure := func() error {
Expand Down Expand Up @@ -244,10 +247,12 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m
errDg = r.Digest
errMsg = r.Status.Message
} else {
CompressedSize := len(r.Data)
switch r.Compressor {
case repb.Compressor_IDENTITY:
// do nothing
case repb.Compressor_ZSTD:
CompressedSize = len(r.Data)
b, err := zstdDecoder.DecodeAll(r.Data, nil)
if err != nil {
errDg = r.Digest
Expand All @@ -260,7 +265,11 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m
errMsg = fmt.Sprintf("blob returned with unsupported compressor %s", r.Compressor)
continue
}
res[digest.NewFromProtoUnvalidated(r.Digest)] = r.Data
bi := CompressedBlobInfo{
CompressedSize: int64(CompressedSize),
Data: r.Data,
}
res[digest.NewFromProtoUnvalidated(r.Digest)] = bi
}
}
req.Digests = failedDgs
Expand All @@ -275,6 +284,19 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m
return res, c.Retrier.Do(ctx, closure)
}

// BatchDownloadBlobs downloads a number of blobs from the CAS to memory. They must collectively be below the
// maximum total size for a batch read, which is about 4 MB (see MaxBatchSize). Digests must be
// computed in advance by the caller. In case multiple errors occur during the blob read, the
// last error will be returned.
func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (map[digest.Digest][]byte, error) {
biRes, err := c.BatchDownloadBlobsWithStats(ctx, dgs)
res := make(map[digest.Digest][]byte)
for dg, bi := range biRes {
res[dg] = bi.Data
}
return res, err
}

// ReadBlob fetches a blob from the CAS into a byte slice.
// Returns the size of the blob and the amount of bytes moved through the wire.
func (c *Client) ReadBlob(ctx context.Context, d digest.Digest) ([]byte, *MovedBytesMetadata, error) {
Expand Down Expand Up @@ -727,20 +749,20 @@ func (c *Client) download(ctx context.Context, data []*downloadRequest) {

func (c *Client) downloadBatch(ctx context.Context, batch []digest.Digest, reqs map[digest.Digest][]*downloadRequest) {
contextmd.Infof(ctx, log.Level(3), "Downloading batch of %d files", len(batch))
bchMap, err := c.BatchDownloadBlobs(ctx, batch)
bchMap, err := c.BatchDownloadBlobsWithStats(ctx, batch)
if err != nil {
afterDownload(batch, reqs, map[digest.Digest]*MovedBytesMetadata{}, err)
return
}
for _, dg := range batch {
bi := bchMap[dg]
stats := &MovedBytesMetadata{
Requested: dg.Size,
LogicalMoved: dg.Size,
// There's no compression for batch requests, and there's no such thing as "partial" data for
// a blob since they're all inlined in the response.
RealMoved: dg.Size,
RealMoved: bi.CompressedSize,
}
data := bchMap[dg]
for i, r := range reqs[dg] {
perm := c.RegularMode
if r.output.IsExecutable {
Expand All @@ -750,7 +772,7 @@ func (c *Client) downloadBatch(ctx context.Context, batch []digest.Digest, reqs
// We only report it to the first client to prevent double accounting.
r.wait <- &downloadResponse{
stats: stats,
err: os.WriteFile(filepath.Join(r.outDir, r.output.Path), data, perm),
err: os.WriteFile(filepath.Join(r.outDir, r.output.Path), bi.Data, perm),
}
if i == 0 {
// Prevent races by not writing to the original stats.
Expand Down Expand Up @@ -859,20 +881,20 @@ func (c *Client) downloadNonUnified(ctx context.Context, outDir string, outputs
}
if len(batch) > 1 {
contextmd.Infof(ctx, log.Level(3), "Downloading batch of %d files", len(batch))
bchMap, err := c.BatchDownloadBlobs(eCtx, batch)
bchMap, err := c.BatchDownloadBlobsWithStats(eCtx, batch)
for _, dg := range batch {
data := bchMap[dg]
bi := bchMap[dg]
out := outputs[dg]
perm := c.RegularMode
if out.IsExecutable {
perm = c.ExecutableMode
}
if err := os.WriteFile(filepath.Join(outDir, out.Path), data, perm); err != nil {
if err := os.WriteFile(filepath.Join(outDir, out.Path), bi.Data, perm); err != nil {
return err
}
statsMu.Lock()
fullStats.LogicalMoved += int64(len(data))
fullStats.RealMoved += int64(len(data))
fullStats.LogicalMoved += int64(len(bi.Data))
fullStats.RealMoved += bi.CompressedSize
statsMu.Unlock()
}
if err != nil {
Expand Down
46 changes: 46 additions & 0 deletions go/pkg/client/cas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1867,3 +1867,49 @@
})
}
}

func TestBatchDownloadBlobsCompressed(t *testing.T) {
t.Parallel()
ctx := context.Background()
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
fakeCAS := fakes.NewCAS()
defer listener.Close()
mrahs marked this conversation as resolved.
Show resolved Hide resolved
server := grpc.NewServer()
repb.RegisterContentAddressableStorageServer(server, fakeCAS)
go server.Serve(listener)

Check failure on line 1882 in go/pkg/client/cas_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `server.Serve` is not checked (errcheck)
defer server.Stop()
c, err := client.NewClient(ctx, instance, client.DialParams{
Service: listener.Addr().String(),
NoSecurity: true,
}, client.StartupCapabilities(false))
if err != nil {
t.Fatalf("Error connecting to server: %v", err)
}
defer c.Close()

fooDigest := fakeCAS.Put([]byte("foo"))
barDigest := fakeCAS.Put([]byte("bar"))
digests := []digest.Digest{fooDigest, barDigest}
client.UseBatchCompression(true).Apply(c)

wantBlobs := map[digest.Digest]client.CompressedBlobInfo{
fooDigest: client.CompressedBlobInfo{
CompressedSize: 16,
Data: []byte("foo"),
},
barDigest: client.CompressedBlobInfo{
CompressedSize: 16,
Data: []byte("bar"),
},
}
gotBlobs, err := c.BatchDownloadBlobsWithStats(ctx, digests)
if err != nil {
t.Errorf("client.BatchDownloadBlobs(ctx, digests) failed: %v", err)
}
if diff := cmp.Diff(wantBlobs, gotBlobs); diff != "" {
t.Errorf("client.BatchDownloadBlobs(ctx, digests) had diff (want -> got):\n%s", diff)
}
}
2 changes: 1 addition & 1 deletion go/pkg/client/cas_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (c *Client) BatchWriteBlobs(ctx context.Context, blobs map[digest.Digest][]
Digest: k.ToProto(),
Data: b,
}
if c.batchCompression && c.shouldCompress(k.Size) {
if bool(c.useBatchCompression) && c.shouldCompress(k.Size) {
r.Data = zstdEncoder.EncodeAll(r.Data, nil)
r.Compressor = repb.Compressor_ZSTD
sz += int64(len(r.Data))
Expand Down
14 changes: 12 additions & 2 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ type Client struct {
// UnifiedDownloadTickDuration specifies how often the unified download daemon flushes the pending requests.
UnifiedDownloadTickDuration UnifiedDownloadTickDuration
// TreeSymlinkOpts controls how symlinks are handled when constructing a tree.
TreeSymlinkOpts *TreeSymlinkOpts
TreeSymlinkOpts *TreeSymlinkOpts

serverCaps *repb.ServerCapabilities
useBatchOps UseBatchOps
casConcurrency int64
Expand All @@ -186,7 +187,7 @@ type Client struct {
creds credentials.PerRPCCredentials
uploadOnce sync.Once
downloadOnce sync.Once
batchCompression bool
useBatchCompression UseBatchCompression
}

const (
Expand Down Expand Up @@ -389,6 +390,15 @@ func (u UseBatchOps) Apply(c *Client) {
c.useBatchOps = u
}

// UseBatchCompression is currently set to true when the server has
// SupportedBatchUpdateCompressors capability and supports ZSTD compression.
type UseBatchCompression bool

// Apply sets the batchCompression flag on a client.
func (u UseBatchCompression) Apply(c *Client) {
c.useBatchCompression = u
}

// CASConcurrency is the number of simultaneous requests that will be issued for CAS upload and
// download operations.
type CASConcurrency int
Expand Down
20 changes: 17 additions & 3 deletions go/pkg/fakes/cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,24 @@ func (f *CAS) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsReques
f.mu.Lock()
f.reads[dg]++
f.mu.Unlock()

useZSTDCompression := false
compressor := repb.Compressor_IDENTITY
for _, c := range req.AcceptableCompressors {
if c == repb.Compressor_ZSTD {
compressor = repb.Compressor_ZSTD
useZSTDCompression = true
break
}
}
if useZSTDCompression {
data = zstdEncoder.EncodeAll(data, nil)
}
resps = append(resps, &repb.BatchReadBlobsResponse_Response{
Digest: dgPb,
Status: status.New(codes.OK, "").Proto(),
Data: data,
Digest: dgPb,
Status: status.New(codes.OK, "").Proto(),
Data: data,
Compressor: compressor,
})
}
return &repb.BatchReadBlobsResponse{Responses: resps}, nil
Expand Down
Loading