From bf1864171f7d48cd8dba30a271be4544029ff3af Mon Sep 17 00:00:00 2001 From: Eddie Torres Date: Tue, 31 Oct 2023 21:48:51 +0000 Subject: [PATCH] Batch DescribeVolume API requests Signed-off-by: Eddie Torres --- pkg/cloud/cloud.go | 193 +++++++++++++++++++-- pkg/cloud/cloud_test.go | 268 +++++++++++++++++++++++++++--- pkg/driver/controller.go | 10 +- pkg/driver/controller_test.go | 8 +- tests/e2e/dynamic_provisioning.go | 2 +- tests/e2e/pre_provsioning.go | 2 +- 6 files changed, 433 insertions(+), 50 deletions(-) diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 9dbd0ef9a..d8d688d0f 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -33,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher" dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" "k8s.io/apimachinery/pkg/util/wait" @@ -137,6 +138,12 @@ const ( AwsEbsDriverTagKey = "ebs.csi.aws.com/cluster" ) +// Batcher +const ( + volumeIDBatcher batcherType = iota + volumeTagBatcher +) + var ( // ErrMultiDisks is an error that is returned when multiple // disks are found with the same volume name. @@ -234,21 +241,41 @@ type ec2ListSnapshotsResponse struct { NextToken *string } +// batcherType is an enum representing the types of batchers available. +type batcherType int + +// batcherManager maintains a collection of batchers for different types of tasks. +type batcherManager struct { + batchers map[batcherType]*batcher.Batcher[string, *ec2.Volume] +} + type cloud struct { region string ec2 ec2iface.EC2API dm dm.DeviceManager + bm *batcherManager } var _ Cloud = &cloud{} // NewCloud returns a new instance of AWS cloud // It panics if session is invalid -func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Cloud, error) { - return newEC2Cloud(region, awsSdkDebugLog, userAgentExtra) +func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (Cloud, error) { + c := newEC2Cloud(region, awsSdkDebugLog, userAgentExtra) + + if batching { + klog.V(4).InfoS("NewCloud: batching enabled") + cloudInstance, ok := c.(*cloud) + if !ok { + return nil, fmt.Errorf("expected *cloud type but got %T", c) + } + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + } + + return c, nil } -func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Cloud, error) { +func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) Cloud { awsConfig := &aws.Config{ Region: aws.String(region), CredentialsChainVerboseErrors: aws.Bool(true), @@ -295,7 +322,135 @@ func newEC2Cloud(region string, awsSdkDebugLog bool, userAgentExtra string) (Clo region: region, dm: dm.NewDeviceManager(), ec2: svc, - }, nil + } +} + +// newBatcherManager initializes a new instance of batcherManager. +func newBatcherManager(svc ec2iface.EC2API) *batcherManager { + return &batcherManager{ + batchers: map[batcherType]*batcher.Batcher[string, *ec2.Volume]{ + volumeIDBatcher: batcher.New(500, 1*time.Second, func(ids []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, ids, volumeIDBatcher) + }), + volumeTagBatcher: batcher.New(500, 1*time.Second, func(names []string) (map[string]*ec2.Volume, error) { + return execBatchDescribeVolumes(svc, names, volumeTagBatcher) + }), + }, + } +} + +// getBatcher fetches a specific type of batcher from the batcherManager. +func (bm *batcherManager) getBatcher(b batcherType) *batcher.Batcher[string, *ec2.Volume] { + return bm.batchers[b] +} + +// executes a batched DescribeVolumes API call depending on the type of batcher. +func execBatchDescribeVolumes(svc ec2iface.EC2API, input []string, batcher batcherType) (map[string]*ec2.Volume, error) { + var request *ec2.DescribeVolumesInput + + switch batcher { + case volumeIDBatcher: + klog.V(7).InfoS("execBatchDescribeVolumes", "volumeIds", input) + request = &ec2.DescribeVolumesInput{ + VolumeIds: aws.StringSlice(input), + } + + case volumeTagBatcher: + klog.V(7).InfoS("execBatchDescribeVolumes", "names", input) + filters := []*ec2.Filter{ + { + Name: aws.String("tag:" + VolumeNameTagKey), + Values: aws.StringSlice(input), + }, + } + request = &ec2.DescribeVolumesInput{ + Filters: filters, + } + + default: + return nil, fmt.Errorf("execBatchDescribeVolumes: unsupported request type") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + resp, err := describeVolumes(ctx, svc, request) + if err != nil { + return nil, err + } + + result := make(map[string]*ec2.Volume) + + for _, volume := range resp { + key, err := extractVolumeKey(volume, batcher) + if err != nil { + klog.Warningf("execBatchDescribeVolumes: skipping volume: %v, reason: %v", volume, err) + continue + } + result[key] = volume + } + + klog.V(7).InfoS("execBatchDescribeVolumes: success", "result", result) + return result, nil +} + +// batchDescribeVolumes processes a DescribeVolumes request. Depending on the request, +// it determines the appropriate batcher to use, queues the task, and waits for the result. +func (c *cloud) batchDescribeVolumes(request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { + var bType batcherType + var task string + + switch { + case len(request.VolumeIds) == 1 && request.VolumeIds[0] != nil: + bType = volumeIDBatcher + task = *request.VolumeIds[0] + + case len(request.Filters) == 1 && *request.Filters[0].Name == "tag:"+VolumeNameTagKey && len(request.Filters[0].Values) == 1: + bType = volumeTagBatcher + task = *request.Filters[0].Values[0] + + default: + return nil, fmt.Errorf("batchDescribeVolumes: invalid request, request: %v", request) + } + + ch := make(chan batcher.BatchResult[*ec2.Volume]) + + b := c.bm.getBatcher(bType) + b.AddTask(task, ch) + + r := <-ch + + if r.Err != nil { + return nil, r.Err + } + if r.Result == nil { + return nil, fmt.Errorf("batchDescribeVolumes: no volume found %s", task) + } + return r.Result, nil +} + +// extractVolumeKey retrieves the key associated with a given volume based on the batcher type. +// For the volumeIDBatcher type, it returns the volume's ID. +// For other types, it searches for the VolumeNameTagKey within the volume's tags. +func extractVolumeKey(v *ec2.Volume, batcher batcherType) (string, error) { + if batcher == volumeIDBatcher { + if v.VolumeId == nil { + return "", errors.New("extractVolumeKey: missing volume ID") + } + return *v.VolumeId, nil + } + for _, tag := range v.Tags { + klog.V(7).InfoS("extractVolumeKey: processing tag", "volume", v, "*tag.Key", *tag.Key, "VolumeNameTagKey", VolumeNameTagKey) + if tag.Key == nil || tag.Value == nil { + klog.V(7).InfoS("extractVolumeKey: skipping volume due to missing tag", "volume", v, "tag", tag) + continue + } + if *tag.Key == VolumeNameTagKey { + klog.V(7).InfoS("extractVolumeKey: found volume name tag", "volume", v, "tag", tag) + return *tag.Value, nil + } + } + return "", errors.New("extractVolumeKey: missing VolumeNameTagKey in volume tags") } func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions *DiskOptions) (*Disk, error) { @@ -463,6 +618,7 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * return nil, fmt.Errorf("could not attach tags to volume: %v. %w", volumeID, err) } } + klog.InfoS("volume created", "volumeID", volumeID) return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone, SnapshotID: snapshotID, OutpostArn: outpostArn}, nil } @@ -729,7 +885,7 @@ func (c *cloud) WaitForAttachmentState(ctx context.Context, volumeID, expectedSt return true, nil } // continue waiting - klog.V(4).InfoS("Waiting for volume state", "volumeID", volumeID, "actual", attachmentState, "desired", expectedState) + klog.InfoS("Waiting for volume state", "volumeID", volumeID, "actual", attachmentState, "desired", expectedState) return false, nil } @@ -954,11 +1110,11 @@ func (c *cloud) EnableFastSnapshotRestores(ctx context.Context, availabilityZone return response, nil } -func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { +func describeVolumes(ctx context.Context, svc ec2iface.EC2API, request *ec2.DescribeVolumesInput) ([]*ec2.Volume, error) { var volumes []*ec2.Volume var nextToken *string for { - response, err := c.ec2.DescribeVolumesWithContext(ctx, request) + response, err := svc.DescribeVolumesWithContext(ctx, request) if err != nil { return nil, err } @@ -969,14 +1125,25 @@ func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput } request.NextToken = nextToken } + return volumes, nil +} - if l := len(volumes); l > 1 { - return nil, ErrMultiDisks - } else if l < 1 { - return nil, ErrNotFound - } +func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { + if c.bm == nil { + volumes, err := describeVolumes(ctx, c.ec2, request) + if err != nil { + return nil, err + } - return volumes[0], nil + if l := len(volumes); l > 1 { + return nil, ErrMultiDisks + } else if l < 1 { + return nil, ErrNotFound + } + return volumes[0], nil + } else { + return c.batchDescribeVolumes(request) + } } func (c *cloud) getInstance(ctx context.Context, nodeID string) (*ec2.Instance, error) { diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 8e4b89f7b..2e7fb0f30 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -23,6 +23,7 @@ import ( "reflect" "sort" "strings" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -44,6 +45,208 @@ const ( defaultPath = "/dev/xvdaa" ) +func generateVolumes(volIdCount, volTagCount int) []*ec2.Volume { + volumes := make([]*ec2.Volume, 0, volIdCount+volTagCount) + + for i := 0; i < volIdCount; i++ { + volumeID := fmt.Sprintf("vol-%d", i) + volumes = append(volumes, &ec2.Volume{VolumeId: aws.String(volumeID)}) + } + + for i := 0; i < volTagCount; i++ { + volumeName := fmt.Sprintf("vol-name-%d", i) + volumes = append(volumes, &ec2.Volume{Tags: []*ec2.Tag{{Key: aws.String(VolumeNameTagKey), Value: aws.String(volumeName)}}}) + } + + return volumes +} + +func extractVolumeIdentifiers(volumes []*ec2.Volume) (volumeIDs []string, volumeNames []string) { + for _, volume := range volumes { + if volume.VolumeId != nil { + volumeIDs = append(volumeIDs, *volume.VolumeId) + } + for _, tag := range volume.Tags { + if tag.Key != nil && *tag.Key == VolumeNameTagKey && tag.Value != nil { + volumeNames = append(volumeNames, *tag.Value) + } + } + } + return volumeIDs, volumeNames +} + +func TestBatchDescribeVolumes(t *testing.T) { + testCases := []struct { + name string + volumes []*ec2.Volume + expErr error + mockFunc func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) + }{ + { + name: "TestBatchDescribeVolumes: volume by ID", + volumes: generateVolumes(10, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: volume by tag", + volumes: generateVolumes(0, 10), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: volume by ID and tag", + volumes: generateVolumes(10, 10), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(2) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: max capacity", + volumes: generateVolumes(500, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(1) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: capacity exceeded", + volumes: generateVolumes(550, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(2) + }, + expErr: nil, + }, + { + name: "TestBatchDescribeVolumes: EC2 API generic error", + volumes: generateVolumes(4, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1) + }, + expErr: fmt.Errorf("Generic EC2 API error"), + }, + { + name: "TestBatchDescribeVolumes: volume not found", + volumes: generateVolumes(1, 0), + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(nil, expErr).Times(1) + }, + expErr: fmt.Errorf("volume not found"), + }, + { + name: "TestBatchDescribeVolumes: invalid tag", + volumes: []*ec2.Volume{ + { + Tags: []*ec2.Tag{ + {Key: aws.String("InvalidKey"), Value: aws.String("InvalidValue")}, + }, + }, + }, + mockFunc: func(mockEC2 *MockEC2API, expErr error, volumes []*ec2.Volume) { + + volumeOutput := &ec2.DescribeVolumesOutput{Volumes: volumes} + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), gomock.Any()).Return(volumeOutput, expErr).Times(0) + }, + expErr: fmt.Errorf("invalid tag"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + cloudInstance := c.(*cloud) + cloudInstance.bm = newBatcherManager(cloudInstance.ec2) + + tc.mockFunc(mockEC2, tc.expErr, tc.volumes) + volumeIDs, volumeNames := extractVolumeIdentifiers(tc.volumes) + executeDescribeVolumesTest(t, cloudInstance, volumeIDs, volumeNames, tc.expErr) + }) + } +} +func executeDescribeVolumesTest(t *testing.T, c *cloud, volumeIDs, volumeNames []string, expErr error) { + var wg sync.WaitGroup + + getRequestForID := func(id string) *ec2.DescribeVolumesInput { + return &ec2.DescribeVolumesInput{VolumeIds: []*string{&id}} + } + + getRequestForTag := func(volName string) *ec2.DescribeVolumesInput { + return &ec2.DescribeVolumesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:" + VolumeNameTagKey), + Values: []*string{&volName}, + }, + }, + } + } + + requests := make([]*ec2.DescribeVolumesInput, 0, len(volumeIDs)+len(volumeNames)) + for _, volumeID := range volumeIDs { + requests = append(requests, getRequestForID(volumeID)) + } + for _, volumeName := range volumeNames { + requests = append(requests, getRequestForTag(volumeName)) + } + + r := make([]chan *ec2.Volume, len(requests)) + e := make([]chan error, len(requests)) + + for i, request := range requests { + wg.Add(1) + r[i] = make(chan *ec2.Volume, 1) + e[i] = make(chan error, 1) + + go func(req *ec2.DescribeVolumesInput, resultCh chan *ec2.Volume, errCh chan error) { + defer wg.Done() + volume, err := c.batchDescribeVolumes(req) + if err != nil { + errCh <- err + return + } + resultCh <- volume + // passing `request` as a parameter to create a copy + // TODO remove after https://github.com/golang/go/discussions/56010 is implemented + }(request, r[i], e[i]) + } + + wg.Wait() + + for i := range requests { + select { + case result := <-r[i]: + if result == nil { + t.Errorf("Received nil result for a request") + } + case err := <-e[i]: + if expErr == nil { + t.Errorf("Error while processing request: %v", err) + } + if !errors.Is(err, expErr) { + t.Errorf("Expected error %v, but got %v", expErr, err) + } + default: + t.Errorf("Did not receive a result or an error for a request") + } + } +} + func TestCreateDisk(t *testing.T) { testCases := []struct { name string @@ -731,9 +934,9 @@ func TestAttachDisk(t *testing.T) { attachRequest := createAttachRequest(volumeID, nodeID, path) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest).Return(createAttachVolumeOutput(volumeID, nodeID, path, "attached"), nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, path, "attached"), nil), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().AttachVolumeWithContext(gomock.Any(), attachRequest).Return(createAttachVolumeOutput(volumeID, nodeID, path, "attached"), nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil), ) }, }, @@ -752,8 +955,8 @@ func TestAttachDisk(t *testing.T) { assert.NoError(t, err) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID, volumeID), nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, path, "attached"), nil)) + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID, volumeID), nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, path, "attached"), nil)) }, }, { @@ -767,8 +970,8 @@ func TestAttachDisk(t *testing.T) { attachRequest := createAttachRequest(volumeID, nodeID, path) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().AttachVolumeWithContext(ctx, attachRequest).Return(nil, errors.New("AttachVolume error")), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().AttachVolumeWithContext(gomock.Any(), attachRequest).Return(nil, errors.New("AttachVolume error")), ) }, }, @@ -835,9 +1038,9 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, nil), - mockEC2.EXPECT().DescribeVolumesWithContext(ctx, volumeRequest).Return(createDescribeVolumesOutput(volumeID, nodeID, "", "detached"), nil), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, nil), + mockEC2.EXPECT().DescribeVolumesWithContext(gomock.Any(), volumeRequest).Return(createDescribeVolumesOutput([]*string{&volumeID}, nodeID, "", "detached"), nil), ) }, }, @@ -851,8 +1054,8 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, errors.New("DetachVolume error")), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, errors.New("DetachVolume error")), ) }, }, @@ -866,8 +1069,8 @@ func TestDetachDisk(t *testing.T) { detachRequest := createDetachRequest(volumeID, nodeID) gomock.InOrder( - mockEC2.EXPECT().DescribeInstancesWithContext(ctx, instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), - mockEC2.EXPECT().DetachVolumeWithContext(ctx, detachRequest).Return(nil, ErrNotFound), + mockEC2.EXPECT().DescribeInstancesWithContext(gomock.Any(), instanceRequest).Return(newDescribeInstancesOutput(nodeID), nil), + mockEC2.EXPECT().DetachVolumeWithContext(gomock.Any(), detachRequest).Return(nil, ErrNotFound), ) }, }, @@ -939,6 +1142,12 @@ func TestGetDiskByName(t *testing.T) { Size: aws.Int64(util.BytesToGiB(tc.volumeCapacity)), AvailabilityZone: aws.String(tc.availabilityZone), OutpostArn: aws.String(tc.outpostArn), + Tags: []*ec2.Tag{ + { + Key: aws.String(VolumeNameTagKey), + Value: aws.String(tc.volumeName), + }, + }, } ctx := context.Background() @@ -2110,11 +2319,12 @@ func TestWaitForAttachmentState(t *testing.T) { } func newCloud(mockEC2 ec2iface.EC2API) Cloud { - return &cloud{ + c := &cloud{ region: "test-region", dm: dm.NewDeviceManager(), ec2: mockEC2, } + return c } func newDescribeInstancesOutput(nodeID string, volumeID ...string) *ec2.DescribeInstancesOutput { @@ -2187,20 +2397,24 @@ func createDetachRequest(volumeID, nodeID string) *ec2.DetachVolumeInput { } } -func createDescribeVolumesOutput(volumeID, nodeID, path, state string) *ec2.DescribeVolumesOutput { - return &ec2.DescribeVolumesOutput{ - Volumes: []*ec2.Volume{ - { - VolumeId: aws.String(volumeID), - Attachments: []*ec2.VolumeAttachment{ - { - Device: aws.String(path), - InstanceId: aws.String(nodeID), - State: aws.String(state), - }, +func createDescribeVolumesOutput(volumeIDs []*string, nodeID, path, state string) *ec2.DescribeVolumesOutput { + volumes := make([]*ec2.Volume, 0, len(volumeIDs)) + + for _, volumeID := range volumeIDs { + volumes = append(volumes, &ec2.Volume{ + VolumeId: volumeID, + Attachments: []*ec2.VolumeAttachment{ + { + Device: aws.String(path), + InstanceId: aws.String(nodeID), + State: aws.String(state), }, }, - }, + }) + } + + return &ec2.DescribeVolumesOutput{ + Volumes: volumes, } } diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index cff2eb35a..e0b4bd9cf 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -92,7 +92,8 @@ func newControllerService(driverOptions *DriverOptions) controllerService { region = metadata.GetRegion() } - cloudSrv, err := NewCloudFunc(region, driverOptions.awsSdkDebugLog, driverOptions.userAgentExtra) + klog.InfoS("batching", "status", driverOptions.batching) + cloudSrv, err := NewCloudFunc(region, driverOptions.awsSdkDebugLog, driverOptions.userAgentExtra, driverOptions.batching) if err != nil { panic(err) } @@ -420,7 +421,7 @@ func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *cs } return nil, status.Errorf(codes.Internal, "Could not attach volume %q to node %q: %v", volumeID, nodeID, err) } - klog.V(2).InfoS("ControllerPublishVolume: attached", "volumeID", volumeID, "nodeID", nodeID, "devicePath", devicePath) + klog.InfoS("ControllerPublishVolume: attached", "volumeID", volumeID, "nodeID", nodeID, "devicePath", devicePath) pvInfo := map[string]string{DevicePathKey: devicePath} return &csi.ControllerPublishVolumeResponse{PublishContext: pvInfo}, nil @@ -452,6 +453,7 @@ func validateControllerPublishVolumeRequest(req *csi.ControllerPublishVolumeRequ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { klog.V(4).InfoS("ControllerUnpublishVolume: called", "args", *req) + if err := validateControllerUnpublishVolumeRequest(req); err != nil { return nil, err } @@ -467,12 +469,12 @@ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req * klog.V(2).InfoS("ControllerUnpublishVolume: detaching", "volumeID", volumeID, "nodeID", nodeID) if err := d.cloud.DetachDisk(ctx, volumeID, nodeID); err != nil { if errors.Is(err, cloud.ErrNotFound) { - klog.V(2).InfoS("ControllerUnpublishVolume: attachment not found", "volumeID", volumeID, "nodeID", nodeID) + klog.InfoS("ControllerUnpublishVolume: attachment not found", "volumeID", volumeID, "nodeID", nodeID) return &csi.ControllerUnpublishVolumeResponse{}, nil } return nil, status.Errorf(codes.Internal, "Could not detach volume %q from node %q: %v", volumeID, nodeID, err) } - klog.V(2).InfoS("ControllerUnpublishVolume: detached", "volumeID", volumeID, "nodeID", nodeID) + klog.InfoS("ControllerUnpublishVolume: detached", "volumeID", volumeID, "nodeID", nodeID) return &csi.ControllerUnpublishVolumeResponse{}, nil } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index dbf113e4a..b3f575672 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -53,8 +53,8 @@ func TestNewControllerService(t *testing.T) { testErr = errors.New("test error") testRegion = "test-region" - getNewCloudFunc = func(expectedRegion string, _ bool) func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { - return func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { + getNewCloudFunc = func(expectedRegion string, _ bool) func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { + return func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { if region != expectedRegion { t.Fatalf("expected region %q but got %q", expectedRegion, region) } @@ -66,7 +66,7 @@ func TestNewControllerService(t *testing.T) { testCases := []struct { name string region string - newCloudFunc func(string, bool, string) (cloud.Cloud, error) + newCloudFunc func(string, bool, string, bool) (cloud.Cloud, error) newMetadataFuncErrors bool expectPanic bool }{ @@ -78,7 +78,7 @@ func TestNewControllerService(t *testing.T) { { name: "AWS_REGION variable set, newCloud errors", region: "foo", - newCloudFunc: func(region string, awsSdkDebugLog bool, userAgentExtra string) (cloud.Cloud, error) { + newCloudFunc: func(region string, awsSdkDebugLog bool, userAgentExtra string, batching bool) (cloud.Cloud, error) { return nil, testErr }, expectPanic: true, diff --git a/tests/e2e/dynamic_provisioning.go b/tests/e2e/dynamic_provisioning.go index 9342e4e52..0f2d5ba9f 100644 --- a/tests/e2e/dynamic_provisioning.go +++ b/tests/e2e/dynamic_provisioning.go @@ -374,7 +374,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Dynamic Provisioning", func() { availabilityZones := strings.Split(os.Getenv(awsAvailabilityZonesEnv), ",") availabilityZone := availabilityZones[rand.Intn(len(availabilityZones))] region := availabilityZone[0 : len(availabilityZone)-1] - cloud, err := awscloud.NewCloud(region, false, "") + cloud, err := awscloud.NewCloud(region, false, "", true) if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) } diff --git a/tests/e2e/pre_provsioning.go b/tests/e2e/pre_provsioning.go index 2924f42b1..7553afc4c 100644 --- a/tests/e2e/pre_provsioning.go +++ b/tests/e2e/pre_provsioning.go @@ -88,7 +88,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned", func() { Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"}, } var err error - cloud, err = awscloud.NewCloud(region, false, "") + cloud, err = awscloud.NewCloud(region, false, "", true) if err != nil { Fail(fmt.Sprintf("could not get NewCloud: %v", err)) }