Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed May 6, 2024
1 parent 959774b commit 441c2e8
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 11 deletions.
64 changes: 61 additions & 3 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"context"
"errors"
"net/http"
"os"
"path/filepath"
Expand All @@ -13,9 +14,53 @@ import (
"github.com/stretchr/testify/require"
)

var azDummy = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/"}
var azDummyWithResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123"}
var azDummyWitInvalidResourceId = &Config{Host: "https://adb-xyz.c.azuredatabricks.net/", AzureResourceID: "invalidResourceId"}
type mockTransport struct {
resp *http.Response
err error
}

func (m mockTransport) RoundTrip(*http.Request) (*http.Response, error) {
if m.err != nil {
return nil, m.err
}
return m.resp, nil
}

func makeClient(r *http.Response) *http.Client {
return &http.Client{
Transport: mockTransport{resp: r},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}

func makeFailingClient(err error) *http.Client {
return &http.Client{
Transport: mockTransport{err: err},
}
}

var redirectResponse = &http.Response{
StatusCode: 302,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}},
}
var errDummy = errors.New("failed to get login endpoint")

var azDummy = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(redirectResponse),
}
var azDummyWithResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeClient(redirectResponse),
}
var azDummyWitInvalidResourceId = &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureResourceID: "invalidResourceId",
azureTenantIdFetchClient: makeClient(redirectResponse),
}

// testdataPath returns the PATH to use for the duration of a test.
// It must only return absolute directories because Go refuses to run
Expand Down Expand Up @@ -187,6 +232,19 @@ func TestAzureCliCredentials_CorruptExpire(t *testing.T) {
assert.EqualError(t, err, "cannot parse expiry: parsing time \"\" as \"2006-01-02 15:04:05.999999\": cannot parse \"\" as \"2006\"")
}

func TestAzureCliCredentials_DoNotFetchIfTenantIdAlreadySet(t *testing.T) {
env.CleanupEnvironment(t)
os.Setenv("PATH", testdataPath())
aa := AzureCliCredentials{}
_, err := aa.Configure(context.Background(), &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
AzureTenantID: "123",
AzureResourceID: "/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123",
azureTenantIdFetchClient: makeFailingClient(errDummy),
})
assert.NoError(t, err)
}

// TODO: this test should rather be on sequencing
// func TestConfigureWithAzureCLI_SP(t *testing.T) {
// aa := DatabricksClient{
Expand Down
4 changes: 4 additions & 0 deletions config/auth_permutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config,
AzureTenantID: cf.AzureTenantID,
AzureResourceID: cf.AzureResourceID,
AuthType: cf.AuthType,
azureTenantIdFetchClient: makeClient(&http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/tenant_id/abc"}},
}),
}
if client.IsAzure() {
client.DatabricksEnvironment = &DatabricksEnvironment{
Expand Down
11 changes: 3 additions & 8 deletions config/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,14 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol
// Azure Databricks login endpoint). Here, the redirect is not followed, but the
// tenant ID is extracted from the URL.
func (c *Config) loadAzureTenantId(ctx context.Context) error {
if !c.IsAzure() || c.AzureTenantID != "" {
if !c.IsAzure() || c.AzureTenantID != "" || c.Host == "" {
return nil
}
req, err := http.NewRequestWithContext(ctx, "GET", c.Host+"/aad/auth", nil)
req, err := http.NewRequestWithContext(ctx, "GET", c.CanonicalHostName()+"/aad/auth", nil)
if err != nil {
return err
}
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
res, err := client.Do(req)
res, err := c.azureTenantIdFetchClient.Do(req)
if err != nil && !errors.Is(err, http.ErrUseLastResponse) {
return err
}
Expand Down
44 changes: 44 additions & 0 deletions config/azure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package config

import (
"context"
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLoadAzureTenantId(t *testing.T) {
c := &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeClient(&http.Response{
StatusCode: 302,
Header: http.Header{"Location": []string{"https://login.microsoftonline.com/123-abc/oauth2/token"}},
}),
}
err := c.loadAzureTenantId(context.Background())
assert.NoError(t, err)
assert.Equal(t, c.AzureTenantID, "123-abc")
}

func TestLoadAzureTenantId_Failure(t *testing.T) {
testErr := errors.New("Failed to fetch login page")
c := &Config{
Host: "https://adb-xyz.c.azuredatabricks.net/",
azureTenantIdFetchClient: makeFailingClient(testErr),
}
err := c.loadAzureTenantId(context.Background())
assert.ErrorIs(t, err, testErr)
}

func TestLoadAzureTenantId_SkipNotInAzure(t *testing.T) {
testErr := errors.New("Failed to fetch login page")
c := &Config{
Host: "https://test.cloud.databricks.com/",
azureTenantIdFetchClient: makeFailingClient(testErr),
}
err := c.loadAzureTenantId(context.Background())
assert.NoError(t, err)
assert.Empty(t, c.AzureTenantID)
}
11 changes: 11 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ type Config struct {
// internal background context used for authentication purposes together with refreshClient
refreshCtx context.Context

// internal client used to fetch Azure Tenant ID from Databricks Login endpoint
azureTenantIdFetchClient *http.Client

// marker for testing fixture
isTesting bool

Expand Down Expand Up @@ -288,6 +291,14 @@ func (c *Config) EnsureResolved() error {
"rate limit",
},
})
if c.azureTenantIdFetchClient == nil {
c.azureTenantIdFetchClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Do not follow redirects
return http.ErrUseLastResponse
},
}
}
c.resolved = true
return nil
}
Expand Down

0 comments on commit 441c2e8

Please sign in to comment.