Skip to content

Commit

Permalink
feat: allow configuring default cloud environment (Azure#1555)
Browse files Browse the repository at this point in the history
Signed-off-by: Anish Ramasekar <[email protected]>
  • Loading branch information
aramase authored May 29, 2024
1 parent 8296abd commit 4ef5092
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
16 changes: 13 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"syscall"
"time"

"github.com/Azure/go-autorest/autorest/azure"

"github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/metrics"
"github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/server"
"github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/utils"
Expand Down Expand Up @@ -49,6 +51,9 @@ var (
constructPEMChain = flag.Bool("construct-pem-chain", true, "explicitly reconstruct the pem chain in the order: SERVER, INTERMEDIATE, ROOT")
writeCertAndKeyInSeparateFiles = flag.Bool("write-cert-and-key-in-separate-files", false,
"Write cert and key in separate files. The individual files will be named as <secret-name>.crt and <secret-name>.key. These files will be created in addition to the single file.")

cloudName = flag.String("cloud-name", "AzurePublicCloud", "default cloud environment to use for Azure SDK if not provided in the SecretProviderClass. "+
"Allowed values: AzurePublicCloud, AzureUSGovernmentCloud, AzureChinaCloud, AzureGermanCloud or AzureStackCloud")
)

func main() {
Expand All @@ -75,6 +80,12 @@ func main() {
}
klog.InfoS("Starting Azure Key Vault Provider", "version", version.BuildVersion)

cloudEnv, err := azure.EnvironmentFromName(*cloudName)
if err != nil {
klog.ErrorS(err, "failed validating default cloud environment", "cloudName", *cloudName)
os.Exit(1)
}

if *enableProfile {
klog.InfoS("Starting profiling", "port", *profilePort)
go func() {
Expand All @@ -86,8 +97,7 @@ func main() {
}()
}
// initialize metrics exporter before creating measurements
err := metrics.InitMetricsExporter(*metricsBackend, *prometheusPort)
if err != nil {
if err = metrics.InitMetricsExporter(*metricsBackend, *prometheusPort); err != nil {
klog.ErrorS(err, "failed to initialize metrics exporter")
os.Exit(1)
}
Expand Down Expand Up @@ -130,7 +140,7 @@ func main() {
grpc.UnaryInterceptor(utils.LogInterceptor()),
}
s := grpc.NewServer(opts...)
csiDriverProviderServer := server.New(*constructPEMChain, *writeCertAndKeyInSeparateFiles)
csiDriverProviderServer := server.New(*constructPEMChain, *writeCertAndKeyInSeparateFiles, cloudEnv)
k8spb.RegisterCSIDriverProviderServer(s, csiDriverProviderServer)
// Register the health service.
grpc_health_v1.RegisterHealthServer(s, csiDriverProviderServer)
Expand Down
19 changes: 9 additions & 10 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ type provider struct {

constructPEMChain bool
writeCertAndKeyInSeparateFiles bool

defaultCloudEnvironment azure.Environment
}

// mountConfig holds the information for the mount event
type mountConfig struct {
// the name of the Azure Key Vault instance
keyvaultName string
// the type of azure cloud based on azure go sdk
azureCloudEnvironment *azure.Environment
azureCloudEnvironment azure.Environment
// authConfig is the config parameters for accessing Key Vault
authConfig auth.Config
// tenantID in AAD
Expand All @@ -67,24 +69,21 @@ type keyvaultObject struct {
}

// NewProvider creates a new provider
func NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles bool) Interface {
func NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles bool, defaultCloudEnvironment azure.Environment) Interface {
return &provider{
reporter: metrics.NewStatsReporter(),
constructPEMChain: constructPEMChain,
writeCertAndKeyInSeparateFiles: writeCertAndKeyInSeparateFiles,
defaultCloudEnvironment: defaultCloudEnvironment,
}
}

// parseAzureEnvironment returns azure environment by name
func parseAzureEnvironment(cloudName string) (*azure.Environment, error) {
var env azure.Environment
var err error
func (p *provider) parseAzureEnvironment(cloudName string) (azure.Environment, error) {
if cloudName == "" {
env = azure.PublicCloud
} else {
env, err = azure.EnvironmentFromName(cloudName)
return p.defaultCloudEnvironment, nil
}
return &env, err
return azure.EnvironmentFromName(cloudName)
}

func (mc *mountConfig) initializeKvClient(vaultURI string) (KeyVault, error) {
Expand Down Expand Up @@ -148,7 +147,7 @@ func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
if err != nil {
return nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %w", cloudEnvFileName, err)
}
azureCloudEnv, err := parseAzureEnvironment(cloudName)
azureCloudEnv, err := p.parseAzureEnvironment(cloudName)
if err != nil {
return nil, fmt.Errorf("cloudName %s is not valid, error: %w", cloudName, err)
}
Expand Down
16 changes: 10 additions & 6 deletions pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
func TestGetVaultURL(t *testing.T) {
testEnvs := []string{"", "AZUREPUBLICCLOUD", "AZURECHINACLOUD", "AZUREGERMANCLOUD", "AZUREUSGOVERNMENTCLOUD"}
vaultDNSSuffix := []string{"vault.azure.net", "vault.azure.net", "vault.azure.cn", "vault.microsoftazure.de", "vault.usgovcloudapi.net"}
testProvider := provider{defaultCloudEnvironment: azure.PublicCloud}

cases := []struct {
desc string
Expand Down Expand Up @@ -69,7 +70,7 @@ func TestGetVaultURL(t *testing.T) {
}

for idx := range testEnvs {
azCloudEnv, err := parseAzureEnvironment(testEnvs[idx])
azCloudEnv, err := testProvider.parseAzureEnvironment(testEnvs[idx])
if err != nil {
t.Fatalf("Error parsing cloud environment %v", err)
}
Expand All @@ -88,8 +89,10 @@ func TestGetVaultURL(t *testing.T) {

func TestParseAzureEnvironment(t *testing.T) {
envNamesArray := []string{"AZURECHINACLOUD", "AZUREGERMANCLOUD", "AZUREPUBLICCLOUD", "AZUREUSGOVERNMENTCLOUD", ""}
testProvider := provider{defaultCloudEnvironment: azure.PublicCloud}

for _, envName := range envNamesArray {
azureEnv, err := parseAzureEnvironment(envName)
azureEnv, err := testProvider.parseAzureEnvironment(envName)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
Expand All @@ -101,7 +104,7 @@ func TestParseAzureEnvironment(t *testing.T) {
}

wrongEnvName := "AZUREWRONGCLOUD"
_, err := parseAzureEnvironment(wrongEnvName)
_, err := testProvider.parseAzureEnvironment(wrongEnvName)
if err == nil {
t.Fatalf("expected error for wrong azure environment name")
}
Expand Down Expand Up @@ -226,6 +229,7 @@ lKn75l/9h0PwiiPaI0TGKN2O8AwvhGGwDElmFhYtXedbbaST6rbVRDUj
}

func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) {
testProvider := provider{defaultCloudEnvironment: azure.PublicCloud}
azureStackCloudEnvName := "AZURESTACKCLOUD"
file, err := os.CreateTemp("", "ut")
defer os.Remove(file.Name())
Expand All @@ -236,7 +240,7 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) {
if err != nil {
t.Fatalf("expected error to be nil, got: %+v", err)
}
_, err = parseAzureEnvironment(azureStackCloudEnvName)
_, err = testProvider.parseAzureEnvironment(azureStackCloudEnvName)
if err == nil {
t.Fatalf("expected error to be not nil as AZURE_ENVIRONMENT_FILEPATH is not set")
}
Expand All @@ -246,7 +250,7 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) {
if err != nil {
t.Fatalf("expected error to be nil, got: %+v", err)
}
env, err := parseAzureEnvironment(azureStackCloudEnvName)
env, err := testProvider.parseAzureEnvironment(azureStackCloudEnvName)
if err != nil {
t.Fatalf("expected error to be nil, got: %+v", err)
}
Expand Down Expand Up @@ -1250,7 +1254,7 @@ func TestGetSecretsStoreObjectContent(t *testing.T) {

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
p := NewProvider(false, false)
p := NewProvider(false, false, azure.PublicCloud)

_, err := p.GetSecretsStoreObjectContent(testContext(t), tc.parameters, tc.secrets, 0420)
if tc.expectedErr {
Expand Down
6 changes: 4 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"os"

"github.com/Azure/go-autorest/autorest/azure"

"github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider"
"github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/version"

Expand All @@ -24,9 +26,9 @@ type CSIDriverProviderServer struct {
}

// New returns an instance of CSIDriverProviderServer
func New(constructPEMChain, writeCertAndKeyInSeparateFiles bool) *CSIDriverProviderServer {
func New(constructPEMChain, writeCertAndKeyInSeparateFiles bool, defaultCloudEnvironment azure.Environment) *CSIDriverProviderServer {
return &CSIDriverProviderServer{
provider: provider.NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles),
provider: provider.NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles, defaultCloudEnvironment),
}
}

Expand Down

0 comments on commit 4ef5092

Please sign in to comment.