From 0923003489f97d26656f6e73533dde04f111ce35 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Tue, 24 Dec 2024 13:12:13 -0800 Subject: [PATCH] migrate AWS Redshift services to AWS SDK v2 --- integrations/event-handler/go.mod | 1 + integrations/event-handler/go.sum | 2 + integrations/terraform/go.mod | 1 + integrations/terraform/go.sum | 2 + lib/cloud/aws/aws.go | 4 +- lib/cloud/aws/errors.go | 15 +- lib/cloud/aws/tags_helpers.go | 8 +- lib/cloud/awstesthelpers/tags.go | 45 +++ lib/cloud/clients.go | 23 -- lib/cloud/mocks/aws.go | 18 +- lib/cloud/mocks/aws_config.go | 54 ++++ lib/cloud/mocks/aws_redshift.go | 81 +++-- lib/cloud/mocks/aws_sts.go | 81 +++++ lib/kube/proxy/kube_creds_test.go | 2 +- lib/srv/db/access_test.go | 19 +- lib/srv/db/cloud/aws.go | 4 +- lib/srv/db/cloud/iam_test.go | 23 +- lib/srv/db/cloud/meta.go | 49 ++- lib/srv/db/cloud/meta_test.go | 49 +-- lib/srv/db/cloud/resource_checker.go | 6 + lib/srv/db/cloud/resource_checker_url.go | 23 +- lib/srv/db/cloud/resource_checker_url_aws.go | 22 +- .../db/cloud/resource_checker_url_aws_test.go | 39 ++- lib/srv/db/common/auth.go | 74 +++-- lib/srv/db/common/auth_test.go | 75 +++-- lib/srv/db/common/errors.go | 4 + lib/srv/db/server.go | 42 ++- lib/srv/db/watcher.go | 2 +- lib/srv/discovery/common/database.go | 278 +++++++++--------- lib/srv/discovery/common/database_test.go | 22 +- lib/srv/discovery/config_test.go | 2 + lib/srv/discovery/discovery.go | 36 ++- lib/srv/discovery/discovery_test.go | 50 +++- lib/srv/discovery/fetchers/db/aws.go | 28 +- lib/srv/discovery/fetchers/db/aws_redshift.go | 57 ++-- .../fetchers/db/aws_redshift_test.go | 45 +-- lib/srv/discovery/fetchers/db/db.go | 69 ++++- lib/srv/discovery/fetchers/db/helpers_test.go | 29 +- .../kube_integration_watcher_test.go | 8 +- 39 files changed, 923 insertions(+), 469 deletions(-) create mode 100644 lib/cloud/awstesthelpers/tags.go create mode 100644 lib/cloud/mocks/aws_config.go create mode 100644 lib/cloud/mocks/aws_sts.go diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index c5f1c71cee047..d306a8e1406ac 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -87,6 +87,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/kms v1.37.7 // indirect github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.92.0 // indirect + github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.56.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 7965768d49fdb..04a8235c40249 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -777,6 +777,8 @@ github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0 h1:CVHfN8ZVvWzDkAf/Qj github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0/go.mod h1:SVY+doFrL3KTvVMWzFLKvD7KYQ6GQfwNRPSQS7eA3cA= github.com/aws/aws-sdk-go-v2/service/rds v1.92.0 h1:W0gUYAjO24u/M6tpR041wMHJWGzleOhxtCnNLImdrZs= github.com/aws/aws-sdk-go-v2/service/rds v1.92.0/go.mod h1:ADD2uROOoEIXjbjDPEvDDZWnGmfKFYMddgKwG5RlBGw= +github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0 h1:4/hmROBioc89sKlMVjHgOaH92zAkrAAMZR3BIvYwyD0= +github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0/go.mod h1:UydVhUJOB/DaCJWiaBkPlvuzvWVcUlgbS2Bxn33bcKI= github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 h1:nyuzXooUNJexRT0Oy0UQY6AhOzxPxhtt4DcBIHyCnmw= github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0/go.mod h1:sT/iQz8JK3u/5gZkT+Hmr7GzVZehUMkRZpOaAwYXeGY= github.com/aws/aws-sdk-go-v2/service/ssm v1.56.1 h1:cfVjoEwOMOJOI6VoRQua0nI0KjZV9EAnR8bKaMeSppE= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index 23f3eb75782e1..383566c661fe6 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -100,6 +100,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/kms v1.37.7 // indirect github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.92.0 // indirect + github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.56.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 860c97ee36879..e9b72bc846c80 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -846,6 +846,8 @@ github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0 h1:CVHfN8ZVvWzDkAf/Qj github.com/aws/aws-sdk-go-v2/service/organizations v1.36.0/go.mod h1:SVY+doFrL3KTvVMWzFLKvD7KYQ6GQfwNRPSQS7eA3cA= github.com/aws/aws-sdk-go-v2/service/rds v1.92.0 h1:W0gUYAjO24u/M6tpR041wMHJWGzleOhxtCnNLImdrZs= github.com/aws/aws-sdk-go-v2/service/rds v1.92.0/go.mod h1:ADD2uROOoEIXjbjDPEvDDZWnGmfKFYMddgKwG5RlBGw= +github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0 h1:4/hmROBioc89sKlMVjHgOaH92zAkrAAMZR3BIvYwyD0= +github.com/aws/aws-sdk-go-v2/service/redshift v1.53.0/go.mod h1:UydVhUJOB/DaCJWiaBkPlvuzvWVcUlgbS2Bxn33bcKI= github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 h1:nyuzXooUNJexRT0Oy0UQY6AhOzxPxhtt4DcBIHyCnmw= github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0/go.mod h1:sT/iQz8JK3u/5gZkT+Hmr7GzVZehUMkRZpOaAwYXeGY= github.com/aws/aws-sdk-go-v2/service/sns v1.33.7 h1:N3o8mXK6/MP24BtD9sb51omEO9J9cgPM3Ughc293dZc= diff --git a/lib/cloud/aws/aws.go b/lib/cloud/aws/aws.go index b1923a9bb9c3f..27ea56321b7df 100644 --- a/lib/cloud/aws/aws.go +++ b/lib/cloud/aws/aws.go @@ -22,12 +22,12 @@ import ( "slices" "strings" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/coreos/go-semver/semver" "github.com/gravitational/teleport/lib/services" @@ -244,7 +244,7 @@ func IsDBClusterAvailable(clusterStatus, clusterIndetifier *string) bool { } // IsRedshiftClusterAvailable checks if the Redshift cluster is available. -func IsRedshiftClusterAvailable(cluster *redshift.Cluster) bool { +func IsRedshiftClusterAvailable(cluster *redshifttypes.Cluster) bool { // For a full list of status values, see: // https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-clusters.html#rs-mgmt-cluster-status // diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index 472590d35b1d4..576e7f4350ce2 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -31,14 +31,23 @@ import ( "github.com/gravitational/trace" ) -// ConvertRequestFailureError converts `error` into AWS RequestFailure errors -// to trace errors. If the provided error is not an `RequestFailure` it returns -// the error without modifying it. +// ConvertRequestFailureError converts `err` into AWS errors to trace errors. +// If the provided error is not a [awserr.RequestFailure] it delegates +// error conversion to [ConvertRequestFailureErrorV2] for SDK v2 compatibility. +// Prefer using [ConvertRequestFailureErrorV2] directly for AWS SDK v2 client +// errors. func ConvertRequestFailureError(err error) error { var requestErr awserr.RequestFailure if errors.As(err, &requestErr) { return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr) } + return ConvertRequestFailureErrorV2(err) +} + +// ConvertRequestFailureErrorV2 converts AWS SDK v2 errors to trace errors. +// If the provided error is not a [awshttp.ResponseError] it returns the error +// without modifying it. +func ConvertRequestFailureErrorV2(err error) error { var re *awshttp.ResponseError if errors.As(err, &re) { return convertRequestFailureErrorFromStatusCode(re.HTTPStatusCode(), re.Err) diff --git a/lib/cloud/aws/tags_helpers.go b/lib/cloud/aws/tags_helpers.go index 27dbe8238f178..3e61bd6fc1a42 100644 --- a/lib/cloud/aws/tags_helpers.go +++ b/lib/cloud/aws/tags_helpers.go @@ -25,13 +25,13 @@ import ( ec2TypesV2 "github.com/aws/aws-sdk-go-v2/service/ec2/types" rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/secretsmanager" "golang.org/x/exp/maps" @@ -45,9 +45,9 @@ type ResourceTag interface { // here and use a type switch for now. rdsTypesV2.Tag | ec2TypesV2.Tag | + redshifttypes.Tag | *ec2.Tag | *rds.Tag | - *redshift.Tag | *elasticache.Tag | *memorydb.Tag | *redshiftserverless.Tag | @@ -80,8 +80,6 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { return aws.StringValue(v.Key), aws.StringValue(v.Value) case *ec2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) - case *redshift.Tag: - return aws.StringValue(v.Key), aws.StringValue(v.Value) case *elasticache.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case *memorydb.Tag: @@ -92,6 +90,8 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { return aws.StringValue(v.Key), aws.StringValue(v.Value) case ec2TypesV2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) + case redshifttypes.Tag: + return aws.StringValue(v.Key), aws.StringValue(v.Value) case *opensearchservice.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case *secretsmanager.Tag: diff --git a/lib/cloud/awstesthelpers/tags.go b/lib/cloud/awstesthelpers/tags.go new file mode 100644 index 0000000000000..5e1f4aa0e0738 --- /dev/null +++ b/lib/cloud/awstesthelpers/tags.go @@ -0,0 +1,45 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awstesthelpers + +import ( + "maps" + "slices" + + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" +) + +// LabelsToRedshiftTags converts labels into [redshifttypes.Tag] list. +func LabelsToRedshiftTags(labels map[string]string) []redshifttypes.Tag { + keys := slices.Collect(maps.Keys(labels)) + slices.Sort(keys) + + ret := make([]redshifttypes.Tag, 0, len(keys)) + for _, key := range keys { + key := key + value := labels[key] + + ret = append(ret, redshifttypes.Tag{ + Key: &key, + Value: &value, + }) + } + + return ret +} diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 54b02d84dc400..99c2deb4001f0 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -53,8 +53,6 @@ import ( "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" - "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/aws/aws-sdk-go/service/s3" @@ -115,8 +113,6 @@ type AWSClients interface { GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error) // GetAWSRDSClient returns AWS RDS client for the specified region. GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) - // GetAWSRedshiftClient returns AWS Redshift client for the specified region. - GetAWSRedshiftClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftiface.RedshiftAPI, error) // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. @@ -517,15 +513,6 @@ func (c *cloudClients) GetAWSRDSClient(ctx context.Context, region string, opts return rds.New(session), nil } -// GetAWSRedshiftClient returns AWS Redshift client for the specified region. -func (c *cloudClients) GetAWSRedshiftClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftiface.RedshiftAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return redshift.New(session), nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *cloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1033,7 +1020,6 @@ var _ Clients = (*TestCloudClients)(nil) type TestCloudClients struct { RDS rdsiface.RDSAPI RDSPerRegion map[string]rdsiface.RDSAPI - Redshift redshiftiface.RedshiftAPI RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI ElastiCache elasticacheiface.ElastiCacheAPI OpenSearch opensearchserviceiface.OpenSearchServiceAPI @@ -1115,15 +1101,6 @@ func (c *TestCloudClients) GetAWSRDSClient(ctx context.Context, region string, o return c.RDS, nil } -// GetAWSRedshiftClient returns AWS Redshift client for the specified region. -func (c *TestCloudClients) GetAWSRedshiftClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftiface.RedshiftAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return c.Redshift, nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *TestCloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go index 016634a9e5529..ceb50bd822cc2 100644 --- a/lib/cloud/mocks/aws.go +++ b/lib/cloud/mocks/aws.go @@ -37,8 +37,8 @@ import ( "github.com/gravitational/trace" ) -// STSMock mocks AWS STS API. -type STSMock struct { +// STSClientV1 mocks AWS STS API for AWS SDK v1. +type STSClientV1 struct { stsiface.STSAPI ARN string URL *url.URL @@ -47,36 +47,36 @@ type STSMock struct { mu sync.Mutex } -func (m *STSMock) GetAssumedRoleARNs() []string { +func (m *STSClientV1) GetAssumedRoleARNs() []string { m.mu.Lock() defer m.mu.Unlock() return m.assumedRoleARNs } -func (m *STSMock) GetAssumedRoleExternalIDs() []string { +func (m *STSClientV1) GetAssumedRoleExternalIDs() []string { m.mu.Lock() defer m.mu.Unlock() return m.assumedRoleExternalIDs } -func (m *STSMock) ResetAssumeRoleHistory() { +func (m *STSClientV1) ResetAssumeRoleHistory() { m.mu.Lock() defer m.mu.Unlock() m.assumedRoleARNs = nil m.assumedRoleExternalIDs = nil } -func (m *STSMock) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) { +func (m *STSClientV1) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Arn: aws.String(m.ARN), }, nil } -func (m *STSMock) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { +func (m *STSClientV1) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { return m.AssumeRoleWithContext(context.Background(), in) } -func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) { +func (m *STSClientV1) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) { m.mu.Lock() defer m.mu.Unlock() if !slices.Contains(m.assumedRoleARNs, aws.StringValue(in.RoleArn)) { @@ -94,7 +94,7 @@ func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput }, nil } -func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { +func (m *STSClientV1) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { return &request.Request{ HTTPRequest: &http.Request{ Header: http.Header{}, diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go new file mode 100644 index 0000000000000..b78269413b0a1 --- /dev/null +++ b/lib/cloud/mocks/aws_config.go @@ -0,0 +1,54 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mocks + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + + "github.com/gravitational/teleport/lib/cloud/awsconfig" +) + +type AWSConfigProvider struct { + STSClient *STSClient +} + +func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { + stsClt := f.STSClient + if stsClt == nil { + stsClt = &STSClient{} + } + optFns = append(optFns, awsconfig.WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + if cfg.Credentials != nil { + if _, ok := cfg.Credentials.(*stscreds.AssumeRoleProvider); ok { + // Create a new fake client linked to the old one. + // Only do this for AssumeRoleProvider, to avoid attempting to + // load the real credential chain. + return &STSClient{ + credentialProvider: cfg.Credentials, + root: stsClt, + } + } + } + return stsClt + })) + return awsconfig.GetConfig(ctx, region, optFns...) +} diff --git a/lib/cloud/mocks/aws_redshift.go b/lib/cloud/mocks/aws_redshift.go index 1e4855d249fd5..d485d1dcc5a5e 100644 --- a/lib/cloud/mocks/aws_redshift.go +++ b/lib/cloud/mocks/aws_redshift.go @@ -19,74 +19,62 @@ package mocks import ( + "context" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" ) -// RedshiftMock mocks AWS Redshift API. -type RedshiftMock struct { - redshiftiface.RedshiftAPI - Clusters []*redshift.Cluster +type RedshiftClient struct { + Unauth bool + + Clusters []redshifttypes.Cluster GetClusterCredentialsOutput *redshift.GetClusterCredentialsOutput GetClusterCredentialsWithIAMOutput *redshift.GetClusterCredentialsWithIAMOutput } -func (m *RedshiftMock) GetClusterCredentialsWithContext(aws.Context, *redshift.GetClusterCredentialsInput, ...request.Option) (*redshift.GetClusterCredentialsOutput, error) { - if m.GetClusterCredentialsOutput == nil { - return nil, trace.AccessDenied("access denied") +func (m *RedshiftClient) DescribeClusters(_ context.Context, input *redshift.DescribeClustersInput, _ ...func(*redshift.Options)) (*redshift.DescribeClustersOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") } - return m.GetClusterCredentialsOutput, nil -} - -func (m *RedshiftMock) GetClusterCredentialsWithIAMWithContext(aws.Context, *redshift.GetClusterCredentialsWithIAMInput, ...request.Option) (*redshift.GetClusterCredentialsWithIAMOutput, error) { - if m.GetClusterCredentialsWithIAMOutput == nil { - return nil, trace.AccessDenied("access denied") - } - return m.GetClusterCredentialsWithIAMOutput, nil -} -func (m *RedshiftMock) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { - if aws.StringValue(input.ClusterIdentifier) == "" { + if aws.ToString(input.ClusterIdentifier) == "" { return &redshift.DescribeClustersOutput{ Clusters: m.Clusters, }, nil } for _, cluster := range m.Clusters { - if aws.StringValue(cluster.ClusterIdentifier) == aws.StringValue(input.ClusterIdentifier) { + if aws.ToString(cluster.ClusterIdentifier) == aws.ToString(input.ClusterIdentifier) { return &redshift.DescribeClustersOutput{ - Clusters: []*redshift.Cluster{cluster}, + Clusters: []redshifttypes.Cluster{cluster}, }, nil } } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterIdentifier)) + return nil, trace.NotFound("cluster %v not found", aws.ToString(input.ClusterIdentifier)) } -func (m *RedshiftMock) DescribeClustersPagesWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, fn func(*redshift.DescribeClustersOutput, bool) bool, options ...request.Option) error { - fn(&redshift.DescribeClustersOutput{ - Clusters: m.Clusters, - }, true) - return nil -} - -// RedshiftMockUnauth is a mock Redshift client that returns access denied to each call. -type RedshiftMockUnauth struct { - redshiftiface.RedshiftAPI +func (m *RedshiftClient) GetClusterCredentials(context.Context, *redshift.GetClusterCredentialsInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsOutput, error) { + if m.Unauth || m.GetClusterCredentialsOutput == nil { + return nil, trace.AccessDenied("access denied") + } + return m.GetClusterCredentialsOutput, nil } -func (m *RedshiftMockUnauth) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { - return nil, trace.AccessDenied("unauthorized") +func (m *RedshiftClient) GetClusterCredentialsWithIAM(context.Context, *redshift.GetClusterCredentialsWithIAMInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsWithIAMOutput, error) { + if m.Unauth || m.GetClusterCredentialsWithIAMOutput == nil { + return nil, trace.AccessDenied("access denied") + } + return m.GetClusterCredentialsWithIAMOutput, nil } -// RedshiftGetClusterCredentialsOutput return a sample redshift.GetClusterCredentialsOutput. +// RedshiftGetClusterCredentialsOutput return a sample [redshift.GetClusterCredentialsOutput]. func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork.Clock) *redshift.GetClusterCredentialsOutput { if clock == nil { clock = clockwork.NewRealClock() @@ -99,7 +87,7 @@ func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork. } // RedshiftGetClusterCredentialsWithIAMOutput return a sample -// redshift.GetClusterCredentialsWithIAMeOutput. +// [redshift.GetClusterCredentialsWithIAMOutput]. func RedshiftGetClusterCredentialsWithIAMOutput(user, password string, clock clockwork.Clock) *redshift.GetClusterCredentialsWithIAMOutput { if clock == nil { clock = clockwork.NewRealClock() @@ -111,20 +99,19 @@ func RedshiftGetClusterCredentialsWithIAMOutput(user, password string, clock clo } } -// RedshiftCluster returns a sample redshift.Cluster. -func RedshiftCluster(name, region string, labels map[string]string, opts ...func(*redshift.Cluster)) *redshift.Cluster { - cluster := &redshift.Cluster{ +func RedshiftCluster(name, region string, labels map[string]string, opts ...func(*redshifttypes.Cluster)) redshifttypes.Cluster { + cluster := redshifttypes.Cluster{ ClusterIdentifier: aws.String(name), ClusterNamespaceArn: aws.String(fmt.Sprintf("arn:aws:redshift:%s:123456789012:namespace:%s", region, name)), ClusterStatus: aws.String("available"), - Endpoint: &redshift.Endpoint{ + Endpoint: &redshifttypes.Endpoint{ Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.redshift.amazonaws.com", name, region)), - Port: aws.Int64(5439), + Port: aws.Int32(5439), }, - Tags: libcloudaws.LabelsToTags[redshift.Tag](labels), + Tags: awstesthelpers.LabelsToRedshiftTags(labels), } for _, opt := range opts { - opt(cluster) + opt(&cluster) } return cluster } diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go new file mode 100644 index 0000000000000..67e180b70f6ac --- /dev/null +++ b/lib/cloud/mocks/aws_sts.go @@ -0,0 +1,81 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mocks + +import ( + "context" + "slices" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/gravitational/trace" +) + +// STSClient fakes AWS SDK v2 STSClient API. +// It also wraps the v1 STSClient mock client, so callers can use it in tests for both +// the v1 and v2 interfaces. +// This is useful when recording assumed roles and some services use v1 +// while others use a v2 STSClient client. +// For example: +// +// f := &STSClient{} +// a.stsClientV1 = &f.STSClientV1 +// b.stsClientV2 = f +// ... +// gotRoles := f.GetAssumedRoleARNs() // returns roles that were assumed with either v1 or v2 client. +type STSClient struct { + STSClientV1 + + root *STSClient + credentialProvider aws.CredentialsProvider +} + +func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Every fake client will retrieve its credentials if it has them, and then + // delegate the AssumeRole call to the root faked client. + // In this way, each role in a chain of roles will be assumed and recorded + // by the root fake STS client. + if m.credentialProvider != nil { + _, err := m.credentialProvider.Retrieve(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + } + if m.root != nil { + return m.root.AssumeRole(ctx, in, optFns...) + } + + m.STSClientV1.mu.Lock() + defer m.STSClientV1.mu.Unlock() + if !slices.Contains(m.assumedRoleARNs, aws.ToString(in.RoleArn)) { + m.assumedRoleARNs = append(m.assumedRoleARNs, aws.ToString(in.RoleArn)) + m.assumedRoleExternalIDs = append(m.assumedRoleExternalIDs, aws.ToString(in.ExternalId)) + } + expiry := time.Now().Add(60 * time.Minute) + return &sts.AssumeRoleOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: in.RoleArn, + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: &expiry, + }, + }, nil +} diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go index b032964021b73..ca4f1bd4b58e0 100644 --- a/lib/kube/proxy/kube_creds_test.go +++ b/lib/kube/proxy/kube_creds_test.go @@ -105,7 +105,7 @@ func Test_DynamicKubeCreds(t *testing.T) { Host: "sts.amazonaws.com", Path: "/?Action=GetCallerIdentity&Version=2011-06-15", } - sts := &mocks.STSMock{ + sts := &mocks.STSClientV1{ // u is used to presign the request // here we just verify the pre-signed request includes this url. URL: u, diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 8d415fc8953c0..5977ef629e0e8 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -61,6 +61,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" clients "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" @@ -2441,6 +2442,8 @@ type agentParams struct { CADownloader CADownloader // CloudClients is the cloud API clients for database service. CloudClients clients.Clients + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // AWSMatchers is a list of AWS databases matchers. AWSMatchers []types.AWSMatcher // AzureMatchers is a list of Azure databases matchers. @@ -2481,9 +2484,8 @@ func (p *agentParams) setDefaults(c *testContext) { if p.CloudClients == nil { p.CloudClients = &clients.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, RDS: &mocks.RDSMock{}, - Redshift: &mocks.RedshiftMock{}, RedshiftServerless: &mocks.RedshiftServerlessMock{}, ElastiCache: p.ElastiCache, MemoryDB: p.MemoryDB, @@ -2492,6 +2494,9 @@ func (p *agentParams) setDefaults(c *testContext) { GCPSQL: p.GCPSQL, } } + if p.AWSConfigProvider == nil { + p.AWSConfigProvider = &mocks.AWSConfigProvider{} + } if p.DiscoveryResourceChecker == nil { p.DiscoveryResourceChecker = &fakeDiscoveryResourceChecker{} @@ -2524,10 +2529,11 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a // Create test database auth tokens generator. testAuth, err := newTestAuth(common.AuthConfig{ - AuthClient: c.authClient, - AccessPoint: c.authClient, - Clients: &clients.TestCloudClients{}, - Clock: c.clock, + AuthClient: c.authClient, + AccessPoint: c.authClient, + Clients: &clients.TestCloudClients{}, + Clock: c.clock, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -2596,6 +2602,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a OnReconcile: p.OnReconcile, ConnectionMonitor: connMonitor, CloudClients: p.CloudClients, + AWSConfigProvider: p.AWSConfigProvider, AWSMatchers: p.AWSMatchers, AzureMatchers: p.AzureMatchers, ShutdownPollPeriod: 100 * time.Millisecond, diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go index 706581952b9d2..8222599c318a7 100644 --- a/lib/srv/db/cloud/aws.go +++ b/lib/srv/db/cloud/aws.go @@ -75,7 +75,7 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) { teleport.ComponentKey, "aws", "db", config.database.GetName(), ) - dbConfigurator, err := getDBConfigurator(ctx, logger, config.clients, config.database) + dbConfigurator, err := getDBConfigurator(logger, config.clients, config.database) if err != nil { return nil, trace.Wrap(err) } @@ -102,7 +102,7 @@ type dbIAMAuthConfigurator interface { } // getDBConfigurator returns a database IAM Auth configurator. -func getDBConfigurator(ctx context.Context, logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) { +func getDBConfigurator(logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) { if db.IsRDS() { // Only setting for RDS instances and Aurora clusters. return &rdsDBConfigurator{clients: clients, logger: logger}, nil diff --git a/lib/srv/db/cloud/iam_test.go b/lib/srv/db/cloud/iam_test.go index e55c94345fc33..d13d1fc74b86c 100644 --- a/lib/srv/db/cloud/iam_test.go +++ b/lib/srv/db/cloud/iam_test.go @@ -28,7 +28,6 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -59,13 +58,8 @@ func TestAWSIAM(t *testing.T) { DbClusterResourceId: aws.String("cluster-xyz"), } - redshiftCluster := &redshift.Cluster{ - ClusterNamespaceArn: aws.String("arn:aws:redshift:us-east-2:123456789012:namespace:namespace-xyz"), - ClusterIdentifier: aws.String("redshift-cluster-1"), - } - // Configure mocks. - stsClient := &mocks.STSMock{ + stsClient := &mocks.STSClientV1{ ARN: "arn:aws:iam::123456789012:role/test-role", } @@ -74,10 +68,6 @@ func TestAWSIAM(t *testing.T) { DBClusters: []*rds.DBCluster{auroraCluster}, } - redshiftClient := &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{redshiftCluster}, - } - iamClient := &mocks.IAMMock{} // Setup database resources. @@ -163,10 +153,9 @@ func TestAWSIAM(t *testing.T) { configurator, err := NewIAM(ctx, IAMConfig{ AccessPoint: &mockAccessPoint{}, Clients: &clients.TestCloudClients{ - RDS: rdsClient, - Redshift: redshiftClient, - STS: stsClient, - IAM: iamClient, + RDS: rdsClient, + STS: stsClient, + IAM: iamClient, }, HostID: "host-id", onProcessedTask: func(iamTask, error) { @@ -294,7 +283,7 @@ func TestAWSIAMNoPermissions(t *testing.T) { t.Cleanup(cancel) // Create unauthorized mocks for AWS services. - stsClient := &mocks.STSMock{ + stsClient := &mocks.STSClientV1{ ARN: "arn:aws:iam::123456789012:role/test-role", } // Make configurator. @@ -347,7 +336,6 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "Redshift cluster", meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}}, clients: &clients.TestCloudClients{ - Redshift: &mocks.RedshiftMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, @@ -371,7 +359,6 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "IAM UnmodifiableEntityException", meta: types.AWS{Region: "localhost", AccountID: "123456789012", Redshift: types.Redshift{ClusterID: "redshift-cluster-1"}}, clients: &clients.TestCloudClients{ - Redshift: &mocks.RedshiftMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: awserr.New(iam.ErrCodeUnmodifiableEntityException, "unauthorized", fmt.Errorf("unauthorized")), }, diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 515ff0d83ecbe..031f9fb9dae4c 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -23,15 +23,15 @@ import ( "log/slog" "strings" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" - "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" @@ -39,15 +39,30 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/db/common" discoverycommon "github.com/gravitational/teleport/lib/srv/discovery/common" logutils "github.com/gravitational/teleport/lib/utils/log" ) +// redshiftClient defines a subset of the AWS Redshift client API. +type redshiftClient interface { + redshift.DescribeClustersAPIClient +} + +// redshiftClientProviderFunc provides a [redshiftClient]. +type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient + // MetadataConfig is the cloud metadata service config. type MetadataConfig struct { // Clients is an interface for retrieving cloud clients. Clients cloud.Clients + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + + // redshiftClientProviderFn is an internal-only [redshiftClient] provider + // func that is only set in tests. + redshiftClientProviderFn redshiftClientProviderFunc } // Check validates the metadata service config. @@ -59,6 +74,15 @@ func (c *MetadataConfig) Check() error { } c.Clients = cloudClients } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } + + if c.redshiftClientProviderFn == nil { + c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return redshift.NewFromConfig(cfg, optFns...) + } + } return nil } @@ -177,13 +201,14 @@ func (m *Metadata) fetchRDSProxyMetadata(ctx context.Context, database types.Dat // fetchRedshiftMetadata fetches metadata for the provided Redshift database. func (m *Metadata) fetchRedshiftMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - redshift, err := m.cfg.Clients.GetAWSRedshiftClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + redshift := m.cfg.redshiftClientProviderFn(awsCfg) cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) if err != nil { return nil, trace.Wrap(err) @@ -296,8 +321,8 @@ func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterI } // describeRedshiftCluster returns AWS Redshift cluster for the specified ID. -func describeRedshiftCluster(ctx context.Context, redshiftClient redshiftiface.RedshiftAPI, clusterID string) (*redshift.Cluster, error) { - out, err := redshiftClient.DescribeClustersWithContext(ctx, &redshift.DescribeClustersInput{ +func describeRedshiftCluster(ctx context.Context, clt redshiftClient, clusterID string) (*redshifttypes.Cluster, error) { + out, err := clt.DescribeClusters(ctx, &redshift.DescribeClustersInput{ ClusterIdentifier: aws.String(clusterID), }) if err != nil { @@ -306,7 +331,7 @@ func describeRedshiftCluster(ctx context.Context, redshiftClient redshiftiface.R if len(out.Clusters) != 1 { return nil, trace.BadParameter("expected 1 Redshift cluster for %v, got %+v", clusterID, out.Clusters) } - return out.Clusters[0], nil + return &out.Clusters[0], nil } // describeElastiCacheCluster returns AWS ElastiCache Redis cluster for the @@ -369,7 +394,7 @@ func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface return nil, trace.Wrap(err) } - rdsProxy, err := describeRDSProxy(ctx, rdsClient, aws.StringValue(rdsProxyEndpoint.DBProxyName)) + rdsProxy, err := describeRDSProxy(ctx, rdsClient, aws.ToString(rdsProxyEndpoint.DBProxyName)) if err != nil { return nil, trace.Wrap(err) } @@ -389,7 +414,7 @@ func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, rdsClient rds for _, customEndpoint := range out.DBProxyEndpoints { // Double check if it has the same URI in case multiple custom // endpoints have the same name. - if strings.Contains(uri, aws.StringValue(customEndpoint.Endpoint)) { + if strings.Contains(uri, aws.ToString(customEndpoint.Endpoint)) { return customEndpoint, nil } } @@ -408,7 +433,7 @@ func fetchRedshiftServerlessVPCEndpointMetadata(ctx context.Context, client reds if err != nil { return nil, trace.Wrap(err) } - workgroup, err := describeRedshiftServerlessWorkgroup(ctx, client, aws.StringValue(endpoint.WorkgroupName)) + workgroup, err := describeRedshiftServerlessWorkgroup(ctx, client, aws.ToString(endpoint.WorkgroupName)) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/cloud/meta_test.go b/lib/srv/db/cloud/meta_test.go index c4eb033360f13..9e66a416a2ebb 100644 --- a/lib/srv/db/cloud/meta_test.go +++ b/lib/srv/db/cloud/meta_test.go @@ -22,11 +22,12 @@ import ( "context" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" @@ -78,8 +79,8 @@ func TestAWSMetadata(t *testing.T) { } // Configure Redshift API mock. - redshift := &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{ + redshiftClt := &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{ { ClusterNamespaceArn: aws.String("arn:aws:redshift:us-west-1:123456789012:namespace:namespace-id"), ClusterIdentifier: aws.String("redshift-cluster-1"), @@ -116,7 +117,7 @@ func TestAWSMetadata(t *testing.T) { }, } - stsMock := &mocks.STSMock{} + fakeSTS := &mocks.STSClient{} // Configure Redshift Serverless API mock. redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "us-west-1") @@ -130,12 +131,15 @@ func TestAWSMetadata(t *testing.T) { metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ RDS: rds, - Redshift: redshift, ElastiCache: elasticache, MemoryDB: memorydb, RedshiftServerless: redshiftServerless, - STS: stsMock, + STS: &fakeSTS.STSClientV1, }, + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: fakeSTS, + }, + redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), }) require.NoError(t, err) @@ -392,9 +396,9 @@ func TestAWSMetadata(t *testing.T) { err = metadata.Update(ctx, database) require.NoError(t, err) require.Equal(t, test.outAWS, database.GetAWS()) - require.Equal(t, []string{test.inAWS.AssumeRoleARN}, stsMock.GetAssumedRoleARNs()) - require.Equal(t, []string{test.inAWS.ExternalID}, stsMock.GetAssumedRoleExternalIDs()) - stsMock.ResetAssumeRoleHistory() + require.Equal(t, []string{test.inAWS.AssumeRoleARN}, fakeSTS.GetAssumedRoleARNs()) + require.Equal(t, []string{test.inAWS.ExternalID}, fakeSTS.GetAssumedRoleExternalIDs()) + fakeSTS.ResetAssumeRoleHistory() }) } } @@ -404,17 +408,20 @@ func TestAWSMetadata(t *testing.T) { func TestAWSMetadataNoPermissions(t *testing.T) { // Create unauthorized mocks. rds := &mocks.RDSMockUnauth{} - redshift := &mocks.RedshiftMockUnauth{} + redshiftClt := &mocks.RedshiftClient{Unauth: true} - stsMock := &mocks.STSMock{} + fakeSTS := &mocks.STSClient{} // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - RDS: rds, - Redshift: redshift, - STS: stsMock, + RDS: rds, + STS: &fakeSTS.STSClientV1, + }, + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: fakeSTS, }, + redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), }) require.NoError(t, err) @@ -480,9 +487,15 @@ func TestAWSMetadataNoPermissions(t *testing.T) { err = metadata.Update(ctx, database) require.NoError(t, err) require.Equal(t, test.meta, database.GetAWS()) - require.Equal(t, []string{test.meta.AssumeRoleARN}, stsMock.GetAssumedRoleARNs()) - require.Equal(t, []string{test.meta.ExternalID}, stsMock.GetAssumedRoleExternalIDs()) - stsMock.ResetAssumeRoleHistory() + require.Equal(t, []string{test.meta.AssumeRoleARN}, fakeSTS.GetAssumedRoleARNs()) + require.Equal(t, []string{test.meta.ExternalID}, fakeSTS.GetAssumedRoleExternalIDs()) + fakeSTS.ResetAssumeRoleHistory() }) } } + +func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return c + } +} diff --git a/lib/srv/db/cloud/resource_checker.go b/lib/srv/db/cloud/resource_checker.go index 14b5393cebcc9..ea38c69abecf4 100644 --- a/lib/srv/db/cloud/resource_checker.go +++ b/lib/srv/db/cloud/resource_checker.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" ) @@ -40,6 +41,8 @@ type DiscoveryResourceChecker interface { // DiscoveryResourceCheckerConfig is the config for DiscoveryResourceChecker. type DiscoveryResourceCheckerConfig struct { + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // ResourceMatchers is a list of database resource matchers. ResourceMatchers []services.ResourceMatcher // Clients is an interface for retrieving cloud clients. @@ -59,6 +62,9 @@ func (c *DiscoveryResourceCheckerConfig) CheckAndSetDefaults() error { } c.Clients = cloudClients } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } if c.Context == nil { c.Context = context.Background() } diff --git a/lib/srv/db/cloud/resource_checker_url.go b/lib/srv/db/cloud/resource_checker_url.go index 4ca6ed0b0b081..fdc4efdb65fe9 100644 --- a/lib/srv/db/cloud/resource_checker_url.go +++ b/lib/srv/db/cloud/resource_checker_url.go @@ -27,17 +27,25 @@ import ( "slices" "sync" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" ) // urlChecker validates the database has the correct URL. type urlChecker struct { + // awsConfigProvider provides [aws.Config] for AWS SDK service clients. + awsConfigProvider awsconfig.Provider + // redshiftClientProviderFn is an internal-only [redshiftClient] provider + // func that is only set in tests. + redshiftClientProviderFn redshiftClientProviderFunc + clients cloud.Clients logger *slog.Logger warnOnError bool @@ -52,6 +60,10 @@ type urlChecker struct { func newURLChecker(cfg DiscoveryResourceCheckerConfig) *urlChecker { return &urlChecker{ + awsConfigProvider: cfg.AWSConfigProvider, + redshiftClientProviderFn: func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return redshift.NewFromConfig(cfg, optFns...) + }, clients: cfg.Clients, logger: cfg.Logger, warnOnError: getWarnOnError(), @@ -121,8 +133,13 @@ func requireDatabaseIsEndpoint(ctx context.Context, database types.Database, isE return trace.Wrap(convIsEndpoint(isEndpoint)(ctx, database)) } -func requireDatabaseAddressPort(database types.Database, wantURLHost *string, wantURLPort *int64) error { - wantURL := fmt.Sprintf("%v:%v", aws.StringValue(wantURLHost), aws.Int64Value(wantURLPort)) +// TODO(gavin): remove the generic type parameter after all callers are migrated from AWS SDK v1 (uses *int64) to SDK v2 (uses *int32). +func requireDatabaseAddressPort[T ~int32 | ~int64](database types.Database, wantURLHost *string, wantURLPort *T) error { + var port int + if wantURLPort != nil { + port = int(*wantURLPort) + } + wantURL := fmt.Sprintf("%v:%v", aws.ToString(wantURLHost), port) if database.GetURI() != wantURL { return trace.BadParameter("expect database URL %q but got %q for database %q", wantURL, database.GetURI(), database.GetName()) } diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go index 87a55c4d26f13..336ee197815fb 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws.go +++ b/lib/srv/db/cloud/resource_checker_url_aws.go @@ -32,6 +32,7 @@ import ( apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/cloud" cloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -60,13 +61,23 @@ func (c *urlChecker) checkAWS(describeCheck, basicEndpointCheck checkDatabaseFun } } +const awsPermissionsErrMsg = "" + + "No permissions to describe AWS resource metadata that is needed for validating databases created by Discovery Service. " + + "Basic AWS endpoint validation will be performed instead. For best security, please provide the Database Service with the proper IAM permissions. " + + "Enable --debug mode to see details on which databases require more IAM permissions. See Database Access documentation for more details." + func (c *urlChecker) logAWSAccessDeniedError(ctx context.Context, database types.Database, accessDeniedError error) { c.warnAWSOnce.Do(func() { // TODO(greedy52) add links to doc. - c.logger.WarnContext(ctx, "No permissions to describe AWS resource metadata that is needed for validating databases created by Discovery Service. Basic AWS endpoint validation will be performed instead. For best security, please provide the Database Service with the proper IAM permissions. Enable --debug mode to see details on which databases require more IAM permissions. See Database Access documentation for more details.") + c.logger.WarnContext(ctx, awsPermissionsErrMsg, + "error", accessDeniedError, + ) }) - c.logger.DebugContext(ctx, "No permissions to describe database for URL validation", "database", database.GetName()) + c.logger.DebugContext(ctx, "No permissions to describe database for URL validation", + "database", database.GetName(), + "error", accessDeniedError, + ) } func (c *urlChecker) checkRDS(ctx context.Context, database types.Database) error { @@ -149,13 +160,14 @@ func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database t func (c *urlChecker) checkRedshift(ctx context.Context, database types.Database) error { meta := database.GetAWS() - redshift, err := c.clients.GetAWSRedshiftClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + redshift := c.redshiftClientProviderFn(awsCfg) cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go index 81928cbd7902c..e8ba24f624c16 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws_test.go +++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go @@ -22,17 +22,18 @@ import ( "context" "testing" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" "github.com/gravitational/teleport/lib/utils" @@ -69,7 +70,7 @@ func TestURLChecker_AWS(t *testing.T) { // Redshift. redshiftCluster := mocks.RedshiftCluster("redshift-cluster", region, nil) - redshiftClusterDB, err := common.NewDatabaseFromRedshiftCluster(redshiftCluster) + redshiftClusterDB, err := common.NewDatabaseFromRedshiftCluster(&redshiftCluster) require.NoError(t, err) testCases = append(testCases, redshiftClusterDB) @@ -126,9 +127,6 @@ func TestURLChecker_AWS(t *testing.T) { DBProxies: []*rds.DBProxy{rdsProxy}, DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyCustomEndpoint}, }, - Redshift: &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{redshiftCluster}, - }, RedshiftServerless: &mocks.RedshiftServerlessMock{ Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint}, @@ -142,41 +140,50 @@ func TestURLChecker_AWS(t *testing.T) { OpenSearch: &mocks.OpenSearchMock{ Domains: []*opensearchservice.DomainStatus{openSearchDomain, openSearchVPCDomain}, }, - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, } mockClientsUnauth := &cloud.TestCloudClients{ RDS: &mocks.RDSMockUnauth{}, - Redshift: &mocks.RedshiftMockUnauth{}, RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true}, ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, MemoryDB: &mocks.MemoryDBMock{Unauth: true}, OpenSearch: &mocks.OpenSearchMock{Unauth: true}, - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, } // Test both check methods. // Note that "No permissions" logs should only be printed during the second // group ("basic endpoint check"). methods := []struct { - name string - clients cloud.Clients + name string + clients cloud.Clients + awsConfigProvider awsconfig.Provider + redshiftClient redshiftClient }{ { - name: "API check", - clients: mockClients, + name: "API check", + clients: mockClients, + awsConfigProvider: &mocks.AWSConfigProvider{}, + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{redshiftCluster}, + }, }, { - name: "basic endpoint check", - clients: mockClientsUnauth, + name: "basic endpoint check", + clients: mockClientsUnauth, + awsConfigProvider: &mocks.AWSConfigProvider{}, + redshiftClient: &mocks.RedshiftClient{Unauth: true}, }, } for _, method := range methods { t.Run(method.name, func(t *testing.T) { c := newURLChecker(DiscoveryResourceCheckerConfig{ - Clients: method.clients, - Logger: utils.NewSlogLoggerForTests(), + Clients: method.clients, + AWSConfigProvider: method.awsConfigProvider, + Logger: utils.NewSlogLoggerForTests(), }) + c.redshiftClientProviderFn = newFakeRedshiftClientProvider(method.redshiftClient) for _, database := range testCases { t.Run(database.GetName(), func(t *testing.T) { diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 43f5df408ab85..e567d82d402e0 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -34,13 +34,13 @@ import ( gcpcredentialspb "cloud.google.com/go/iam/credentials/apiv1/credentialspb" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/rds/rdsutils" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -55,6 +55,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" libazure "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/cryptosuites" @@ -124,6 +125,15 @@ type AccessPoint interface { GetAuthPreference(ctx context.Context) (types.AuthPreference, error) } +// redshiftClient defines a subset of the AWS Redshift client API. +type redshiftClient interface { + GetClusterCredentialsWithIAM(context.Context, *redshift.GetClusterCredentialsWithIAMInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsWithIAMOutput, error) + GetClusterCredentials(context.Context, *redshift.GetClusterCredentialsInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsOutput, error) +} + +// redshiftClientProviderFunc provides a [redshiftClient]. +type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient + // AuthConfig is the database access authenticator configuration. type AuthConfig struct { // AuthClient is the cluster auth client. @@ -136,6 +146,13 @@ type AuthConfig struct { Clock clockwork.Clock // Logger is used for logging. Logger *slog.Logger + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + + // redshiftClientProviderFn is an internal-only [redshiftClient] provider + // func that defaults to a func that provides a real Redshift client. + // The default is only overridden in tests. + redshiftClientProviderFn redshiftClientProviderFunc } // CheckAndSetDefaults validates the config and sets defaults. @@ -149,23 +166,28 @@ func (c *AuthConfig) CheckAndSetDefaults() error { if c.Clients == nil { return trace.BadParameter("missing Clients") } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } if c.Clock == nil { c.Clock = clockwork.NewRealClock() } if c.Logger == nil { c.Logger = slog.With(teleport.ComponentKey, "db:auth") } + + if c.redshiftClientProviderFn == nil { + c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return redshift.NewFromConfig(cfg, optFns...) + } + } return nil } func (c *AuthConfig) withLogger(getUpdatedLogger func(*slog.Logger) *slog.Logger) AuthConfig { - return AuthConfig{ - AuthClient: c.AuthClient, - AccessPoint: c.AccessPoint, - Clients: c.Clients, - Clock: c.Clock, - Logger: getUpdatedLogger(c.Logger), - } + cfg := *c + cfg.Logger = getUpdatedLogger(c.Logger) + return cfg } // dbAuth provides utilities for creating TLS configurations and @@ -272,18 +294,12 @@ func (a *dbAuth) getRedshiftIAMRoleAuthToken(ctx context.Context, database types return "", "", trace.Wrap(err) } - baseSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), - ) - if err != nil { - return "", "", trace.Wrap(err) - } // Assume the configured AWS role before assuming the role we need to get the // auth token. This allows cross-account AWS access. - client, err := a.cfg.Clients.GetAWSRedshiftClient(ctx, meta.Region, - cloud.WithChainedAssumeRole(baseSession, roleARN, externalIDForChainedAssumeRole(meta)), - cloud.WithAmbientCredentials(), + awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAssumeRole(roleARN, externalIDForChainedAssumeRole(meta)), + awsconfig.WithAmbientCredentials(), ) if err != nil { return "", "", trace.AccessDenied(`Could not generate Redshift IAM role auth token: @@ -300,7 +316,8 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent "database_user", databaseUser, "database_name", databaseName, ) - resp, err := client.GetClusterCredentialsWithIAMWithContext(ctx, &redshift.GetClusterCredentialsWithIAMInput{ + client := a.cfg.redshiftClientProviderFn(awsCfg) + resp, err := client.GetClusterCredentialsWithIAM(ctx, &redshift.GetClusterCredentialsWithIAMInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), DbName: aws.String(databaseName), }) @@ -318,14 +335,14 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa %v `, err, roleARN, policy) } - return aws.StringValue(resp.DbUser), aws.StringValue(resp.DbPassword), nil + return aws.ToString(resp.DbUser), aws.ToString(resp.DbPassword), nil } func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { meta := database.GetAWS() - redshiftClient, err := a.cfg.Clients.GetAWSRedshiftClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return "", "", trace.Wrap(err) @@ -335,7 +352,8 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types. "database_user", databaseUser, "database_name", databaseName, ) - resp, err := redshiftClient.GetClusterCredentialsWithContext(ctx, &redshift.GetClusterCredentialsInput{ + clt := a.cfg.redshiftClientProviderFn(awsCfg) + resp, err := clt.GetClusterCredentials(ctx, &redshift.GetClusterCredentialsInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), DbUser: aws.String(databaseUser), DbName: aws.String(databaseName), @@ -344,7 +362,7 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types. AutoCreate: aws.Bool(false), // TODO(r0mant): List of additional groups DbUser will join for the // session. Do we need to let people control this? - DbGroups: []*string{}, + DbGroups: []string{}, }) if err != nil { policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) @@ -362,7 +380,7 @@ propagate): %v `, err, policy) } - return aws.StringValue(resp.DbUser), aws.StringValue(resp.DbPassword), nil + return aws.ToString(resp.DbUser), aws.ToString(resp.DbPassword), nil } // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. @@ -422,7 +440,7 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa %v `, err, roleARN, policy) } - return aws.StringValue(resp.DbUser), aws.StringValue(resp.DbPassword), nil + return aws.ToString(resp.DbUser), aws.ToString(resp.DbPassword), nil } // GetCloudSQLAuthToken returns authorization token that will be used as a diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index c02d8a4984d85..ae136b4d53c46 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -25,11 +25,14 @@ import ( "errors" "fmt" "net/url" + "os" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -43,8 +46,14 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" ) +func TestMain(m *testing.M) { + utils.InitLoggerForTests() + os.Exit(m.Run()) +} + func TestAuthGetAzureCacheForRedisToken(t *testing.T) { t.Parallel() @@ -59,6 +68,7 @@ func TestAuthGetAzureCacheForRedisToken(t *testing.T) { Token: "azure-redis-enterprise-token", }), }, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -106,7 +116,7 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { t.Parallel() // setup mock aws sessions. - stsMock := &mocks.STSMock{} + stsMock := &mocks.STSClientV1{} clock := clockwork.NewFakeClock() auth, err := NewAuth(AuthConfig{ Clock: clock, @@ -118,6 +128,7 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), }, }, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -137,9 +148,10 @@ func TestAuthGetTLSConfig(t *testing.T) { t.Parallel() auth, err := NewAuth(AuthConfig{ - AuthClient: new(authClientMock), - AccessPoint: new(accessPointMock), - Clients: &cloud.TestCloudClients{}, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), + Clients: &cloud.TestCloudClients{}, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -346,9 +358,10 @@ func TestGetAzureIdentityResourceID(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { auth, err := NewAuth(AuthConfig{ - AuthClient: new(authClientMock), - AccessPoint: new(accessPointMock), - Clients: tc.clients, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), + Clients: tc.clients, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -379,6 +392,7 @@ func TestGetAzureIdentityResourceIDCache(t *testing.T) { }, AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(virtualMachinesMock), }, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -466,7 +480,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { t.Cleanup(cancel) tests := map[string]struct { checkGetAuthFn func(t *testing.T, auth Auth) - checkSTS func(t *testing.T, stsMock *mocks.STSMock) + checkSTS func(t *testing.T, stsMock *mocks.STSClient) }{ "Redshift": { checkGetAuthFn: func(t *testing.T, auth Auth) { @@ -485,7 +499,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) }, - checkSTS: func(t *testing.T, stsMock *mocks.STSMock) { + checkSTS: func(t *testing.T, stsMock *mocks.STSClient) { t.Helper() require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftRole") require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshift") @@ -508,9 +522,10 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { require.Equal(t, "IAM:some-role", dbUser) require.Equal(t, "some-password-for-some-role", dbPassword) }, - checkSTS: func(t *testing.T, stsMock *mocks.STSMock) { + checkSTS: func(t *testing.T, stsMock *mocks.STSClient) { t.Helper() require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftRole") + require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/some-role") require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshift") }, }, @@ -530,7 +545,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) }, - checkSTS: func(t *testing.T, stsMock *mocks.STSMock) { + checkSTS: func(t *testing.T, stsMock *mocks.STSClient) { t.Helper() require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftServerlessRole") require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshiftServerless") @@ -550,7 +565,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { require.NoError(t, err) require.Contains(t, token, "DBUser=some-user") }, - checkSTS: func(t *testing.T, stsMock *mocks.STSMock) { + checkSTS: func(t *testing.T, stsMock *mocks.STSClient) { t.Helper() require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RDSProxyRole") require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRDSProxy") @@ -578,7 +593,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { require.Equal(t, "arn:aws:iam::123456789012:role/RedisRole/20010203/ca-central-1/elasticache/aws4_request", query.Get("X-Amz-Credential")) }, - checkSTS: func(t *testing.T, stsMock *mocks.STSMock) { + checkSTS: func(t *testing.T, stsMock *mocks.STSClient) { t.Helper() require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedisRole") require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalElastiCacheRedis") @@ -586,23 +601,26 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, } - stsMock := &mocks.STSMock{} + fakeSTS := &mocks.STSClient{} clock := clockwork.NewFakeClockAt(time.Date(2001, time.February, 3, 0, 0, 0, 0, time.UTC)) auth, err := NewAuth(AuthConfig{ Clock: clock, AuthClient: new(authClientMock), AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ - STS: stsMock, + STS: &fakeSTS.STSClientV1, RDS: &mocks.RDSMock{}, - Redshift: &mocks.RedshiftMock{ - GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock), - GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock), - }, RedshiftServerless: &mocks.RedshiftServerlessMock{ GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), }, }, + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: fakeSTS, + }, + redshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock), + GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock), + }), }) require.NoError(t, err) @@ -611,7 +629,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() tt.checkGetAuthFn(t, auth) - tt.checkSTS(t, stsMock) + tt.checkSTS(t, fakeSTS) }) } } @@ -623,7 +641,7 @@ func TestGetAWSIAMCreds(t *testing.T) { for name, tt := range map[string]struct { db types.Database - stsMock *mocks.STSMock + stsMock *mocks.STSClientV1 username string expectedKeyId string expectedAssumedRoles []string @@ -632,7 +650,7 @@ func TestGetAWSIAMCreds(t *testing.T) { }{ "username is full role ARN": { db: newMongoAtlasDatabase(t, types.AWS{}), - stsMock: &mocks.STSMock{}, + stsMock: &mocks.STSClientV1{}, username: "arn:aws:iam::123456789012:role/role-name", expectedKeyId: "arn:aws:iam::123456789012:role/role-name", expectedAssumedRoles: []string{"arn:aws:iam::123456789012:role/role-name"}, @@ -641,7 +659,7 @@ func TestGetAWSIAMCreds(t *testing.T) { }, "username is partial role ARN": { db: newMongoAtlasDatabase(t, types.AWS{}), - stsMock: &mocks.STSMock{ + stsMock: &mocks.STSClientV1{ // This is the role returned by the STS GetCallerIdentity. ARN: "arn:aws:iam::222222222222:role/teleport-service-role", }, @@ -653,7 +671,7 @@ func TestGetAWSIAMCreds(t *testing.T) { }, "unable to fetch account ID": { db: newMongoAtlasDatabase(t, types.AWS{}), - stsMock: &mocks.STSMock{ + stsMock: &mocks.STSClientV1{ ARN: "", }, username: "role/role-name", @@ -664,7 +682,7 @@ func TestGetAWSIAMCreds(t *testing.T) { ExternalID: "123123", AssumeRoleARN: "arn:aws:iam::222222222222:role/teleport-service-role-external", }), - stsMock: &mocks.STSMock{ + stsMock: &mocks.STSClientV1{ ARN: "arn:aws:iam::111111111111:role/teleport-service-role", }, username: "role/role-name", @@ -685,6 +703,7 @@ func TestGetAWSIAMCreds(t *testing.T) { Clients: &cloud.TestCloudClients{ STS: tt.stsMock, }, + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) @@ -1002,3 +1021,9 @@ func (m *imdsMock) GetID(_ context.Context) (string, error) { func (m *imdsMock) GetType() types.InstanceMetadataType { return m.instanceType } + +func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return c + } +} diff --git a/lib/srv/db/common/errors.go b/lib/srv/db/common/errors.go index a86305d13b8ec..7b8456f5e5b35 100644 --- a/lib/srv/db/common/errors.go +++ b/lib/srv/db/common/errors.go @@ -25,6 +25,7 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/go-mysql-org/go-mysql/mysql" "github.com/gravitational/trace" @@ -66,12 +67,15 @@ func ConvertError(err error) error { var googleAPIErr *googleapi.Error var awsRequestFailureErr awserr.RequestFailure + var awsRequestFailureErrV2 *awshttp.ResponseError var azResponseErr *azcore.ResponseError var pgError *pgconn.PgError var myError *mysql.MyError switch err := trace.Unwrap(err); { case errors.As(err, &googleAPIErr): return convertGCPError(googleAPIErr) + case errors.As(err, &awsRequestFailureErrV2): + return awslib.ConvertRequestFailureErrorV2(awsRequestFailureErrV2) case errors.As(err, &awsRequestFailureErr): return awslib.ConvertRequestFailureError(awsRequestFailureErr) case errors.As(err, &azResponseErr): diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index b91245d6898ef..28fcc486bf4db 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -41,6 +41,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" clients "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/inventory" @@ -68,6 +69,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/spanner" "github.com/gravitational/teleport/lib/srv/db/sqlserver" discoverycommon "github.com/gravitational/teleport/lib/srv/discovery/common" + "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/utils" ) @@ -138,6 +140,10 @@ type Config struct { CADownloader CADownloader // CloudClients creates cloud API clients. CloudClients clients.Clients + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + // AWSDatabaseFetcherFactory provides AWS database fetchers + AWSDatabaseFetcherFactory *db.AWSFetcherFactory // CloudMeta fetches cloud metadata for cloud hosted databases. CloudMeta *cloud.Metadata // CloudIAM configures IAM for cloud hosted databases. @@ -192,12 +198,30 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } c.CloudClients = cloudClients } + if c.AWSConfigProvider == nil { + provider, err := awsconfig.NewCache() + if err != nil { + return trace.Wrap(err, "unable to create AWS config provider cache") + } + c.AWSConfigProvider = provider + } + if c.AWSDatabaseFetcherFactory == nil { + factory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + CloudClients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, + }) + if err != nil { + return trace.Wrap(err) + } + c.AWSDatabaseFetcherFactory = factory + } if c.Auth == nil { c.Auth, err = common.NewAuth(common.AuthConfig{ - AuthClient: c.AuthClient, - AccessPoint: c.AccessPoint, - Clock: c.Clock, - Clients: c.CloudClients, + AuthClient: c.AuthClient, + AccessPoint: c.AccessPoint, + Clock: c.Clock, + Clients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -226,7 +250,8 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } if c.CloudMeta == nil { c.CloudMeta, err = cloud.NewMetadata(cloud.MetadataConfig{ - Clients: c.CloudClients, + Clients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -282,9 +307,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { if c.discoveryResourceChecker == nil { c.discoveryResourceChecker, err = cloud.NewDiscoveryResourceChecker(cloud.DiscoveryResourceCheckerConfig{ - ResourceMatchers: c.ResourceMatchers, - Clients: c.CloudClients, - Context: ctx, + ResourceMatchers: c.ResourceMatchers, + Clients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, + Context: ctx, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index f4d7fe86d99ec..2dc1dcb11d35c 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -110,7 +110,7 @@ func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWat // startCloudWatcher starts fetching cloud databases according to the // selectors and register/unregister them appropriately. func (s *Server) startCloudWatcher(ctx context.Context) error { - awsFetchers, err := dbfetchers.MakeAWSFetchers(ctx, s.cfg.CloudClients, s.cfg.AWSMatchers, "" /* discovery config */) + awsFetchers, err := s.cfg.AWSDatabaseFetcherFactory.MakeFetchers(ctx, s.cfg.AWSMatchers, "" /* discovery config */) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/discovery/common/database.go b/lib/srv/discovery/common/database.go index 237c72f2fc76d..8afe335f87fcb 100644 --- a/lib/srv/discovery/common/database.go +++ b/lib/srv/discovery/common/database.go @@ -28,14 +28,14 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql" + "github.com/aws/aws-sdk-go-v2/aws" rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" - "github.com/aws/aws-sdk-go/aws" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" @@ -295,7 +295,7 @@ func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(instance.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(instance.Engine)) if err != nil { return nil, trace.Wrap(err) } @@ -304,10 +304,10 @@ func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error setAWSDBName(types.Metadata{ Description: fmt.Sprintf("RDS instance in %v", metadata.Region), Labels: labelsFromRDSInstance(instance, metadata), - }, aws.StringValue(instance.DBInstanceIdentifier)), + }, aws.ToString(instance.DBInstanceIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint.Address), aws.Int64Value(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), AWS: *metadata, }) } @@ -323,7 +323,7 @@ func NewDatabaseFromRDSV2Instance(instance *rdstypes.DBInstance) (types.Database if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(instance.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(instance.Engine)) if err != nil { return nil, trace.Wrap(err) } @@ -331,9 +331,9 @@ func NewDatabaseFromRDSV2Instance(instance *rdstypes.DBInstance) (types.Database uri := "" if instance.Endpoint != nil && instance.Endpoint.Address != nil { if instance.Endpoint.Port != nil { - uri = fmt.Sprintf("%s:%d", aws.StringValue(instance.Endpoint.Address), *instance.Endpoint.Port) + uri = fmt.Sprintf("%s:%d", aws.ToString(instance.Endpoint.Address), *instance.Endpoint.Port) } else { - uri = aws.StringValue(instance.Endpoint.Address) + uri = aws.ToString(instance.Endpoint.Address) } } @@ -341,7 +341,7 @@ func NewDatabaseFromRDSV2Instance(instance *rdstypes.DBInstance) (types.Database setAWSDBName(types.Metadata{ Description: fmt.Sprintf("RDS instance in %v", metadata.Region), Labels: labelsFromRDSV2Instance(instance, metadata), - }, aws.StringValue(instance.DBInstanceIdentifier)), + }, aws.ToString(instance.DBInstanceIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, URI: uri, @@ -352,7 +352,7 @@ func NewDatabaseFromRDSV2Instance(instance *rdstypes.DBInstance) (types.Database // MetadataFromRDSInstance creates AWS metadata from the provided RDS instance. // It uses aws sdk v2. func MetadataFromRDSV2Instance(rdsInstance *rdstypes.DBInstance) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(rdsInstance.DBInstanceArn)) + parsedARN, err := arn.Parse(aws.ToString(rdsInstance.DBInstanceArn)) if err != nil { return nil, trace.Wrap(err) } @@ -363,10 +363,10 @@ func MetadataFromRDSV2Instance(rdsInstance *rdstypes.DBInstance) (*types.AWS, er Region: parsedARN.Region, AccountID: parsedARN.AccountID, RDS: types.RDS{ - InstanceID: aws.StringValue(rdsInstance.DBInstanceIdentifier), - ClusterID: aws.StringValue(rdsInstance.DBClusterIdentifier), - ResourceID: aws.StringValue(rdsInstance.DbiResourceId), - IAMAuth: aws.BoolValue(rdsInstance.IAMDatabaseAuthenticationEnabled), + InstanceID: aws.ToString(rdsInstance.DBInstanceIdentifier), + ClusterID: aws.ToString(rdsInstance.DBClusterIdentifier), + ResourceID: aws.ToString(rdsInstance.DbiResourceId), + IAMAuth: aws.ToBool(rdsInstance.IAMDatabaseAuthenticationEnabled), Subnets: subnets, VPCID: vpcID, SecurityGroups: rdsSecurityGroupInfo(rdsInstance.VpcSecurityGroups), @@ -378,12 +378,12 @@ func MetadataFromRDSV2Instance(rdsInstance *rdstypes.DBInstance) (*types.AWS, er // It uses aws sdk v2. func labelsFromRDSV2Instance(rdsInstance *rdstypes.DBInstance, meta *types.AWS) map[string]string { labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion) + labels[types.DiscoveryLabelEngine] = aws.ToString(rdsInstance.Engine) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsInstance.EngineVersion) labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance - labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsInstance.DBInstanceStatus) + labels[types.DiscoveryLabelStatus] = aws.ToString(rdsInstance.DBInstanceStatus) if rdsInstance.DBSubnetGroup != nil { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsInstance.DBSubnetGroup.VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList)) } @@ -395,20 +395,20 @@ func NewDatabaseFromRDSV2Cluster(cluster *rdstypes.DBCluster, firstInstance *rds if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(cluster.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(cluster.Engine)) if err != nil { return nil, trace.Wrap(err) } uri := "" if cluster.Endpoint != nil && cluster.Port != nil { - uri = fmt.Sprintf("%v:%v", aws.StringValue(cluster.Endpoint), *cluster.Port) + uri = fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), *cluster.Port) } return types.NewDatabaseV3( setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region), Labels: labelsFromRDSV2Cluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, firstInstance), - }, aws.StringValue(cluster.DBClusterIdentifier)), + }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, URI: uri, @@ -421,10 +421,10 @@ func rdsSubnetGroupToNetworkInfo(subnetGroup *rdstypes.DBSubnetGroup) (vpcID str return } - vpcID = aws.StringValue(subnetGroup.VpcId) + vpcID = aws.ToString(subnetGroup.VpcId) subnets = make([]string, 0, len(subnetGroup.Subnets)) for _, s := range subnetGroup.Subnets { - subnetID := aws.StringValue(s.SubnetIdentifier) + subnetID := aws.ToString(s.SubnetIdentifier) if subnetID != "" { subnets = append(subnets, subnetID) } @@ -439,7 +439,7 @@ func rdsSecurityGroupInfo(memberships []rdstypes.VpcSecurityGroupMembership) []s secGroups = make([]string, 0, len(memberships)) } for _, group := range memberships { - groupID := aws.StringValue(group.VpcSecurityGroupId) + groupID := aws.ToString(group.VpcSecurityGroupId) if groupID != "" { secGroups = append(secGroups, groupID) } @@ -451,7 +451,7 @@ func rdsSecurityGroupInfo(memberships []rdstypes.VpcSecurityGroupMembership) []s // It uses aws sdk v2. // An optional [rdstypes.DBInstance] can be passed to fill the network configuration of the Cluster. func MetadataFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, rdsInstance *rdstypes.DBInstance) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(rdsCluster.DBClusterArn)) + parsedARN, err := arn.Parse(aws.ToString(rdsCluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) } @@ -467,9 +467,9 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, rdsInstance *rdsty Region: parsedARN.Region, AccountID: parsedARN.AccountID, RDS: types.RDS{ - ClusterID: aws.StringValue(rdsCluster.DBClusterIdentifier), - ResourceID: aws.StringValue(rdsCluster.DbClusterResourceId), - IAMAuth: aws.BoolValue(rdsCluster.IAMDatabaseAuthenticationEnabled), + ClusterID: aws.ToString(rdsCluster.DBClusterIdentifier), + ResourceID: aws.ToString(rdsCluster.DbClusterResourceId), + IAMAuth: aws.ToBool(rdsCluster.IAMDatabaseAuthenticationEnabled), Subnets: subnets, VPCID: vpcID, SecurityGroups: rdsSecurityGroupInfo(rdsCluster.VpcSecurityGroups), @@ -481,12 +481,12 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, rdsInstance *rdsty // It uses aws sdk v2. func labelsFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, endpointType string, memberInstance *rdstypes.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion) + labels[types.DiscoveryLabelEngine] = aws.ToString(rdsCluster.Engine) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsCluster.EngineVersion) labels[types.DiscoveryLabelEndpointType] = endpointType - labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsCluster.Status) + labels[types.DiscoveryLabelStatus] = aws.ToString(rdsCluster.Status) if memberInstance != nil && memberInstance.DBSubnetGroup != nil { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstance.DBSubnetGroup.VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(memberInstance.DBSubnetGroup.VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList)) } @@ -497,7 +497,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(cluster.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(cluster.Engine)) if err != nil { return nil, trace.Wrap(err) } @@ -505,10 +505,10 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region), Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypePrimary, memberInstances), - }, aws.StringValue(cluster.DBClusterIdentifier)), + }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.Endpoint), aws.Int64Value(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), AWS: *metadata, }) } @@ -519,7 +519,7 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(cluster.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(cluster.Engine)) if err != nil { return nil, trace.Wrap(err) } @@ -527,10 +527,10 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeReader), Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeReader, memberInstances), - }, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader), + }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.ReaderEndpoint), aws.Int64Value(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), AWS: *metadata, }) } @@ -541,7 +541,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns if err != nil { return nil, trace.Wrap(err) } - protocol, err := rdsEngineToProtocol(aws.StringValue(cluster.Engine)) + protocol, err := rdsEngineToProtocol(aws.ToString(cluster.Engine)) if err != nil { return nil, trace.Wrap(err) } @@ -551,7 +551,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns for _, endpoint := range cluster.CustomEndpoints { // RDS custom endpoint format: // .cluster-custom-. - endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.StringValue(endpoint)) + endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.ToString(endpoint)) if err != nil { errors = append(errors, trace.Wrap(err)) continue @@ -565,16 +565,16 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, apiawsutils.RDSEndpointTypeCustom), Labels: labelsFromRDSCluster(cluster, metadata, apiawsutils.RDSEndpointTypeCustom, memberInstances), - }, aws.StringValue(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName), + }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint), aws.Int64Value(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint), aws.ToInt64(cluster.Port)), AWS: *metadata, // Aurora instances update their certificates upon restart, and thus custom endpoint SAN may not be available right // away. Using primary endpoint instead as server name since it's always available. TLS: types.DatabaseTLS{ - ServerName: aws.StringValue(cluster.Endpoint), + ServerName: aws.ToString(cluster.Endpoint), }, }) if err != nil { @@ -591,7 +591,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReaderInstance bool) { for _, clusterMember := range cluster.DBClusterMembers { if clusterMember != nil { - if aws.BoolValue(clusterMember.IsClusterWriter) { + if aws.ToBool(clusterMember.IsClusterWriter) { hasWriterInstance = true } else { hasReaderInstance = true @@ -692,10 +692,10 @@ func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Dat setAWSDBName(types.Metadata{ Description: fmt.Sprintf("DocumentDB cluster in %v", metadata.Region), Labels: labelsFromDocumentDBCluster(cluster, metadata, endpointType), - }, aws.StringValue(cluster.DBClusterIdentifier)), + }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.Endpoint), aws.Int64Value(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), AWS: *metadata, }) } @@ -712,10 +712,10 @@ func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Data setAWSDBName(types.Metadata{ Description: fmt.Sprintf("DocumentDB cluster in %v (%v endpoint)", metadata.Region, endpointType), Labels: labelsFromDocumentDBCluster(cluster, metadata, endpointType), - }, aws.StringValue(cluster.DBClusterIdentifier), endpointType), + }, aws.ToString(cluster.DBClusterIdentifier), endpointType), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.ReaderEndpoint), aws.Int64Value(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), AWS: *metadata, }) } @@ -726,7 +726,7 @@ func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Datab if err != nil { return nil, trace.Wrap(err) } - protocol, port, err := rdsEngineFamilyToProtocolAndPort(aws.StringValue(dbProxy.EngineFamily)) + protocol, port, err := rdsEngineFamilyToProtocolAndPort(aws.ToString(dbProxy.EngineFamily)) if err != nil { return nil, trace.Wrap(err) } @@ -734,10 +734,10 @@ func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Datab setAWSDBName(types.Metadata{ Description: fmt.Sprintf("RDS Proxy in %v", metadata.Region), Labels: labelsFromRDSProxy(dbProxy, metadata, tags), - }, aws.StringValue(dbProxy.DBProxyName)), + }, aws.ToString(dbProxy.DBProxyName)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%s:%d", aws.StringValue(dbProxy.Endpoint), port), + URI: fmt.Sprintf("%s:%d", aws.ToString(dbProxy.Endpoint), port), AWS: *metadata, }) } @@ -749,7 +749,7 @@ func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint if err != nil { return nil, trace.Wrap(err) } - protocol, port, err := rdsEngineFamilyToProtocolAndPort(aws.StringValue(dbProxy.EngineFamily)) + protocol, port, err := rdsEngineFamilyToProtocolAndPort(aws.ToString(dbProxy.EngineFamily)) if err != nil { return nil, trace.Wrap(err) } @@ -757,10 +757,10 @@ func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint setAWSDBName(types.Metadata{ Description: fmt.Sprintf("RDS Proxy endpoint in %v", metadata.Region), Labels: labelsFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, metadata, tags), - }, aws.StringValue(dbProxy.DBProxyName), aws.StringValue(customEndpoint.DBProxyEndpointName)), + }, aws.ToString(dbProxy.DBProxyName), aws.ToString(customEndpoint.DBProxyEndpointName)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%s:%d", aws.StringValue(customEndpoint.Endpoint), port), + URI: fmt.Sprintf("%s:%d", aws.ToString(customEndpoint.Endpoint), port), AWS: *metadata, // RDS proxies serve wildcard certificates like this: @@ -773,17 +773,17 @@ func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint // Using proxy's default endpoint as server name as it should always // succeed. TLS: types.DatabaseTLS{ - ServerName: aws.StringValue(dbProxy.Endpoint), + ServerName: aws.ToString(dbProxy.Endpoint), }, }) } // NewDatabaseFromRedshiftCluster creates a database resource from a Redshift cluster. -func NewDatabaseFromRedshiftCluster(cluster *redshift.Cluster) (types.Database, error) { +func NewDatabaseFromRedshiftCluster(cluster *redshifttypes.Cluster) (types.Database, error) { // Endpoint can be nil while the cluster is being created. Return an error // until the Endpoint is available. if cluster.Endpoint == nil { - return nil, trace.BadParameter("missing endpoint in Redshift cluster %v", aws.StringValue(cluster.ClusterIdentifier)) + return nil, trace.BadParameter("missing endpoint in Redshift cluster %v", aws.ToString(cluster.ClusterIdentifier)) } metadata, err := MetadataFromRedshiftCluster(cluster) @@ -795,10 +795,10 @@ func NewDatabaseFromRedshiftCluster(cluster *redshift.Cluster) (types.Database, setAWSDBName(types.Metadata{ Description: fmt.Sprintf("Redshift cluster in %v", metadata.Region), Labels: labelsFromRedshiftCluster(cluster, metadata), - }, aws.StringValue(cluster.ClusterIdentifier)), + }, aws.ToString(cluster.ClusterIdentifier)), types.DatabaseSpecV3{ Protocol: defaults.ProtocolPostgres, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.Endpoint.Address), aws.Int64Value(cluster.Endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint.Address), aws.ToInt32(cluster.Endpoint.Port)), AWS: *metadata, }) } @@ -842,7 +842,7 @@ func NewDatabasesFromElastiCacheNodeGroups(cluster *elasticache.ReplicationGroup func NewDatabasesFromElastiCacheReplicationGroup(cluster *elasticache.ReplicationGroup, extraLabels map[string]string) (types.Databases, error) { // Create database using configuration endpoint for Redis with cluster // mode enabled. - if aws.BoolValue(cluster.ClusterEnabled) { + if aws.ToBool(cluster.ClusterEnabled) { database, err := NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) if err != nil { return nil, trace.Wrap(err) @@ -873,9 +873,9 @@ func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *ela return types.NewDatabaseV3(setAWSDBName(types.Metadata{ Description: fmt.Sprintf("ElastiCache cluster in %v (%v endpoint)", metadata.Region, endpointType), Labels: labelsFromMetaAndEndpointType(metadata, endpointType, extraLabels), - }, aws.StringValue(cluster.ReplicationGroupId), suffix...), types.DatabaseSpecV3{ + }, aws.ToString(cluster.ReplicationGroupId), suffix...), types.DatabaseSpecV3{ Protocol: defaults.ProtocolRedis, - URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint.Address), aws.Int64Value(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), AWS: *metadata, }) } @@ -884,7 +884,7 @@ func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *ela func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, tags []*opensearchservice.Tag) (types.Databases, error) { var databases types.Databases - if aws.StringValue(domain.Endpoint) != "" { + if aws.ToString(domain.Endpoint) != "" { metadata, err := MetadataFromOpenSearchDomain(domain, apiawsutils.OpenSearchDefaultEndpoint) if err != nil { return nil, trace.Wrap(err) @@ -895,10 +895,10 @@ func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, ta Labels: labelsFromOpenSearchDomain(domain, metadata, apiawsutils.OpenSearchDefaultEndpoint, tags), } - meta = setAWSDBName(meta, aws.StringValue(domain.DomainName)) + meta = setAWSDBName(meta, aws.ToString(domain.DomainName)) spec := types.DatabaseSpecV3{ Protocol: defaults.ProtocolOpenSearch, - URI: fmt.Sprintf("%v:443", aws.StringValue(domain.Endpoint)), + URI: fmt.Sprintf("%v:443", aws.ToString(domain.Endpoint)), AWS: *metadata, } @@ -910,7 +910,7 @@ func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, ta databases = append(databases, db) } - if domain.DomainEndpointOptions != nil && aws.StringValue(domain.DomainEndpointOptions.CustomEndpoint) != "" { + if domain.DomainEndpointOptions != nil && aws.ToString(domain.DomainEndpointOptions.CustomEndpoint) != "" { metadata, err := MetadataFromOpenSearchDomain(domain, apiawsutils.OpenSearchCustomEndpoint) if err != nil { return nil, trace.Wrap(err) @@ -921,10 +921,10 @@ func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, ta Labels: labelsFromOpenSearchDomain(domain, metadata, apiawsutils.OpenSearchCustomEndpoint, tags), } - meta = setAWSDBName(meta, aws.StringValue(domain.DomainName), "custom") + meta = setAWSDBName(meta, aws.ToString(domain.DomainName), "custom") spec := types.DatabaseSpecV3{ Protocol: defaults.ProtocolOpenSearch, - URI: fmt.Sprintf("%v:443", aws.StringValue(domain.DomainEndpointOptions.CustomEndpoint)), + URI: fmt.Sprintf("%v:443", aws.ToString(domain.DomainEndpointOptions.CustomEndpoint)), AWS: *metadata, } @@ -948,13 +948,13 @@ func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, ta } if domain.VPCOptions != nil { - meta.Labels[types.DiscoveryLabelVPCID] = aws.StringValue(domain.VPCOptions.VPCId) + meta.Labels[types.DiscoveryLabelVPCID] = aws.ToString(domain.VPCOptions.VPCId) } - meta = setAWSDBName(meta, aws.StringValue(domain.DomainName), name) + meta = setAWSDBName(meta, aws.ToString(domain.DomainName), name) spec := types.DatabaseSpecV3{ Protocol: defaults.ProtocolOpenSearch, - URI: fmt.Sprintf("%v:443", aws.StringValue(url)), + URI: fmt.Sprintf("%v:443", aws.ToString(url)), AWS: *metadata, } @@ -983,10 +983,10 @@ func NewDatabaseFromMemoryDBCluster(cluster *memorydb.Cluster, extraLabels map[s setAWSDBName(types.Metadata{ Description: fmt.Sprintf("MemoryDB cluster in %v", metadata.Region), Labels: labelsFromMetaAndEndpointType(metadata, endpointType, extraLabels), - }, aws.StringValue(cluster.Name)), + }, aws.ToString(cluster.Name)), types.DatabaseSpecV3{ Protocol: defaults.ProtocolRedis, - URI: fmt.Sprintf("%v:%v", aws.StringValue(cluster.ClusterEndpoint.Address), aws.Int64Value(cluster.ClusterEndpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ClusterEndpoint.Address), aws.ToInt64(cluster.ClusterEndpoint.Port)), AWS: *metadata, }) } @@ -1010,7 +1010,7 @@ func NewDatabaseFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Wo }, metadata.RedshiftServerless.WorkgroupName), types.DatabaseSpecV3{ Protocol: defaults.ProtocolPostgres, - URI: fmt.Sprintf("%v:%v", aws.StringValue(workgroup.Endpoint.Address), aws.Int64Value(workgroup.Endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(workgroup.Endpoint.Address), aws.ToInt64(workgroup.Endpoint.Port)), AWS: *metadata, }) } @@ -1034,19 +1034,19 @@ func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.E }, metadata.RedshiftServerless.WorkgroupName, metadata.RedshiftServerless.EndpointName), types.DatabaseSpecV3{ Protocol: defaults.ProtocolPostgres, - URI: fmt.Sprintf("%v:%v", aws.StringValue(endpoint.Address), aws.Int64Value(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), AWS: *metadata, // Use workgroup's default address as the server name. TLS: types.DatabaseTLS{ - ServerName: aws.StringValue(workgroup.Endpoint.Address), + ServerName: aws.ToString(workgroup.Endpoint.Address), }, }) } // MetadataFromRDSInstance creates AWS metadata from the provided RDS instance. func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(rdsInstance.DBInstanceArn)) + parsedARN, err := arn.Parse(aws.ToString(rdsInstance.DBInstanceArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1054,17 +1054,17 @@ func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { Region: parsedARN.Region, AccountID: parsedARN.AccountID, RDS: types.RDS{ - InstanceID: aws.StringValue(rdsInstance.DBInstanceIdentifier), - ClusterID: aws.StringValue(rdsInstance.DBClusterIdentifier), - ResourceID: aws.StringValue(rdsInstance.DbiResourceId), - IAMAuth: aws.BoolValue(rdsInstance.IAMDatabaseAuthenticationEnabled), + InstanceID: aws.ToString(rdsInstance.DBInstanceIdentifier), + ClusterID: aws.ToString(rdsInstance.DBClusterIdentifier), + ResourceID: aws.ToString(rdsInstance.DbiResourceId), + IAMAuth: aws.ToBool(rdsInstance.IAMDatabaseAuthenticationEnabled), }, }, nil } // MetadataFromRDSCluster creates AWS metadata from the provided RDS cluster. func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(rdsCluster.DBClusterArn)) + parsedARN, err := arn.Parse(aws.ToString(rdsCluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1072,9 +1072,9 @@ func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { Region: parsedARN.Region, AccountID: parsedARN.AccountID, RDS: types.RDS{ - ClusterID: aws.StringValue(rdsCluster.DBClusterIdentifier), - ResourceID: aws.StringValue(rdsCluster.DbClusterResourceId), - IAMAuth: aws.BoolValue(rdsCluster.IAMDatabaseAuthenticationEnabled), + ClusterID: aws.ToString(rdsCluster.DBClusterIdentifier), + ResourceID: aws.ToString(rdsCluster.DbClusterResourceId), + IAMAuth: aws.ToBool(rdsCluster.IAMDatabaseAuthenticationEnabled), }, }, nil } @@ -1082,7 +1082,7 @@ func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { // MetadataFromDocumentDBCluster creates AWS metadata from the provided // DocumentDB cluster. func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(cluster.DBClusterArn)) + parsedARN, err := arn.Parse(aws.ToString(cluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1090,7 +1090,7 @@ func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) Region: parsedARN.Region, AccountID: parsedARN.AccountID, DocumentDB: types.DocumentDB{ - ClusterID: aws.StringValue(cluster.DBClusterIdentifier), + ClusterID: aws.ToString(cluster.DBClusterIdentifier), EndpointType: endpointType, }, }, nil @@ -1098,7 +1098,7 @@ func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) // MetadataFromRDSProxy creates AWS metadata from the provided RDS Proxy. func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(rdsProxy.DBProxyArn)) + parsedARN, err := arn.Parse(aws.ToString(rdsProxy.DBProxyArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1112,14 +1112,14 @@ func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { // resource type is "db-proxy" and the resource ID is "prx-xxxyyyzzz". _, resourceID, ok := strings.Cut(parsedARN.Resource, ":") if !ok { - return nil, trace.BadParameter("failed to find resource ID from %v", aws.StringValue(rdsProxy.DBProxyArn)) + return nil, trace.BadParameter("failed to find resource ID from %v", aws.ToString(rdsProxy.DBProxyArn)) } return &types.AWS{ Region: parsedARN.Region, AccountID: parsedARN.AccountID, RDSProxy: types.RDSProxy{ - Name: aws.StringValue(rdsProxy.DBProxyName), + Name: aws.ToString(rdsProxy.DBProxyName), ResourceID: resourceID, }, }, nil @@ -1135,13 +1135,13 @@ func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *r return nil, trace.Wrap(err) } - metadata.RDSProxy.CustomEndpointName = aws.StringValue(customEndpoint.DBProxyEndpointName) + metadata.RDSProxy.CustomEndpointName = aws.ToString(customEndpoint.DBProxyEndpointName) return metadata, nil } // MetadataFromRedshiftCluster creates AWS metadata from the provided Redshift cluster. -func MetadataFromRedshiftCluster(cluster *redshift.Cluster) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(cluster.ClusterNamespaceArn)) +func MetadataFromRedshiftCluster(cluster *redshifttypes.Cluster) (*types.AWS, error) { + parsedARN, err := arn.Parse(aws.ToString(cluster.ClusterNamespaceArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1149,7 +1149,7 @@ func MetadataFromRedshiftCluster(cluster *redshift.Cluster) (*types.AWS, error) Region: parsedARN.Region, AccountID: parsedARN.AccountID, Redshift: types.Redshift{ - ClusterID: aws.StringValue(cluster.ClusterIdentifier), + ClusterID: aws.ToString(cluster.ClusterIdentifier), }, }, nil } @@ -1157,7 +1157,7 @@ func MetadataFromRedshiftCluster(cluster *redshift.Cluster) (*types.AWS, error) // MetadataFromElastiCacheCluster creates AWS metadata for the provided // ElastiCache cluster. func MetadataFromElastiCacheCluster(cluster *elasticache.ReplicationGroup, endpointType string) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(cluster.ARN)) + parsedARN, err := arn.Parse(aws.ToString(cluster.ARN)) if err != nil { return nil, trace.Wrap(err) } @@ -1168,16 +1168,16 @@ func MetadataFromElastiCacheCluster(cluster *elasticache.ReplicationGroup, endpo // messages don't fail. var userGroupIDs []string if len(cluster.UserGroupIds) != 0 { - userGroupIDs = aws.StringValueSlice(cluster.UserGroupIds) + userGroupIDs = aws.ToStringSlice(cluster.UserGroupIds) } return &types.AWS{ Region: parsedARN.Region, AccountID: parsedARN.AccountID, ElastiCache: types.ElastiCache{ - ReplicationGroupID: aws.StringValue(cluster.ReplicationGroupId), + ReplicationGroupID: aws.ToString(cluster.ReplicationGroupId), UserGroupIDs: userGroupIDs, - TransitEncryptionEnabled: aws.BoolValue(cluster.TransitEncryptionEnabled), + TransitEncryptionEnabled: aws.ToBool(cluster.TransitEncryptionEnabled), EndpointType: endpointType, }, }, nil @@ -1185,7 +1185,7 @@ func MetadataFromElastiCacheCluster(cluster *elasticache.ReplicationGroup, endpo // MetadataFromOpenSearchDomain creates AWS metadata for the provided OpenSearch domain. func MetadataFromOpenSearchDomain(domain *opensearchservice.DomainStatus, endpointType string) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(domain.ARN)) + parsedARN, err := arn.Parse(aws.ToString(domain.ARN)) if err != nil { return nil, trace.Wrap(err) } @@ -1194,8 +1194,8 @@ func MetadataFromOpenSearchDomain(domain *opensearchservice.DomainStatus, endpoi Region: parsedARN.Region, AccountID: parsedARN.AccountID, OpenSearch: types.OpenSearch{ - DomainName: aws.StringValue(domain.DomainName), - DomainID: aws.StringValue(domain.DomainId), + DomainName: aws.ToString(domain.DomainName), + DomainID: aws.ToString(domain.DomainId), EndpointType: endpointType, }, }, nil @@ -1204,7 +1204,7 @@ func MetadataFromOpenSearchDomain(domain *opensearchservice.DomainStatus, endpoi // MetadataFromMemoryDBCluster creates AWS metadata for the provided MemoryDB // cluster. func MetadataFromMemoryDBCluster(cluster *memorydb.Cluster, endpointType string) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(cluster.ARN)) + parsedARN, err := arn.Parse(aws.ToString(cluster.ARN)) if err != nil { return nil, trace.Wrap(err) } @@ -1213,9 +1213,9 @@ func MetadataFromMemoryDBCluster(cluster *memorydb.Cluster, endpointType string) Region: parsedARN.Region, AccountID: parsedARN.AccountID, MemoryDB: types.MemoryDB{ - ClusterName: aws.StringValue(cluster.Name), - ACLName: aws.StringValue(cluster.ACLName), - TLSEnabled: aws.BoolValue(cluster.TLSEnabled), + ClusterName: aws.ToString(cluster.Name), + ACLName: aws.ToString(cluster.ACLName), + TLSEnabled: aws.ToBool(cluster.TLSEnabled), EndpointType: endpointType, }, }, nil @@ -1224,7 +1224,7 @@ func MetadataFromMemoryDBCluster(cluster *memorydb.Cluster, endpointType string) // MetadataFromRedshiftServerlessWorkgroup creates AWS metadata for the // provided Redshift Serverless Workgroup. func MetadataFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgroup) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(workgroup.WorkgroupArn)) + parsedARN, err := arn.Parse(aws.ToString(workgroup.WorkgroupArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1233,8 +1233,8 @@ func MetadataFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workg Region: parsedARN.Region, AccountID: parsedARN.AccountID, RedshiftServerless: types.RedshiftServerless{ - WorkgroupName: aws.StringValue(workgroup.WorkgroupName), - WorkgroupID: aws.StringValue(workgroup.WorkgroupId), + WorkgroupName: aws.ToString(workgroup.WorkgroupName), + WorkgroupID: aws.ToString(workgroup.WorkgroupId), }, }, nil } @@ -1242,7 +1242,7 @@ func MetadataFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workg // MetadataFromRedshiftServerlessVPCEndpoint creates AWS metadata for the // provided Redshift Serverless VPC endpoint. func MetadataFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.EndpointAccess, workgroup *redshiftserverless.Workgroup) (*types.AWS, error) { - parsedARN, err := arn.Parse(aws.StringValue(endpoint.EndpointArn)) + parsedARN, err := arn.Parse(aws.ToString(endpoint.EndpointArn)) if err != nil { return nil, trace.Wrap(err) } @@ -1251,9 +1251,9 @@ func MetadataFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.Endp Region: parsedARN.Region, AccountID: parsedARN.AccountID, RedshiftServerless: types.RedshiftServerless{ - WorkgroupName: aws.StringValue(endpoint.WorkgroupName), - EndpointName: aws.StringValue(endpoint.EndpointName), - WorkgroupID: aws.StringValue(workgroup.WorkgroupId), + WorkgroupName: aws.ToString(endpoint.WorkgroupName), + EndpointName: aws.ToString(endpoint.EndpointName), + WorkgroupID: aws.ToString(workgroup.WorkgroupId), }, }, nil } @@ -1261,15 +1261,15 @@ func MetadataFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.Endp // ExtraElastiCacheLabels returns a list of extra labels for provided // ElastiCache cluster. func ExtraElastiCacheLabels(cluster *elasticache.ReplicationGroup, tags []*elasticache.Tag, allNodes []*elasticache.CacheCluster, allSubnetGroups []*elasticache.CacheSubnetGroup) map[string]string { - replicationGroupID := aws.StringValue(cluster.ReplicationGroupId) + replicationGroupID := aws.ToString(cluster.ReplicationGroupId) subnetGroupName := "" labels := make(map[string]string) // Find any node belongs to this cluster and set engine version label. for _, node := range allNodes { - if aws.StringValue(node.ReplicationGroupId) == replicationGroupID { - subnetGroupName = aws.StringValue(node.CacheSubnetGroupName) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(node.EngineVersion) + if aws.ToString(node.ReplicationGroupId) == replicationGroupID { + subnetGroupName = aws.ToString(node.CacheSubnetGroupName) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(node.EngineVersion) break } } @@ -1280,8 +1280,8 @@ func ExtraElastiCacheLabels(cluster *elasticache.ReplicationGroup, tags []*elast // accessible within the same VPC. Having a VPC ID label can be very useful // for filtering. for _, subnetGroup := range allSubnetGroups { - if aws.StringValue(subnetGroup.CacheSubnetGroupName) == subnetGroupName { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(subnetGroup.VpcId) + if aws.ToString(subnetGroup.CacheSubnetGroupName) == subnetGroupName { + labels[types.DiscoveryLabelVPCID] = aws.ToString(subnetGroup.VpcId) break } } @@ -1296,12 +1296,12 @@ func ExtraMemoryDBLabels(cluster *memorydb.Cluster, tags []*memorydb.Tag, allSub labels := make(map[string]string) // Engine version. - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(cluster.EngineVersion) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(cluster.EngineVersion) // VPC ID. for _, subnetGroup := range allSubnetGroups { - if aws.StringValue(subnetGroup.Name) == aws.StringValue(cluster.SubnetGroupName) { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(subnetGroup.VpcId) + if aws.ToString(subnetGroup.Name) == aws.ToString(cluster.SubnetGroupName) { + labels[types.DiscoveryLabelVPCID] = aws.ToString(subnetGroup.VpcId) break } } @@ -1423,11 +1423,11 @@ func labelsFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers.Serv // labelsFromRDSInstance creates database labels for the provided RDS instance. func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[string]string { labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion) + labels[types.DiscoveryLabelEngine] = aws.ToString(rdsInstance.Engine) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsInstance.EngineVersion) labels[types.DiscoveryLabelEndpointType] = apiawsutils.RDSEndpointTypeInstance if rdsInstance.DBSubnetGroup != nil { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsInstance.DBSubnetGroup.VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList)) } @@ -1435,19 +1435,19 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str // labelsFromRDSCluster creates database labels for the provided RDS cluster. func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion) + labels[types.DiscoveryLabelEngine] = aws.ToString(rdsCluster.Engine) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsCluster.EngineVersion) labels[types.DiscoveryLabelEndpointType] = endpointType if len(memberInstances) > 0 && memberInstances[0].DBSubnetGroup != nil { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstances[0].DBSubnetGroup.VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(memberInstances[0].DBSubnetGroup.VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList)) } func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpointType string) map[string]string { labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelEngine] = aws.StringValue(cluster.Engine) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(cluster.EngineVersion) + labels[types.DiscoveryLabelEngine] = aws.ToString(cluster.Engine) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(cluster.EngineVersion) labels[types.DiscoveryLabelEndpointType] = endpointType return addLabels(labels, libcloudaws.TagsToLabels(cluster.TagList)) } @@ -1456,8 +1456,8 @@ func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpoi func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) map[string]string { // rds.DBProxy has no TagList. labels := labelsFromAWSMetadata(meta) - labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsProxy.VpcId) - labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsProxy.EngineFamily) + labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsProxy.VpcId) + labels[types.DiscoveryLabelEngine] = aws.ToString(rdsProxy.EngineFamily) return addLabels(labels, libcloudaws.TagsToLabels(tags)) } @@ -1465,12 +1465,12 @@ func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) // RDS Proxy custom endpoint. func labelsFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, meta *types.AWS, tags []*rds.Tag) map[string]string { labels := labelsFromRDSProxy(rdsProxy, meta, tags) - labels[types.DiscoveryLabelEndpointType] = aws.StringValue(customEndpoint.TargetRole) + labels[types.DiscoveryLabelEndpointType] = aws.ToString(customEndpoint.TargetRole) return labels } // labelsFromRedshiftCluster creates database labels for the provided Redshift cluster. -func labelsFromRedshiftCluster(cluster *redshift.Cluster, meta *types.AWS) map[string]string { +func labelsFromRedshiftCluster(cluster *redshifttypes.Cluster, meta *types.AWS) map[string]string { labels := labelsFromAWSMetadata(meta) return addLabels(labels, libcloudaws.TagsToLabels(cluster.Tags)) } @@ -1478,9 +1478,9 @@ func labelsFromRedshiftCluster(cluster *redshift.Cluster, meta *types.AWS) map[s func labelsFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgroup, meta *types.AWS, tags []*redshiftserverless.Tag) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEndpointType] = services.RedshiftServerlessWorkgroupEndpoint - labels[types.DiscoveryLabelNamespace] = aws.StringValue(workgroup.NamespaceName) + labels[types.DiscoveryLabelNamespace] = aws.ToString(workgroup.NamespaceName) if workgroup.Endpoint != nil && len(workgroup.Endpoint.VpcEndpoints) > 0 { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(workgroup.Endpoint.VpcEndpoints[0].VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(workgroup.Endpoint.VpcEndpoints[0].VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(tags)) } @@ -1488,10 +1488,10 @@ func labelsFromRedshiftServerlessWorkgroup(workgroup *redshiftserverless.Workgro func labelsFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.EndpointAccess, workgroup *redshiftserverless.Workgroup, meta *types.AWS, tags []*redshiftserverless.Tag) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEndpointType] = services.RedshiftServerlessVPCEndpoint - labels[types.DiscoveryLabelWorkgroup] = aws.StringValue(endpoint.WorkgroupName) - labels[types.DiscoveryLabelNamespace] = aws.StringValue(workgroup.NamespaceName) + labels[types.DiscoveryLabelWorkgroup] = aws.ToString(endpoint.WorkgroupName) + labels[types.DiscoveryLabelNamespace] = aws.ToString(workgroup.NamespaceName) if endpoint.VpcEndpoint != nil { - labels[types.DiscoveryLabelVPCID] = aws.StringValue(endpoint.VpcEndpoint.VpcId) + labels[types.DiscoveryLabelVPCID] = aws.ToString(endpoint.VpcEndpoint.VpcId) } return addLabels(labels, libcloudaws.TagsToLabels(tags)) } @@ -1509,7 +1509,7 @@ func labelsFromAWSMetadata(meta *types.AWS) map[string]string { func labelsFromOpenSearchDomain(domain *opensearchservice.DomainStatus, meta *types.AWS, endpointType string, tags []*opensearchservice.Tag) map[string]string { labels := labelsFromMetaAndEndpointType(meta, endpointType, libcloudaws.TagsToLabels(tags)) - labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(domain.EngineVersion) + labels[types.DiscoveryLabelEngineVersion] = aws.ToString(domain.EngineVersion) return labels } diff --git a/lib/srv/discovery/common/database_test.go b/lib/srv/discovery/common/database_test.go index 24cb2dacfd483..ab2b45fff24bc 100644 --- a/lib/srv/discovery/common/database_test.go +++ b/lib/srv/discovery/common/database_test.go @@ -27,12 +27,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql" + "github.com/aws/aws-sdk-go-v2/aws" rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" - "github.com/aws/aws-sdk-go/aws" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -1182,14 +1182,14 @@ func TestAzureTagsToLabels(t *testing.T) { // TestDatabaseFromRedshiftCluster tests converting an Redshift cluster to a database resource. func TestDatabaseFromRedshiftCluster(t *testing.T) { t.Run("success", func(t *testing.T) { - cluster := &redshift.Cluster{ + cluster := &redshifttypes.Cluster{ ClusterIdentifier: aws.String("mycluster"), ClusterNamespaceArn: aws.String("arn:aws:redshift:us-east-1:123456789012:namespace:u-u-i-d"), - Endpoint: &redshift.Endpoint{ + Endpoint: &redshifttypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5439), + Port: aws.Int32(5439), }, - Tags: []*redshift.Tag{ + Tags: []redshifttypes.Tag{ { Key: aws.String("key"), Value: aws.String("val"), @@ -1231,14 +1231,14 @@ func TestDatabaseFromRedshiftCluster(t *testing.T) { for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { t.Run("success with name override via"+overrideLabel, func(t *testing.T) { - cluster := &redshift.Cluster{ + cluster := &redshifttypes.Cluster{ ClusterIdentifier: aws.String("mycluster"), ClusterNamespaceArn: aws.String("arn:aws:redshift:us-east-1:123456789012:namespace:u-u-i-d"), - Endpoint: &redshift.Endpoint{ + Endpoint: &redshifttypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5439), + Port: aws.Int32(5439), }, - Tags: []*redshift.Tag{ + Tags: []redshifttypes.Tag{ { Key: aws.String("key"), Value: aws.String("val"), @@ -1284,7 +1284,7 @@ func TestDatabaseFromRedshiftCluster(t *testing.T) { }) t.Run("missing endpoint", func(t *testing.T) { - _, err := NewDatabaseFromRedshiftCluster(&redshift.Cluster{ + _, err := NewDatabaseFromRedshiftCluster(&redshifttypes.Cluster{ ClusterIdentifier: aws.String("still-creating"), }) require.Error(t, err) diff --git a/lib/srv/discovery/config_test.go b/lib/srv/discovery/config_test.go index 6954a2036a493..615c1852643a8 100644 --- a/lib/srv/discovery/config_test.go +++ b/lib/srv/discovery/config_test.go @@ -50,6 +50,8 @@ func TestConfigCheckAndSetDefaults(t *testing.T) { cfgChange: func(c *Config) {}, postCheckAndSetDefaultsFunc: func(t *testing.T, c *Config) { require.NotNil(t, c.CloudClients) + require.NotNil(t, c.AWSConfigProvider) + require.NotNil(t, c.AWSDatabaseFetcherFactory) require.NotNil(t, c.Log) require.NotNil(t, c.clock) require.NotNil(t, c.TriggerFetchC) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index ffb4a76353f59..0d3d7304927d3 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -115,6 +115,10 @@ type gcpInstaller interface { type Config struct { // CloudClients is an interface for retrieving cloud clients. CloudClients cloud.Clients + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + // AWSDatabaseFetcherFactory provides AWS database fetchers + AWSDatabaseFetcherFactory *db.AWSFetcherFactory // GetEC2Client gets an AWS EC2 client for the given region. GetEC2Client server.EC2ClientGetter // GetSSMClient gets an AWS SSM client for the given region. @@ -219,6 +223,24 @@ kubernetes matchers are present.`) } c.CloudClients = cloudClients } + if c.AWSConfigProvider == nil { + provider, err := awsconfig.NewCache() + if err != nil { + return trace.Wrap(err, "unable to create AWS config provider cache") + } + c.AWSConfigProvider = provider + } + if c.AWSDatabaseFetcherFactory == nil { + factory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + CloudClients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, + IntegrationCredentialProviderFn: c.getIntegrationCredentialProviderFn(), + }) + if err != nil { + return trace.Wrap(err) + } + c.AWSDatabaseFetcherFactory = factory + } if c.GetEC2Client == nil { c.GetEC2Client = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { cfg, err := c.getAWSConfig(ctx, region, opts...) @@ -290,7 +312,13 @@ kubernetes matchers are present.`) } func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws.Config, error) { - opts = append(opts, awsconfig.WithIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { + opts = append(opts, awsconfig.WithIntegrationCredentialProvider(c.getIntegrationCredentialProviderFn())) + cfg, err := c.AWSConfigProvider.GetConfig(ctx, region, opts...) + return cfg, trace.Wrap(err) +} + +func (c *Config) getIntegrationCredentialProviderFn() awsconfig.IntegrationCredentialProviderFunc { + return func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) if err != nil { return nil, trace.Wrap(err) @@ -308,9 +336,7 @@ func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awscon Region: region, }) return cred, trace.Wrap(err) - })) - cfg, err := awsconfig.GetConfig(ctx, region, opts...) - return cfg, trace.Wrap(err) + } } // Server is a discovery server, used to discover cloud resources for @@ -661,7 +687,7 @@ func (s *Server) databaseFetchersFromMatchers(matchers Matchers, discoveryConfig // AWS awsDatabaseMatchers, _ := splitMatchers(matchers.AWS, db.IsAWSMatcherType) if len(awsDatabaseMatchers) > 0 { - databaseFetchers, err := db.MakeAWSFetchers(s.ctx, s.CloudClients, awsDatabaseMatchers, discoveryConfigName) + databaseFetchers, err := s.AWSDatabaseFetcherFactory.MakeFetchers(s.ctx, awsDatabaseMatchers, discoveryConfigName) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 4bf7685e3cca5..0499412e149a6 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -37,8 +37,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/aws/aws-sdk-go/aws" @@ -46,7 +49,6 @@ import ( "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/redshift" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" @@ -85,6 +87,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" + "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/srv/server" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" libutils "github.com/gravitational/teleport/lib/utils" @@ -647,7 +650,7 @@ func TestDiscoveryServer(t *testing.T) { foundEC2Instances: []ec2types.Instance{}, ssm: &mockSSMClient{}, cloudClients: &cloud.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, EKS: &mocks.EKSMock{ Clusters: []*eks.Cluster{ { @@ -1396,7 +1399,7 @@ func TestDiscoveryInCloudKube(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - sts := &mocks.STSMock{} + sts := &mocks.STSClientV1{} testCloudClients := &cloud.TestCloudClients{ STS: sts, @@ -1542,7 +1545,7 @@ func TestDiscoveryServer_New(t *testing.T) { }{ { desc: "no matchers error", - cloudClients: &cloud.TestCloudClients{STS: &mocks.STSMock{}}, + cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, matchers: Matchers{}, errAssertion: func(t require.TestingT, err error, i ...interface{}) { require.ErrorIs(t, err, &trace.BadParameterError{Message: "no matchers or discovery group configured for discovery"}) @@ -1551,7 +1554,7 @@ func TestDiscoveryServer_New(t *testing.T) { }, { desc: "success with EKS matcher", - cloudClients: &cloud.TestCloudClients{STS: &mocks.STSMock{}, EKS: &mocks.EKSMock{}}, + cloudClients: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}, EKS: &mocks.EKSMock{}}, matchers: Matchers{ AWS: []types.AWSMatcher{ { @@ -1576,7 +1579,7 @@ func TestDiscoveryServer_New(t *testing.T) { { desc: "EKS fetcher is skipped on initialization error (missing region)", cloudClients: &cloud.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, EKS: &mocks.EKSMock{}, }, matchers: Matchers{ @@ -1965,7 +1968,7 @@ func TestDiscoveryDatabase(t *testing.T) { } testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, RDS: &mocks.RDSMock{ DBInstances: []*rds.DBInstance{awsRDSInstance}, DBEngineVersions: []*rds.DBEngineVersion{ @@ -1973,9 +1976,6 @@ func TestDiscoveryDatabase(t *testing.T) { }, }, MemoryDB: &mocks.MemoryDBMock{}, - Redshift: &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{awsRedshiftResource}, - }, AzureRedis: azure.NewRedisClientByAPI(&azure.ARMRedisMock{ Servers: []*armredis.ResourceInfo{azRedisResource}, }), @@ -1987,6 +1987,18 @@ func TestDiscoveryDatabase(t *testing.T) { Clusters: []*eks.Cluster{eksAWSResource}, }, } + fakeConfigProvider := &mocks.AWSConfigProvider{} + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: fakeConfigProvider, + CloudClients: testCloudClients, + IntegrationCredentialProviderFn: func(_ context.Context, _, _ string) (awsv2.CredentialsProvider, error) { + return credentials.NewStaticCredentialsProvider("key", "secret", "session"), nil + }, + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, + }), + }) + require.NoError(t, err) tcs := []struct { name string @@ -2297,6 +2309,8 @@ func TestDiscoveryDatabase(t *testing.T) { &Config{ IntegrationOnlyCredentials: integrationOnlyCredential, CloudClients: testCloudClients, + AWSDatabaseFetcherFactory: dbFetcherFactory, + AWSConfigProvider: fakeConfigProvider, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), @@ -2377,7 +2391,7 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1", rewriteDiscoveryLabelsParams{discoveryConfigName: dc2Name, discoveryGroup: mainDiscoveryGroup}) testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, RDS: &mocks.RDSMock{ DBInstances: []*rds.DBInstance{awsRDSInstance}, DBEngineVersions: []*rds.DBEngineVersion{ @@ -2561,15 +2575,15 @@ func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteD return instance, database } -func makeRedshiftCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*redshift.Cluster, types.Database) { +func makeRedshiftCluster(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*redshifttypes.Cluster, types.Database) { t.Helper() - cluster := &redshift.Cluster{ + cluster := &redshifttypes.Cluster{ ClusterIdentifier: aws.String(name), ClusterNamespaceArn: aws.String(fmt.Sprintf("arn:aws:redshift:%s:123456789012:namespace:%s", region, name)), ClusterStatus: aws.String("available"), - Endpoint: &redshift.Endpoint{ + Endpoint: &redshifttypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5439), + Port: aws.Int32(5439), }, } @@ -3619,3 +3633,9 @@ func newPopulatedGCPProjectsMock() *mockProjectsAPI { }, } } + +func newFakeRedshiftClientProvider(c redshift.DescribeClustersAPIClient) db.RedshiftClientProviderFunc { + return func(cfg awsv2.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { + return c + } +} diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index 9ccf26f82b397..f87e0e9a6c443 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -23,14 +23,21 @@ import ( "fmt" "log/slog" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// maxAWSPages is the maximum number of pages to iterate over when fetching aws +// databases. +const maxAWSPages = 10 + // awsFetcherPlugin defines an interface that provides database type specific // functions for use by the common AWS database fetcher. type awsFetcherPlugin interface { @@ -46,6 +53,11 @@ type awsFetcherPlugin interface { type awsFetcherConfig struct { // AWSClients are the AWS API clients. AWSClients cloud.AWSClients + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + // IntegrationCredentialProviderFn is a required function that provides + // credentials via AWS OIDC integration. + IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc // Type is the type of DB matcher, for example "rds", "redshift", etc. Type string // AssumeRole provides a role ARN and ExternalID to assume an AWS role @@ -64,6 +76,9 @@ type awsFetcherConfig struct { // Might be empty when the fetcher is using static matchers: // ie teleport.yaml/discovery_service.. DiscoveryConfigName string + + // redshiftClientProviderFn provides an AWS Redshift client. + redshiftClientProviderFn RedshiftClientProviderFunc } // CheckAndSetDefaults validates the config and sets defaults. @@ -71,6 +86,9 @@ func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { if cfg.AWSClients == nil { return trace.BadParameter("missing parameter AWSClients") } + if cfg.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } if cfg.Type == "" { return trace.BadParameter("missing parameter Type") } @@ -93,6 +111,12 @@ func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { "credentials", credentialsSource, ) } + + if cfg.redshiftClientProviderFn == nil { + cfg.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { + return redshift.NewFromConfig(cfg, optFns...) + } + } return nil } @@ -179,7 +203,3 @@ func (f *awsFetcher) String() string { return fmt.Sprintf("awsFetcher(Type: %v, Region=%v, Labels=%v)", f.cfg.Type, f.cfg.Region, f.cfg.Labels) } - -// maxAWSPages is the maximum number of pages to iterate over when fetching aws -// databases. -const maxAWSPages = 10 diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 7b4b1bfb35315..ccf1e8a7146dc 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -21,17 +21,25 @@ package db import ( "context" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// RedshiftClientProviderFunc provides a [RedshiftClient]. +type RedshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient + +// RedshiftClient is a subset of the AWS Redshift API. +type RedshiftClient interface { + redshift.DescribeClustersAPIClient +} + // newRedshiftFetcher returns a new AWS fetcher for Redshift databases. func newRedshiftFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { return newAWSFetcher(cfg, &redshiftPlugin{}) @@ -42,32 +50,33 @@ type redshiftPlugin struct{} // GetDatabases returns Redshift databases matching the watcher's selectors. func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - redshiftClient, err := cfg.AWSClients.GetAWSRedshiftClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), + awsconfig.WithIntegrationCredentialProvider(cfg.IntegrationCredentialProviderFn), ) if err != nil { return nil, trace.Wrap(err) } - clusters, err := getRedshiftClusters(ctx, redshiftClient) + clusters, err := getRedshiftClusters(ctx, cfg.redshiftClientProviderFn(awsCfg)) if err != nil { return nil, trace.Wrap(err) } var databases types.Databases for _, cluster := range clusters { - if !libcloudaws.IsRedshiftClusterAvailable(cluster) { + if !libcloudaws.IsRedshiftClusterAvailable(&cluster) { cfg.Logger.DebugContext(ctx, "Skipping unavailable Redshift cluster", - "cluster", aws.StringValue(cluster.ClusterIdentifier), - "status", aws.StringValue(cluster.ClusterStatus), + "cluster", aws.ToString(cluster.ClusterIdentifier), + "status", aws.ToString(cluster.ClusterStatus), ) continue } - database, err := common.NewDatabaseFromRedshiftCluster(cluster) + database, err := common.NewDatabaseFromRedshiftCluster(&cluster) if err != nil { cfg.Logger.InfoContext(ctx, "Could not convert Redshift cluster to database resource", - "cluster", aws.StringValue(cluster.ClusterIdentifier), + "cluster", aws.ToString(cluster.ClusterIdentifier), "error", err, ) continue @@ -84,17 +93,21 @@ func (f *redshiftPlugin) ComponentShortName() string { // getRedshiftClusters fetches all Reshift clusters using the provided client, // up to the specified max number of pages -func getRedshiftClusters(ctx context.Context, redshiftClient redshiftiface.RedshiftAPI) ([]*redshift.Cluster, error) { - var clusters []*redshift.Cluster +func getRedshiftClusters(ctx context.Context, clt redshift.DescribeClustersAPIClient) ([]redshifttypes.Cluster, error) { + var clusters []redshifttypes.Cluster var pageNum int - err := redshiftClient.DescribeClustersPagesWithContext( - ctx, + pager := redshift.NewDescribeClustersPaginator(clt, &redshift.DescribeClustersInput{}, - func(page *redshift.DescribeClustersOutput, lastPage bool) bool { - pageNum++ - clusters = append(clusters, page.Clusters...) - return pageNum <= maxAWSPages + func(dcpo *redshift.DescribeClustersPaginatorOptions) { + dcpo.StopOnDuplicateToken = false }, ) - return clusters, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + for pageNum <= maxAWSPages && pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, libcloudaws.ConvertRequestFailureErrorV2(err) + } + clusters = append(clusters, page.Clusters...) + } + return clusters, nil } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index 76fb7898db578..ded47035e96e3 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -21,16 +21,21 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/redshift" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" + redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +func newFakeRedshiftClientProvider(c RedshiftClient) RedshiftClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { + return c + } +} func TestRedshiftFetcher(t *testing.T) { t.Parallel() @@ -42,30 +47,30 @@ func TestRedshiftFetcher(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - Redshift: &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, + }), }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUse1Dev}, }, { name: "fetch prod", - inputClients: &cloud.TestCloudClients{ - Redshift: &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, + }), }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", envProdLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod}, }, { name: "skip unavailable", - inputClients: &cloud.TestCloudClients{ - Redshift: &mocks.RedshiftMock{ - Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Unavailable, redshiftUse1UnknownStatus}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Unavailable, *redshiftUse1UnknownStatus}, + }), }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUnknownStatus}, @@ -74,18 +79,18 @@ func TestRedshiftFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshift.Cluster)) (*redshift.Cluster, types.Database) { +func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshifttypes.Cluster)) (*redshifttypes.Cluster, types.Database) { cluster := mocks.RedshiftCluster(env, region, map[string]string{"env": env}, opts...) - database, err := common.NewDatabaseFromRedshiftCluster(cluster) + database, err := common.NewDatabaseFromRedshiftCluster(&cluster) require.NoError(t, err) common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRedshift) - return cluster, database + return &cluster, database } // withRedshiftStatus returns an option function for makeRedshiftCluster to overwrite status. -func withRedshiftStatus(status string) func(*redshift.Cluster) { - return func(cluster *redshift.Cluster) { +func withRedshiftStatus(status string) func(*redshifttypes.Cluster) { + return func(cluster *redshifttypes.Cluster) { cluster.ClusterStatus = aws.String(status) } } diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index c3c3cbeec6dbe..3ef56532d90af 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -22,11 +22,14 @@ import ( "context" "log/slog" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "golang.org/x/exp/maps" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -64,8 +67,53 @@ func IsAzureMatcherType(matcherType string) bool { return len(makeAzureFetcherFuncs[matcherType]) > 0 } -// MakeAWSFetchers creates new AWS database fetchers. -func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []types.AWSMatcher, discoveryConfigName string) (result []common.Fetcher, err error) { +// AWSFetcherFactoryConfig is the configuration for an [AWSFetcherFactory]. +type AWSFetcherFactoryConfig struct { + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + // CloudClients is an interface for retrieving AWS SDK v1 cloud clients. + CloudClients cloud.AWSClients + // IntegrationCredentialProviderFn is an optional function that provides + // credentials via AWS OIDC integration. + IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc + // RedshiftClientProviderFn is an optional function that provides + RedshiftClientProviderFn RedshiftClientProviderFunc +} + +func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error { + if c.CloudClients == nil { + return trace.BadParameter("missing CloudClients") + } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } + if c.RedshiftClientProviderFn == nil { + c.RedshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { + return redshift.NewFromConfig(cfg, optFns...) + } + } + return nil +} + +// AWSFetcherFactory makes AWS database fetchers. +type AWSFetcherFactory struct { + cfg AWSFetcherFactoryConfig +} + +// NewAWSFetcherFactory checks the given config and returns a new fetcher +// provider. +func NewAWSFetcherFactory(cfg AWSFetcherFactoryConfig) (*AWSFetcherFactory, error) { + if err := cfg.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &AWSFetcherFactory{ + cfg: cfg, + }, nil +} + +// MakeFetchers returns AWS database fetchers for each matcher given. +func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.AWSMatcher, discoveryConfigName string) ([]common.Fetcher, error) { + var result []common.Fetcher for _, matcher := range matchers { assumeRole := types.AssumeRole{} if matcher.AssumeRole != nil { @@ -80,13 +128,16 @@ func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []t for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { fetcher, err := makeFetcher(awsFetcherConfig{ - AWSClients: clients, - Type: matcherType, - AssumeRole: assumeRole, - Labels: matcher.Tags, - Region: region, - Integration: matcher.Integration, - DiscoveryConfigName: discoveryConfigName, + AWSClients: f.cfg.CloudClients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + Integration: matcher.Integration, + DiscoveryConfigName: discoveryConfigName, + AWSConfigProvider: f.cfg.AWSConfigProvider, + IntegrationCredentialProviderFn: f.cfg.IntegrationCredentialProviderFn, + redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/discovery/fetchers/db/helpers_test.go b/lib/srv/discovery/fetchers/db/helpers_test.go index 6063198b71e6d..5feae42c7b367 100644 --- a/lib/srv/discovery/fetchers/db/helpers_test.go +++ b/lib/srv/discovery/fetchers/db/helpers_test.go @@ -53,10 +53,12 @@ func makeAWSMatchersForType(matcherType, region string, tags map[string]string) }} } -func mustMakeAWSFetchers(t *testing.T, clients cloud.AWSClients, matchers []types.AWSMatcher, discoveryConfigName string) []common.Fetcher { +func mustMakeAWSFetchers(t *testing.T, cfg AWSFetcherFactoryConfig, matchers []types.AWSMatcher, discoveryConfigName string) []common.Fetcher { t.Helper() - fetchers, err := MakeAWSFetchers(context.Background(), clients, matchers, discoveryConfigName) + fetcherFactory, err := NewAWSFetcherFactory(cfg) + require.NoError(t, err) + fetchers, err := fetcherFactory.MakeFetchers(context.Background(), matchers, discoveryConfigName) require.NoError(t, err) require.NotEmpty(t, fetchers) @@ -111,6 +113,7 @@ var testAssumeRole = types.AssumeRole{ type awsFetcherTest struct { name string inputClients *cloud.TestCloudClients + fetcherCfg AWSFetcherFactoryConfig inputMatchers []types.AWSMatcher wantDatabases types.Databases } @@ -121,22 +124,30 @@ func testAWSFetchers(t *testing.T, tests ...awsFetcherTest) { t.Helper() for _, test := range tests { test := test - require.Nil(t, test.inputClients.STS, "testAWSFetchers injects an STS mock itself, but test input had already configured it. This is a test configuration error.") - stsMock := &mocks.STSMock{} - test.inputClients.STS = stsMock + fakeSTS := &mocks.STSClient{} + if test.inputClients != nil { + require.Nil(t, test.inputClients.STS, "testAWSFetchers injects an STS mock itself, but test input had already configured it. This is a test configuration error.") + test.inputClients.STS = &fakeSTS.STSClientV1 + } + test.fetcherCfg.CloudClients = test.inputClients + require.Nil(t, test.fetcherCfg.AWSConfigProvider, "testAWSFetchers injects a fake AWSConfigProvider, but the test input had already configured it. This is a test configuration error.") + test.fetcherCfg.AWSConfigProvider = &mocks.AWSConfigProvider{ + STSClient: fakeSTS, + } t.Run(test.name, func(t *testing.T) { t.Helper() - fetchers := mustMakeAWSFetchers(t, test.inputClients, test.inputMatchers, "" /* discovery config */) + fetchers := mustMakeAWSFetchers(t, test.fetcherCfg, test.inputMatchers, "" /* discovery config */) require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) }) t.Run(test.name+" with assume role", func(t *testing.T) { t.Helper() + fakeSTS.ResetAssumeRoleHistory() matchers := copyAWSMatchersWithAssumeRole(testAssumeRole, test.inputMatchers...) wantDBs := copyDatabasesWithAWSAssumeRole(testAssumeRole, test.wantDatabases...) - fetchers := mustMakeAWSFetchers(t, test.inputClients, matchers, "" /* discovery config */) + fetchers := mustMakeAWSFetchers(t, test.fetcherCfg, matchers, "" /* discovery config */) require.ElementsMatch(t, wantDBs, mustGetDatabases(t, fetchers)) - require.Equal(t, []string{testAssumeRole.RoleARN}, stsMock.GetAssumedRoleARNs()) - require.Equal(t, []string{testAssumeRole.ExternalID}, stsMock.GetAssumedRoleExternalIDs()) + require.Equal(t, []string{testAssumeRole.RoleARN}, fakeSTS.GetAssumedRoleARNs()) + require.Equal(t, []string{testAssumeRole.ExternalID}, fakeSTS.GetAssumedRoleExternalIDs()) }) } } diff --git a/lib/srv/discovery/kube_integration_watcher_test.go b/lib/srv/discovery/kube_integration_watcher_test.go index bde45665b2a28..423339678ae8d 100644 --- a/lib/srv/discovery/kube_integration_watcher_test.go +++ b/lib/srv/discovery/kube_integration_watcher_test.go @@ -56,19 +56,19 @@ import ( func TestServer_getKubeFetchers(t *testing.T) { eks1, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSMock{}}, + ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", }) require.NoError(t, err) eks2, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSMock{}}, + ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", Integration: "aws1"}) require.NoError(t, err) eks3, err := fetchers.NewEKSFetcher(fetchers.EKSFetcherConfig{ - ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSMock{}}, + ClientGetter: &cloud.TestCloudClients{STS: &mocks.STSClientV1{}}, FilterLabels: types.Labels{"l1": []string{"v1"}}, Region: "region1", Integration: "aws1"}) @@ -314,7 +314,7 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { t.Parallel() testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSMock{}, + STS: &mocks.STSClientV1{}, EKS: &mockEKSAPI{ clusters: eksMockClusters[:2], },