Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor lambda, add unit tests #417

Draft
wants to merge 4 commits into
base: v0.2.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions actionners/aws/lambda/lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,51 @@ func (a Actionner) Checks(_ *events.Event, action *rules.Action) error {
if err != nil {
return err
}
// would be nice to pass the client here so we can easily test that function
return awsChecks.CheckLambdaExist.Run(awsChecks.CheckLambdaExist{}, parameters.AWSLambdaName)
}

func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine, *models.Data, error) {
lambdaClient := aws.GetLambdaClient()
return a.RunWithClient(lambdaClient, event, action)
}

func (a Actionner) CheckParameters(action *rules.Action) error {
var parameters Parameters
err := utils.DecodeParams(action.GetParameters(), &parameters)
if err != nil {
return err
}

err = utils.ValidateStruct(parameters)
if err != nil {
return err
}
return nil
}

func GetInvocationType(invocationType string) types.InvocationType {
switch invocationType {
case "RequestResponse":
return types.InvocationTypeRequestResponse
case "Event":
return types.InvocationTypeEvent
case "DryRun":
return types.InvocationTypeDryRun
default:
return types.InvocationTypeRequestResponse // Default
}
}

func GetLambdaVersion(qualifier *string) *string {
if qualifier == nil || *qualifier == "" {
defaultVal := "$LATEST"
return &defaultVal
}
return qualifier
}

func (a Actionner) RunWithClient(client aws.LambdaClientAPI, event *events.Event, action *rules.Action) (utils.LogLine, *models.Data, error) {
var parameters Parameters
err := utils.DecodeParams(action.GetParameters(), &parameters)
if err != nil {
Expand Down Expand Up @@ -134,12 +173,12 @@ func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine
input := &lambda.InvokeInput{
FunctionName: &parameters.AWSLambdaName,
ClientContext: nil,
InvocationType: getInvocationType(parameters.AWSLambdaInvocationType),
InvocationType: GetInvocationType(parameters.AWSLambdaInvocationType),
Payload: payload,
Qualifier: getLambdaVersion(&parameters.AWSLambdaAliasOrVersion),
Qualifier: GetLambdaVersion(&parameters.AWSLambdaAliasOrVersion),
}

lambdaOutput, err := lambdaClient.Invoke(context.Background(), input)
lambdaOutput, err := client.Invoke(context.Background(), input)
if err != nil {
return utils.LogLine{
Objects: objects,
Expand All @@ -158,38 +197,3 @@ func (a Actionner) Run(event *events.Event, action *rules.Action) (utils.LogLine
Status: status,
}, nil, nil
}

func (a Actionner) CheckParameters(action *rules.Action) error {
var parameters Parameters
err := utils.DecodeParams(action.GetParameters(), &parameters)
if err != nil {
return err
}

err = utils.ValidateStruct(parameters)
if err != nil {
return err
}
return nil
}

func getInvocationType(invocationType string) types.InvocationType {
switch invocationType {
case "RequestResponse":
return types.InvocationTypeRequestResponse
case "Event":
return types.InvocationTypeEvent
case "DryRun":
return types.InvocationTypeDryRun
default:
return types.InvocationTypeRequestResponse // Default
}
}

func getLambdaVersion(qualifier *string) *string {
if qualifier == nil || *qualifier == "" {
defaultVal := "$LATEST"
return &defaultVal
}
return qualifier
}
187 changes: 187 additions & 0 deletions actionners/aws/lambda/lambda_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package lambda_test

import (
"context"
"encoding/json"
"errors"
"testing"

"github.com/aws/aws-sdk-go-v2/service/lambda"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/falco-talon/falco-talon/internal/models"

lambdaActionner "github.com/falco-talon/falco-talon/actionners/aws/lambda"
"github.com/falco-talon/falco-talon/internal/events"
"github.com/falco-talon/falco-talon/internal/rules"
"github.com/falco-talon/falco-talon/utils"
)

// MockLambdaClient is a mock implementation of the LambdaClientAPI interface
type MockLambdaClient struct {
mock.Mock
}

func (m *MockLambdaClient) Invoke(ctx context.Context, input *lambda.InvokeInput, _ ...func(*lambda.Options)) (*lambda.InvokeOutput, error) {
args := m.Called(ctx, input)
return args.Get(0).(*lambda.InvokeOutput), args.Error(1)
}

func (m *MockLambdaClient) GetFunction(_ context.Context, _ *lambda.GetFunctionInput, _ ...func(*lambda.Options)) (*lambda.GetFunctionOutput, error) {
return &lambda.GetFunctionOutput{}, nil
}

type lambdaTestCase struct {
name string
event *events.Event
action *rules.Action
mockInvokeOutput *lambda.InvokeOutput
mockInvokeError error
expectedData *models.Data
expectedLogLine utils.LogLine
expectError bool
}

var lambdaTestCases = []lambdaTestCase{
{
name: "Successful Invocation",
event: &events.Event{
TraceID: "123",
Source: "falco-talon",
Rule: "sample-rule",
},
action: &rules.Action{
Parameters: map[string]interface{}{
"aws_lambda_name": "sample-function",
"aws_lambda_alias_or_version": "$LATEST",
"aws_lambda_invocation_type": "RequestResponse",
},
},
mockInvokeOutput: &lambda.InvokeOutput{
StatusCode: 200,
Payload: []byte(`{"message":"success"}`),
},
mockInvokeError: nil,
expectedLogLine: utils.LogLine{
Status: utils.SuccessStr,
Output: "{\"message\":\"success\"}",
Objects: map[string]string{"name": "sample-function", "version": "$LATEST"},
},
expectedData: nil,
expectError: false,
},
{
name: "Successful invocation of custom version",
event: &events.Event{
TraceID: "123",
Source: "falco-talon",
Rule: "sample-rule",
},
action: &rules.Action{
Parameters: map[string]interface{}{
"aws_lambda_name": "sample-function",
"aws_lambda_alias_or_version": "1",
"aws_lambda_invocation_type": "RequestResponse",
},
},
mockInvokeOutput: &lambda.InvokeOutput{
StatusCode: 200,
Payload: []byte(`{"message":"success"}`),
},
mockInvokeError: nil,
expectedLogLine: utils.LogLine{
Status: utils.SuccessStr,
Output: "{\"message\":\"success\"}",
Objects: map[string]string{"name": "sample-function", "version": "1"},
},
expectedData: nil,
expectError: false,
},
{
name: "Successful invocation of event",
event: &events.Event{
TraceID: "123",
Source: "falco-talon",
Rule: "sample-rule",
},
action: &rules.Action{
Parameters: map[string]interface{}{
"aws_lambda_name": "sample-function",
"aws_lambda_alias_or_version": "1",
"aws_lambda_invocation_type": " Event",
},
},
mockInvokeOutput: &lambda.InvokeOutput{
StatusCode: 200,
Payload: []byte(`{"message":"success"}`),
},
mockInvokeError: nil,
expectedLogLine: utils.LogLine{
Status: utils.SuccessStr,
Output: "{\"message\":\"success\"}",
Objects: map[string]string{"name": "sample-function", "version": "1"},
},
expectedData: nil,
expectError: false,
},
{
name: "Invocation Error",
event: &events.Event{}, // Provide event data as needed
action: &rules.Action{
Parameters: map[string]interface{}{
"aws_lambda_name": "sample-function",
"aws_lambda_alias_or_version": "$LATEST",
"aws_lambda_invocation_type": "RequestResponse",
},
},
mockInvokeOutput: new(lambda.InvokeOutput),
mockInvokeError: errors.New("invoke error"),
expectedLogLine: utils.LogLine{
Status: utils.FailureStr,
Error: "invoke error",
Objects: map[string]string{"name": "sample-function", "version": "$LATEST"},
},
expectedData: nil,
expectError: true,
},
}

func TestRunWithClient(t *testing.T) {
for _, tt := range lambdaTestCases {
t.Run(tt.name, func(t *testing.T) {
mockClient := new(MockLambdaClient)

lambdaName := tt.action.Parameters["aws_lambda_name"].(string)
lambdaVersion := tt.action.Parameters["aws_lambda_alias_or_version"].(string)
lambdaInvocationType := tt.action.Parameters["aws_lambda_invocation_type"].(string)
expectedPayload, _ := json.Marshal(tt.event)

mockClient.On("Invoke", mock.Anything, &lambda.InvokeInput{
FunctionName: &lambdaName,
InvocationType: lambdaActionner.GetInvocationType(lambdaInvocationType),
Payload: expectedPayload,
Qualifier: lambdaActionner.GetLambdaVersion(&lambdaVersion),
}).Return(tt.mockInvokeOutput, tt.mockInvokeError)

actionner := lambdaActionner.Actionner{}

logLine, data, err := actionner.RunWithClient(mockClient, tt.event, tt.action)

if tt.expectError {
assert.Error(t, err)
assert.Equal(t, tt.expectedLogLine.Status, logLine.Status)
assert.Contains(t, logLine.Error, tt.expectedLogLine.Error)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedLogLine.Status, logLine.Status)
assert.Equal(t, tt.expectedLogLine.Output, logLine.Output)
assert.Equal(t, tt.expectedLogLine.Objects, logLine.Objects)
}

assert.Equal(t, tt.expectedData, data)

mockClient.AssertExpectations(t)
})
}
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ require (
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0
go.opentelemetry.io/otel v1.29.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.29.0
Expand Down Expand Up @@ -145,6 +146,7 @@ require (
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.6-0.20210604193023-d5e0c0615ace // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/vishvananda/netlink v1.2.1-beta.2.0.20240524165444-4d4ba1473f21 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
9 changes: 7 additions & 2 deletions internal/aws/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ import (
"github.com/aws/aws-sdk-go-v2/service/sts"
)

type LambdaClientAPI interface {
Invoke(ctx context.Context, input *lambda.InvokeInput, optFns ...func(*lambda.Options)) (*lambda.InvokeOutput, error)
GetFunction(ctx context.Context, params *lambda.GetFunctionInput, optFns ...func(*lambda.Options)) (*lambda.GetFunctionOutput, error)
}

type AWSClient struct {
lambdaClient *lambda.Client
lambdaClient LambdaClientAPI
imdsClient *imds.Client
s3Client *s3.Client
cfg aws.Config
Expand Down Expand Up @@ -90,7 +95,7 @@ func GetAWSClient() *AWSClient {
return awsClient
}

func GetLambdaClient() *lambda.Client {
func GetLambdaClient() LambdaClientAPI {
c := GetAWSClient()
if c == nil {
return nil
Expand Down
Loading