diff --git a/cmd/aws/command.go b/cmd/aws/command.go index 64244db5..62044131 100644 --- a/cmd/aws/command.go +++ b/cmd/aws/command.go @@ -43,13 +43,13 @@ var ( supportedDNSProviders = []string{"aws", "cloudflare"} supportedGitProviders = []string{"github", "gitlab"} supportedGitProtocolOverride = []string{"https", "ssh"} - supportedAMITypes = []string{ - "AL2_x86_64", - "AL2_ARM_64", - "BOTTLEROCKET_ARM_64", - "BOTTLEROCKET_x86_64", - "BOTTLEROCKET_ARM_64_NVIDIA", - "BOTTLEROCKET_x86_64_NVIDIA", + supportedAMITypes = map[string]string{ + "AL2_x86_64": "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id", + "AL2_ARM_64": "/aws/service/eks/optimized-ami/1.29/amazon-linux-2-arm64/recommended/image_id", + "BOTTLEROCKET_ARM_64": "/aws/service/bottlerocket/aws-k8s-1.29/arm64/latest/image_id", + "BOTTLEROCKET_x86_64": "/aws/service/bottlerocket/aws-k8s-1.29/x86_64/latest/image_id", + "BOTTLEROCKET_ARM_64_NVIDIA": "/aws/service/bottlerocket/aws-k8s-1.29-nvidia/arm64/latest/image_id", + "BOTTLEROCKET_x86_64_NVIDIA": "/aws/service/bottlerocket/aws-k8s-1.29-nvidia/x86_64/latest/image_id", } ) @@ -108,11 +108,19 @@ func Create() *cobra.Command { createCmd.Flags().BoolVar(&useTelemetryFlag, "use-telemetry", true, "whether to emit telemetry") createCmd.Flags().BoolVar(&ecrFlag, "ecr", false, "whether or not to use ecr vs the git provider") createCmd.Flags().BoolVar(&installKubefirstProFlag, "install-kubefirst-pro", true, "whether or not to install kubefirst pro") - createCmd.Flags().StringVar(&amiType, "ami-type", "AL2_x86_64", fmt.Sprintf("the ami type for node group - one of: %q", supportedAMITypes)) + createCmd.Flags().StringVar(&amiType, "ami-type", "AL2_x86_64", fmt.Sprintf("the ami type for node group - one of: %q", getSupportedAMITypes())) return createCmd } +func getSupportedAMITypes() []string { + var amiTypes []string + for k := range supportedAMITypes { + amiTypes = append(amiTypes, k) + } + return amiTypes +} + func Destroy() *cobra.Command { destroyCmd := &cobra.Command{ Use: "destroy", diff --git a/cmd/aws/create.go b/cmd/aws/create.go index 5892816e..b0b5930a 100644 --- a/cmd/aws/create.go +++ b/cmd/aws/create.go @@ -32,15 +32,6 @@ import ( "github.com/spf13/viper" ) -var ssmTypesID = map[string]string{ - "AL2_x86_64": "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id", - "AL2_ARM_64": "/aws/service/eks/optimized-ami/1.29/amazon-linux-2-arm64/recommended/image_id", - "BOTTLEROCKET_ARM_64": "/aws/service/bottlerocket/aws-k8s-1.29/arm64/latest/image_id", - "BOTTLEROCKET_x86_64": "/aws/service/bottlerocket/aws-k8s-1.29/x86_64/latest/image_id", - "BOTTLEROCKET_ARM_64_NVIDIA": "/aws/service/bottlerocket/aws-k8s-1.29-nvidia/arm64/latest/image_id", - "BOTTLEROCKET_x86_64_NVIDIA": "/aws/service/bottlerocket/aws-k8s-1.29-nvidia/x86_64/latest/image_id", -} - func createAws(cmd *cobra.Command, _ []string) error { cliFlags, err := utilities.GetFlags(cmd, "aws") if err != nil { @@ -172,7 +163,7 @@ func ValidateProvidedFlags(ctx context.Context, cfg aws.Config, gitProvider, ami ec2Client := ec2.NewFromConfig(cfg) paginator := ec2.NewDescribeInstanceTypesPaginator(ec2Client, &ec2.DescribeInstanceTypesInput{}) - if err := ValidateAMIType(ctx, amiType, nodeType, ssmClient, ec2Client, paginator); err != nil { + if err := validateAMIType(ctx, amiType, nodeType, ssmClient, ec2Client, paginator); err != nil { progress.Error(err.Error()) return fmt.Errorf("failed to validte ami type for node group: %w", err) } @@ -192,25 +183,25 @@ func getSessionCredentials(ctx context.Context, cfg aws.CredentialsProvider) (*a return &creds, nil } -func ValidateAMIType(ctx context.Context, amiType, nodeType string, ssmClient ssmClienter, ec2Client ec2Clienter, paginator paginater) error { - ssmParameterName, ok := ssmTypesID[amiType] +func validateAMIType(ctx context.Context, amiType, nodeType string, ssmClient ssmClienter, ec2Client ec2Clienter, paginator paginator) error { + ssmParameterName, ok := supportedAMITypes[amiType] if !ok { return fmt.Errorf("not a valid ami type: %q", amiType) } log.Info().Msgf("ami type is %s", amiType) - amiID, err := GetLatestAMIFromSSM(ctx, ssmClient, ssmParameterName) + amiID, err := getLatestAMIFromSSM(ctx, ssmClient, ssmParameterName) if err != nil { return fmt.Errorf("failed to get AMI ID from SSM: %w", err) } - architecture, err := GetAMIArchitecture(ctx, ec2Client, amiID) + architecture, err := getAMIArchitecture(ctx, ec2Client, amiID) if err != nil { return fmt.Errorf("failed to get AMI architecture: %w", err) } - instanceTypes, err := GetSupportedInstanceTypes(ctx, paginator, architecture) + instanceTypes, err := getSupportedInstanceTypes(ctx, paginator, architecture) if err != nil { return fmt.Errorf("failed to get supported instance types: %w", err) } @@ -230,7 +221,7 @@ type ssmClienter interface { GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) } -func GetLatestAMIFromSSM(ctx context.Context, ssmClient ssmClienter, parameterName string) (string, error) { +func getLatestAMIFromSSM(ctx context.Context, ssmClient ssmClienter, parameterName string) (string, error) { input := &ssm.GetParameterInput{ Name: aws.String(parameterName), } @@ -246,7 +237,7 @@ type ec2Clienter interface { DescribeImages(ctx context.Context, params *ec2.DescribeImagesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) } -func GetAMIArchitecture(ctx context.Context, ec2Client ec2Clienter, amiID string) (string, error) { +func getAMIArchitecture(ctx context.Context, ec2Client ec2Clienter, amiID string) (string, error) { input := &ec2.DescribeImagesInput{ ImageIds: []string{amiID}, } @@ -262,15 +253,15 @@ func GetAMIArchitecture(ctx context.Context, ec2Client ec2Clienter, amiID string return string(output.Images[0].Architecture), nil } -type paginater interface { +type paginator interface { HasMorePages() bool NextPage(ctx context.Context, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) } -func GetSupportedInstanceTypes(ctx context.Context, paginator paginater, architecture string) ([]string, error) { +func getSupportedInstanceTypes(ctx context.Context, p paginator, architecture string) ([]string, error) { var instanceTypes []string - for paginator.HasMorePages() { - page, err := paginator.NextPage(ctx) + for p.HasMorePages() { + page, err := p.NextPage(ctx) if err != nil { return nil, fmt.Errorf("failed to load next pages for instance types: %w", err) } diff --git a/cmd/aws/create_test.go b/cmd/aws/create_test.go index e7d19ce6..080ecc63 100644 --- a/cmd/aws/create_test.go +++ b/cmd/aws/create_test.go @@ -173,7 +173,7 @@ func TestGetLatestAMIFromSSM(t *testing.T) { err: tt.err, } - amiID, err := GetLatestAMIFromSSM(context.Background(), mockSSM, tt.parameterName) + amiID, err := getLatestAMIFromSSM(context.Background(), mockSSM, tt.parameterName) if tt.expectedErr != nil { require.EqualError(t, err, tt.expectedErr.Error()) require.Empty(t, amiID) @@ -231,7 +231,7 @@ func TestGetAMIArchitecture(t *testing.T) { err: tt.err, } - architecture, err := GetAMIArchitecture(context.Background(), mockEC2, tt.amiID) + architecture, err := getAMIArchitecture(context.Background(), mockEC2, tt.amiID) fmt.Printf("arch is %s\n", string(architecture)) if tt.expectedErr != "" { require.EqualError(t, err, tt.expectedErr) @@ -305,7 +305,7 @@ func TestGetSupportedInstanceTypes(t *testing.T) { err: tt.paginateErr, } - got, err := GetSupportedInstanceTypes(context.Background(), paginator, tt.architecture) + got, err := getSupportedInstanceTypes(context.Background(), paginator, tt.architecture) if tt.expectedErr != nil { require.EqualError(t, err, tt.expectedErr.Error()) require.Nil(t, got) @@ -389,7 +389,7 @@ func TestValidateAMIType(t *testing.T) { err: tt.ec2Err, } - err := ValidateAMIType(context.Background(), tt.amiType, tt.nodeType, mockSSM, mockEC2, &mockInstanceTypesPaginator{ + err := validateAMIType(context.Background(), tt.amiType, tt.nodeType, mockSSM, mockEC2, &mockInstanceTypesPaginator{ instanceTypes: []ec2Types.InstanceTypeInfo{ { InstanceType: ec2Types.InstanceTypeT2Micro,