diff --git a/charts/aws-ebs-csi-driver/templates/controller.yaml b/charts/aws-ebs-csi-driver/templates/controller.yaml index 90db75120..1cc59deb7 100644 --- a/charts/aws-ebs-csi-driver/templates/controller.yaml +++ b/charts/aws-ebs-csi-driver/templates/controller.yaml @@ -87,6 +87,9 @@ spec: {{- if .Values.controller.sdkDebugLog }} - --aws-sdk-debug-log=true {{- end}} + {{- if .Values.controller.batching }} + - --batching=true + {{- end}} {{- with .Values.controller.loggingFormat }} - --logging-format={{ . }} {{- end }} diff --git a/charts/aws-ebs-csi-driver/values.yaml b/charts/aws-ebs-csi-driver/values.yaml index 27280e7ed..0a64b3f2f 100644 --- a/charts/aws-ebs-csi-driver/values.yaml +++ b/charts/aws-ebs-csi-driver/values.yaml @@ -166,6 +166,7 @@ awsAccessSecret: accessKey: access_key controller: + batching: true volumeModificationFeature: enabled: false # Additional parameters provided by aws-ebs-csi-driver controller. diff --git a/cmd/main.go b/cmd/main.go index e3ba18f6e..797ef2893 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -72,6 +72,7 @@ func main() { driver.WithWarnOnInvalidTag(options.ControllerOptions.WarnOnInvalidTag), driver.WithUserAgentExtra(options.ControllerOptions.UserAgentExtra), driver.WithOtelTracing(options.ServerOptions.EnableOtelTracing), + driver.WithBatching(options.ControllerOptions.Batching), ) if err != nil { klog.ErrorS(err, "failed to create driver") diff --git a/cmd/options/controller_options.go b/cmd/options/controller_options.go index 8d42dec57..ecffd97df 100644 --- a/cmd/options/controller_options.go +++ b/cmd/options/controller_options.go @@ -39,6 +39,8 @@ type ControllerOptions struct { WarnOnInvalidTag bool // flag to set user agent UserAgentExtra string + // flag to enable batching of API calls + Batching bool } func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) { @@ -48,4 +50,5 @@ func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) { fs.BoolVar(&s.AwsSdkDebugLog, "aws-sdk-debug-log", false, "To enable the aws sdk debug log level (default to false).") fs.BoolVar(&s.WarnOnInvalidTag, "warn-on-invalid-tag", false, "To warn on invalid tags, instead of returning an error") fs.StringVar(&s.UserAgentExtra, "user-agent-extra", "", "Extra string appended to user agent.") + fs.BoolVar(&s.Batching, "batching", false, "To enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits.") } diff --git a/cmd/options/controller_options_test.go b/cmd/options/controller_options_test.go index 5b26d9327..ce12d3999 100644 --- a/cmd/options/controller_options_test.go +++ b/cmd/options/controller_options_test.go @@ -43,6 +43,11 @@ func TestControllerOptions(t *testing.T) { flag: "aws-sdk-debug-log", found: true, }, + { + name: "lookup batching", + flag: "batching", + found: true, + }, { name: "lookup user-agent-extra", flag: "user-agent-extra", diff --git a/cmd/options_test.go b/cmd/options_test.go index e87ad3d3b..cf5023706 100644 --- a/cmd/options_test.go +++ b/cmd/options_test.go @@ -55,6 +55,8 @@ func TestGetOptions(t *testing.T) { userAgentExtraFlagValue := "test" otelTracingFlagName := "enable-otel-tracing" otelTracingFlagValue := true + batchingFlagName := "batching" + batchingFlagValue := true args := append([]string{ "aws-ebs-csi-driver", @@ -68,6 +70,7 @@ func TestGetOptions(t *testing.T) { args = append(args, "--"+extraTagsFlagName+"="+extraTagKey+"="+extraTagValue) args = append(args, "--"+awsSdkDebugFlagName+"="+strconv.FormatBool(awsSdkDebugFlagValue)) args = append(args, "--"+userAgentExtraFlag+"="+userAgentExtraFlagValue) + args = append(args, "--"+batchingFlagName+"="+strconv.FormatBool(batchingFlagValue)) } if withNodeOptions { args = append(args, "--"+VolumeAttachLimitFlagName+"="+strconv.FormatInt(VolumeAttachLimit, 10)) @@ -110,6 +113,13 @@ func TestGetOptions(t *testing.T) { if options.ControllerOptions.UserAgentExtra != userAgentExtraFlagValue { t.Fatalf("expected user agent string to be %q but it is %q", userAgentExtraFlagValue, options.ControllerOptions.UserAgentExtra) } + batchingFlag := flagSet.Lookup(batchingFlagName) + if batchingFlag == nil { + t.Fatalf("expected %q flag to be added but it is not", batchingFlagName) + } + if options.ControllerOptions.Batching != batchingFlagValue { + t.Fatalf("expected sdk debug flag to be %v but it is %v", batchingFlagValue, options.ControllerOptions.Batching) + } } if withNodeOptions { diff --git a/deploy/kubernetes/base/controller.yaml b/deploy/kubernetes/base/controller.yaml index 8ebe5469c..fa30459d0 100644 --- a/deploy/kubernetes/base/controller.yaml +++ b/deploy/kubernetes/base/controller.yaml @@ -66,6 +66,7 @@ spec: args: # - {all,controller,node} # specify the driver mode - --endpoint=$(CSI_ENDPOINT) + - --batching=true - --logging-format=text - --user-agent-extra=kustomize - --v=2 diff --git a/docs/options.md b/docs/options.md index 3eb7663e5..88d1deed4 100644 --- a/docs/options.md +++ b/docs/options.md @@ -12,3 +12,4 @@ There are a couple of driver options that can be passed as arguments when starti | logging-format | json | text | Sets the log format. Permitted formats: text, json| | user-agent-extra | csi-ebs | helm | Extra string appended to user agent| | enable-otel-tracing | true | false | If set to true, the driver will enable opentelemetry tracing. Might need [additional env variables](https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/#general-sdk-configuration) to export the traces to the right collector| +| batching | true | true | If set to true, the driver will enable batching of API calls. This is especially helpful for improving performance in workloads that are sensitive to EC2 rate limits| diff --git a/pkg/batcher/batcher.go b/pkg/batcher/batcher.go new file mode 100644 index 000000000..0bfa703e8 --- /dev/null +++ b/pkg/batcher/batcher.go @@ -0,0 +1,162 @@ +// Package batcher facilitates task aggregation and execution. +// +// Basic Usage: +// Instantiate a Batcher, set up its constraints, and then start adding tasks. As tasks accumulate, +// they are batched together for execution, either when a maximum task count is reached or a specified +// duration elapses. Results of the executed tasks are communicated asynchronously via channels. +// +// Example: +// Create a Batcher with a maximum of 10 tasks or a 5-second wait: +// +// `b := batcher.New(10, 5*time.Second, execFunc)` +// +// Add a task and receive its result: +// +// resultChan := make(chan batcher.BatchResult) +// b.AddTask(myTask, resultChan) +// result := <-resultChan +// +// Key Components: +// - `Batcher`: The main component that manages task queueing, aggregation, and execution. +// - `BatchResult`: A structure encapsulating the response for a task. +// - `taskEntry`: Internal representation of a task and its associated result channel. +// +// Task Duplication: +// Batcher identifies tasks by content. For multiple identical tasks, each has a unique result channel. +// This distinction ensures that identical tasks return their results to the appropriate callers. +package batcher + +import ( + "time" + + "k8s.io/klog/v2" +) + +// Batcher manages the batching and execution of tasks. It collects tasks up to a specified limit (maxEntries) or +// waits for a defined duration (maxDelay) before triggering a batch execution. The actual task execution +// logic is provided by the execFunc, which processes tasks and returns their corresponding results. Tasks are +// queued via the taskChan and stored in pendingTasks until batch execution. +type Batcher[InputType comparable, ResultType interface{}] struct { + // execFunc is the function responsible for executing a batch of tasks. + // It returns a map associating each task with its result. + execFunc func(inputs []InputType) (map[InputType]ResultType, error) + + // pendingTasks holds the tasks that are waiting to be executed in a batch. + // Each task is associated with one or more result channels to account for duplicates. + pendingTasks map[InputType][]chan BatchResult[ResultType] + + // taskChan is the channel through which new tasks are added to the Batcher. + taskChan chan taskEntry[InputType, ResultType] + + // maxEntries is the maximum number of tasks that can be batched together for execution. + maxEntries int + + // maxDelay is the maximum duration the Batcher waits before executing a batch operation, + // regardless of how many tasks are in the batch. + maxDelay time.Duration +} + +// BatchResult encapsulates the response of a batched task. +// A task will either have a result or an error, but not both. +type BatchResult[ResultType interface{}] struct { + Result ResultType + Err error +} + +// taskEntry represents a single task waiting to be batched and its associated result channel. +// The result channel is used to communicate the task's result back to the caller. +type taskEntry[InputType comparable, ResultType interface{}] struct { + task InputType + resultChan chan BatchResult[ResultType] +} + +// New creates and returns a Batcher configured with the specified maxEntries and maxDelay parameters. +// Upon instantiation, it immediately launches the internal task manager as a goroutine to oversee batch operations. +// The provided execFunc is used to execute batch requests. +func New[InputType comparable, ResultType interface{}](entries int, delay time.Duration, fn func(inputs []InputType) (map[InputType]ResultType, error)) *Batcher[InputType, ResultType] { + klog.V(7).InfoS("New: initializing Batcher", "maxEntries", entries, "maxDelay", delay) + + b := &Batcher[InputType, ResultType]{ + execFunc: fn, + pendingTasks: make(map[InputType][]chan BatchResult[ResultType]), + taskChan: make(chan taskEntry[InputType, ResultType], entries), + maxEntries: entries, + maxDelay: delay, + } + + go b.taskManager() + return b +} + +// AddTask adds a new task to the Batcher's queue. +func (b *Batcher[InputType, ResultType]) AddTask(t InputType, resultChan chan BatchResult[ResultType]) { + klog.V(7).InfoS("AddTask: queueing task", "task", t) + b.taskChan <- taskEntry[InputType, ResultType]{task: t, resultChan: resultChan} +} + +// taskManager runs as a goroutine, continuously managing the Batcher's internal state. +// It batches tasks and triggers their execution based on set constraints (maxEntries and maxDelay). +func (b *Batcher[InputType, ResultType]) taskManager() { + klog.V(7).InfoS("taskManager: started taskManager") + var timerCh <-chan time.Time + + exec := func() { + timerCh = nil + go b.execute(b.pendingTasks) + b.pendingTasks = make(map[InputType][]chan BatchResult[ResultType]) + } + + for { + select { + case <-timerCh: + klog.V(7).InfoS("taskManager: maxDelay execution") + exec() + + case t := <-b.taskChan: + if _, exists := b.pendingTasks[t.task]; exists { + klog.InfoS("taskManager: duplicate task detected", "task", t.task) + } else { + b.pendingTasks[t.task] = make([]chan BatchResult[ResultType], 0) + } + b.pendingTasks[t.task] = append(b.pendingTasks[t.task], t.resultChan) + + if len(b.pendingTasks) == 1 { + klog.V(7).InfoS("taskManager: starting maxDelay timer") + timerCh = time.After(b.maxDelay) + } + + if len(b.pendingTasks) == b.maxEntries { + klog.V(7).InfoS("taskManager: maxEntries reached") + exec() + } + } + } +} + +// execute is called by taskManager to execute a batch of tasks. +// It calls the Batcher's internal execFunc and then sends the results of each task to its corresponding result channels. +func (b *Batcher[InputType, ResultType]) execute(pendingTasks map[InputType][]chan BatchResult[ResultType]) { + batch := make([]InputType, 0, len(pendingTasks)) + for task := range pendingTasks { + batch = append(batch, task) + } + + klog.V(7).InfoS("execute: calling execFunc", "batchSize", len(batch)) + resultsMap, err := b.execFunc(batch) + if err != nil { + klog.ErrorS(err, "execute: error executing batch") + } + + klog.V(7).InfoS("execute: sending batch results", "batch", batch) + for _, task := range batch { + r := resultsMap[task] + for _, ch := range pendingTasks[task] { + select { + case ch <- BatchResult[ResultType]{Result: r, Err: err}: + default: + klog.V(7).InfoS("execute: ignoring channel with no receiver") + } + } + } + klog.V(7).InfoS("execute: finished execution", "batchSize", len(batch)) +} diff --git a/pkg/batcher/batcher_test.go b/pkg/batcher/batcher_test.go new file mode 100644 index 000000000..7553f1bd8 --- /dev/null +++ b/pkg/batcher/batcher_test.go @@ -0,0 +1,181 @@ +package batcher + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func mockExecution(inputs []string) (map[string]string, error) { + results := make(map[string]string) + for _, input := range inputs { + results[input] = input + } + return results, nil +} + +func mockExecutionWithError(inputs []string) (map[string]string, error) { + results := make(map[string]string) + for _, input := range inputs { + results[input] = input + } + return results, fmt.Errorf("mock execution error") +} + +func TestBatcher(t *testing.T) { + type testCase struct { + name string + mockFunc func(inputs []string) (map[string]string, error) + maxEntries int + maxDelay time.Duration + tasks []string + expectErrors bool + expectResult bool + } + + tests := []testCase{ + { + name: "TestBatcher: single task", + mockFunc: mockExecution, + maxEntries: 10, + maxDelay: 1 * time.Second, + tasks: []string{"task1"}, + expectResult: true, + expectErrors: false, + }, + { + name: "TestBatcher: multiple tasks", + mockFunc: mockExecution, + maxEntries: 10, + maxDelay: 1 * time.Second, + tasks: []string{"task1", "task2", "task3"}, + expectResult: true, + expectErrors: false, + }, + { + name: "TestBatcher: same task", + mockFunc: mockExecution, + maxEntries: 10, + maxDelay: 1 * time.Second, + tasks: []string{"task1", "task1", "task1"}, + expectResult: true, + expectErrors: false, + }, + { + name: "TestBatcher: max capacity", + mockFunc: mockExecution, + maxEntries: 5, + maxDelay: 100 * time.Second, + tasks: []string{"task1", "task2", "task3", "task4", "task5"}, + expectResult: true, + expectErrors: false, + }, + { + name: "TestBatcher: max delay", + mockFunc: mockExecution, + maxEntries: 100, + maxDelay: 2 * time.Second, + tasks: []string{"task1", "task2", "task3", "task4"}, + expectResult: true, + expectErrors: false, + }, + { + name: "TestBatcher: no execution without max delay or max entries", + mockFunc: mockExecution, + maxEntries: 10, + maxDelay: 15 * time.Second, + tasks: []string{"task1", "task2", "task3"}, + expectResult: false, + }, + { + name: "TestBatcher: error handling", + mockFunc: mockExecutionWithError, + maxEntries: 10, + maxDelay: 1 * time.Second, + tasks: []string{"errorTask"}, + expectErrors: true, + }, + { + name: "TestBatcher: immediate execution", + mockFunc: mockExecution, + maxEntries: 10, + maxDelay: 0, + tasks: []string{"task1"}, + expectResult: true, + expectErrors: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := New(tc.maxEntries, tc.maxDelay, tc.mockFunc) + resultChans := make([]chan BatchResult[string], len(tc.tasks)) + + var wg sync.WaitGroup + + for i := 0; i < len(tc.tasks); i++ { + wg.Add(1) + go func(taskNum int) { + defer wg.Done() + task := fmt.Sprintf("task%d", taskNum) + resultChans[taskNum] = make(chan BatchResult[string], 1) + b.AddTask(task, resultChans[taskNum]) + }(i) + } + + wg.Wait() + + for i := 0; i < len(tc.tasks); i++ { + select { + case r := <-resultChans[i]: + task := fmt.Sprintf("task%d", i) + if tc.expectErrors && r.Err == nil { + t.Errorf("Expected error for task %v, but got %v", task, r.Err) + } + if r.Result != task && tc.expectResult { + t.Errorf("Expected result for task %v, but got %v", task, r.Result) + } + case <-time.After(10 * time.Second): + if tc.expectResult { + t.Errorf("Timed out waiting for result of task %d", i) + } + } + } + }) + } +} + +func TestBatcherConcurrentTaskAdditions(t *testing.T) { + numTasks := 100 + var wg sync.WaitGroup + + b := New(numTasks, 1*time.Second, mockExecution) + resultChans := make([]chan BatchResult[string], numTasks) + + for i := 0; i < numTasks; i++ { + wg.Add(1) + go func(taskNum int) { + defer wg.Done() + task := fmt.Sprintf("task%d", taskNum) + resultChans[taskNum] = make(chan BatchResult[string], 1) + b.AddTask(task, resultChans[taskNum]) + }(i) + } + + wg.Wait() + + for i := 0; i < numTasks; i++ { + r := <-resultChans[i] + task := fmt.Sprintf("task%d", i) + if r.Err != nil { + t.Errorf("Expected no error for task %v, but got %v", task, r.Err) + } + if r.Result != task { + t.Errorf("Expected result %v for task %v, but got %v", task, task, r.Result) + } + } +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index d92964af9..8569e7d51 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -68,6 +68,7 @@ type DriverOptions struct { volumeAttachLimit int64 kubernetesClusterID string awsSdkDebugLog bool + batching bool warnOnInvalidTag bool userAgentExtra string otelTracing bool @@ -193,6 +194,12 @@ func WithVolumeAttachLimit(volumeAttachLimit int64) func(*DriverOptions) { } } +func WithBatching(enableBatching bool) func(*DriverOptions) { + return func(o *DriverOptions) { + o.batching = enableBatching + } +} + func WithKubernetesClusterID(clusterID string) func(*DriverOptions) { return func(o *DriverOptions) { o.kubernetesClusterID = clusterID diff --git a/pkg/driver/driver_test.go b/pkg/driver/driver_test.go index e98229ae0..efe033c63 100644 --- a/pkg/driver/driver_test.go +++ b/pkg/driver/driver_test.go @@ -112,3 +112,12 @@ func TestWithOtelTracing(t *testing.T) { t.Fatalf("expected otelTracing option got set to %v but is set to %v", enableOtelTracing, options.otelTracing) } } + +func TestWithBatching(t *testing.T) { + var batching bool = true + options := &DriverOptions{} + WithBatching(batching)(options) + if options.batching != batching { + t.Fatalf("expected batching option got set to %v but is set to %v", batching, options.batching) + } +}