diff --git a/actionners/aws/lambda/lambda.go b/actionners/aws/lambda/lambda.go index 338a45fc..9ae5929b 100644 --- a/actionners/aws/lambda/lambda.go +++ b/actionners/aws/lambda/lambda.go @@ -101,12 +101,51 @@ func (a Actionner) Checks(_ *events.Event, action *rules.Action) error { if err != nil { return err } + // would be nice to pass the client here so we can easily test that function return awsChecks.CheckLambdaExist.Run(awsChecks.CheckLambdaExist{}, parameters.AWSLambdaName) } func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine, *models.Data, error) { lambdaClient := aws.GetLambdaClient() + return a.RunWithClient(lambdaClient, event, action) +} + +func (a Actionner) CheckParameters(action *rules.Action) error { + var parameters Parameters + err := utils.DecodeParams(action.GetParameters(), ¶meters) + if err != nil { + return err + } + + err = utils.ValidateStruct(parameters) + if err != nil { + return err + } + return nil +} + +func GetInvocationType(invocationType string) types.InvocationType { + switch invocationType { + case "RequestResponse": + return types.InvocationTypeRequestResponse + case "Event": + return types.InvocationTypeEvent + case "DryRun": + return types.InvocationTypeDryRun + default: + return types.InvocationTypeRequestResponse // Default + } +} + +func GetLambdaVersion(qualifier *string) *string { + if qualifier == nil || *qualifier == "" { + defaultVal := "$LATEST" + return &defaultVal + } + return qualifier +} +func (a Actionner) RunWithClient(client aws.LambdaClientAPI, event *events.Event, action *rules.Action) (utils.LogLine, *models.Data, error) { var parameters Parameters err := utils.DecodeParams(action.GetParameters(), ¶meters) if err != nil { @@ -134,12 +173,12 @@ func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine input := &lambda.InvokeInput{ FunctionName: ¶meters.AWSLambdaName, ClientContext: nil, - InvocationType: getInvocationType(parameters.AWSLambdaInvocationType), + InvocationType: GetInvocationType(parameters.AWSLambdaInvocationType), Payload: payload, - Qualifier: getLambdaVersion(¶meters.AWSLambdaAliasOrVersion), + Qualifier: GetLambdaVersion(¶meters.AWSLambdaAliasOrVersion), } - lambdaOutput, err := lambdaClient.Invoke(context.Background(), input) + lambdaOutput, err := client.Invoke(context.Background(), input) if err != nil { return utils.LogLine{ Objects: objects, @@ -158,38 +197,3 @@ func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine Status: status, }, nil, nil } - -func (a Actionner) CheckParameters(action *rules.Action) error { - var parameters Parameters - err := utils.DecodeParams(action.GetParameters(), ¶meters) - if err != nil { - return err - } - - err = utils.ValidateStruct(parameters) - if err != nil { - return err - } - return nil -} - -func getInvocationType(invocationType string) types.InvocationType { - switch invocationType { - case "RequestResponse": - return types.InvocationTypeRequestResponse - case "Event": - return types.InvocationTypeEvent - case "DryRun": - return types.InvocationTypeDryRun - default: - return types.InvocationTypeRequestResponse // Default - } -} - -func getLambdaVersion(qualifier *string) *string { - if qualifier == nil || *qualifier == "" { - defaultVal := "$LATEST" - return &defaultVal - } - return qualifier -} diff --git a/actionners/aws/lambda/lambda_test.go b/actionners/aws/lambda/lambda_test.go new file mode 100644 index 00000000..07bba5e9 --- /dev/null +++ b/actionners/aws/lambda/lambda_test.go @@ -0,0 +1,187 @@ +package lambda_test + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/lambda" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/falco-talon/falco-talon/internal/models" + + lambdaActionner "github.com/falco-talon/falco-talon/actionners/aws/lambda" + "github.com/falco-talon/falco-talon/internal/events" + "github.com/falco-talon/falco-talon/internal/rules" + "github.com/falco-talon/falco-talon/utils" +) + +// MockLambdaClient is a mock implementation of the LambdaClientAPI interface +type MockLambdaClient struct { + mock.Mock +} + +func (m *MockLambdaClient) Invoke(ctx context.Context, input *lambda.InvokeInput, _ ...func(*lambda.Options)) (*lambda.InvokeOutput, error) { + args := m.Called(ctx, input) + return args.Get(0).(*lambda.InvokeOutput), args.Error(1) +} + +func (m *MockLambdaClient) GetFunction(_ context.Context, _ *lambda.GetFunctionInput, _ ...func(*lambda.Options)) (*lambda.GetFunctionOutput, error) { + return &lambda.GetFunctionOutput{}, nil +} + +type lambdaTestCase struct { + name string + event *events.Event + action *rules.Action + mockInvokeOutput *lambda.InvokeOutput + mockInvokeError error + expectedData *models.Data + expectedLogLine utils.LogLine + expectError bool +} + +var lambdaTestCases = []lambdaTestCase{ + { + name: "Successful Invocation", + event: &events.Event{ + TraceID: "123", + Source: "falco-talon", + Rule: "sample-rule", + }, + action: &rules.Action{ + Parameters: map[string]interface{}{ + "aws_lambda_name": "sample-function", + "aws_lambda_alias_or_version": "$LATEST", + "aws_lambda_invocation_type": "RequestResponse", + }, + }, + mockInvokeOutput: &lambda.InvokeOutput{ + StatusCode: 200, + Payload: []byte(`{"message":"success"}`), + }, + mockInvokeError: nil, + expectedLogLine: utils.LogLine{ + Status: utils.SuccessStr, + Output: "{\"message\":\"success\"}", + Objects: map[string]string{"name": "sample-function", "version": "$LATEST"}, + }, + expectedData: nil, + expectError: false, + }, + { + name: "Successful invocation of custom version", + event: &events.Event{ + TraceID: "123", + Source: "falco-talon", + Rule: "sample-rule", + }, + action: &rules.Action{ + Parameters: map[string]interface{}{ + "aws_lambda_name": "sample-function", + "aws_lambda_alias_or_version": "1", + "aws_lambda_invocation_type": "RequestResponse", + }, + }, + mockInvokeOutput: &lambda.InvokeOutput{ + StatusCode: 200, + Payload: []byte(`{"message":"success"}`), + }, + mockInvokeError: nil, + expectedLogLine: utils.LogLine{ + Status: utils.SuccessStr, + Output: "{\"message\":\"success\"}", + Objects: map[string]string{"name": "sample-function", "version": "1"}, + }, + expectedData: nil, + expectError: false, + }, + { + name: "Successful invocation of event", + event: &events.Event{ + TraceID: "123", + Source: "falco-talon", + Rule: "sample-rule", + }, + action: &rules.Action{ + Parameters: map[string]interface{}{ + "aws_lambda_name": "sample-function", + "aws_lambda_alias_or_version": "1", + "aws_lambda_invocation_type": " Event", + }, + }, + mockInvokeOutput: &lambda.InvokeOutput{ + StatusCode: 200, + Payload: []byte(`{"message":"success"}`), + }, + mockInvokeError: nil, + expectedLogLine: utils.LogLine{ + Status: utils.SuccessStr, + Output: "{\"message\":\"success\"}", + Objects: map[string]string{"name": "sample-function", "version": "1"}, + }, + expectedData: nil, + expectError: false, + }, + { + name: "Invocation Error", + event: &events.Event{}, // Provide event data as needed + action: &rules.Action{ + Parameters: map[string]interface{}{ + "aws_lambda_name": "sample-function", + "aws_lambda_alias_or_version": "$LATEST", + "aws_lambda_invocation_type": "RequestResponse", + }, + }, + mockInvokeOutput: new(lambda.InvokeOutput), + mockInvokeError: errors.New("invoke error"), + expectedLogLine: utils.LogLine{ + Status: utils.FailureStr, + Error: "invoke error", + Objects: map[string]string{"name": "sample-function", "version": "$LATEST"}, + }, + expectedData: nil, + expectError: true, + }, +} + +func TestRunWithClient(t *testing.T) { + for _, tt := range lambdaTestCases { + t.Run(tt.name, func(t *testing.T) { + mockClient := new(MockLambdaClient) + + lambdaName := tt.action.Parameters["aws_lambda_name"].(string) + lambdaVersion := tt.action.Parameters["aws_lambda_alias_or_version"].(string) + lambdaInvocationType := tt.action.Parameters["aws_lambda_invocation_type"].(string) + expectedPayload, _ := json.Marshal(tt.event) + + mockClient.On("Invoke", mock.Anything, &lambda.InvokeInput{ + FunctionName: &lambdaName, + InvocationType: lambdaActionner.GetInvocationType(lambdaInvocationType), + Payload: expectedPayload, + Qualifier: lambdaActionner.GetLambdaVersion(&lambdaVersion), + }).Return(tt.mockInvokeOutput, tt.mockInvokeError) + + actionner := lambdaActionner.Actionner{} + + logLine, data, err := actionner.RunWithClient(mockClient, tt.event, tt.action) + + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, tt.expectedLogLine.Status, logLine.Status) + assert.Contains(t, logLine.Error, tt.expectedLogLine.Error) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedLogLine.Status, logLine.Status) + assert.Equal(t, tt.expectedLogLine.Output, logLine.Output) + assert.Equal(t, tt.expectedLogLine.Objects, logLine.Objects) + } + + assert.Equal(t, tt.expectedData, data) + + mockClient.AssertExpectations(t) + }) + } +} diff --git a/go.mod b/go.mod index ed7cb1c3..9120f316 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/rs/zerolog v1.33.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 go.opentelemetry.io/otel v1.29.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.29.0 @@ -145,6 +146,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.6-0.20210604193023-d5e0c0615ace // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/vishvananda/netlink v1.2.1-beta.2.0.20240524165444-4d4ba1473f21 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index 344f5bf3..0c14b775 100644 --- a/go.sum +++ b/go.sum @@ -300,6 +300,7 @@ github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/internal/aws/client/client.go b/internal/aws/client/client.go index 0b888196..fafc8181 100644 --- a/internal/aws/client/client.go +++ b/internal/aws/client/client.go @@ -17,8 +17,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" ) +type LambdaClientAPI interface { + Invoke(ctx context.Context, input *lambda.InvokeInput, optFns ...func(*lambda.Options)) (*lambda.InvokeOutput, error) + GetFunction(ctx context.Context, params *lambda.GetFunctionInput, optFns ...func(*lambda.Options)) (*lambda.GetFunctionOutput, error) +} + type AWSClient struct { - lambdaClient *lambda.Client + lambdaClient LambdaClientAPI imdsClient *imds.Client s3Client *s3.Client cfg aws.Config @@ -90,7 +95,7 @@ func GetAWSClient() *AWSClient { return awsClient } -func GetLambdaClient() *lambda.Client { +func GetLambdaClient() LambdaClientAPI { c := GetAWSClient() if c == nil { return nil