Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] GitHub in house OIDC #977

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
28 changes: 1 addition & 27 deletions config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)

Expand All @@ -28,7 +27,7 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
return nil, nil
}

idToken, err := requestIDToken(ctx, cfg)
idToken, err := cfg.getAllOIDCSuppliers().GetOIDCToken(ctx, "api://AzureADTokenExchange")
if err != nil {
return nil, err
}
Expand All @@ -47,31 +46,6 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
return credentials.NewOAuthCredentialsProvider(refreshableVisitor(ts), ts.Token), nil
}

// requestIDToken requests an ID token from the Github Action.
func requestIDToken(ctx context.Context, cfg *Config) (string, error) {
if cfg.ActionsIDTokenRequestURL == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestURL, likely not calling from a Github action")
return "", nil
}
if cfg.ActionsIDTokenRequestToken == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestToken, likely not calling from a Github action")
return "", nil
}

resp := struct { // anonymous struct to parse the response
Value string `json:"value"`
}{}
err := cfg.refreshClient.Do(ctx, "GET", fmt.Sprintf("%s&audience=api://AzureADTokenExchange", cfg.ActionsIDTokenRequestURL),
httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", cfg.ActionsIDTokenRequestToken)),
httpclient.WithResponseUnmarshal(&resp),
)
if err != nil {
return "", fmt.Errorf("failed to request ID token from %s: %w", cfg.ActionsIDTokenRequestURL, err)
}

return resp.Value, nil
}

// azureOIDCTokenSource implements [oauth2.TokenSource] to obtain Azure auth
// tokens from an ID token.
type azureOIDCTokenSource struct {
Expand Down
65 changes: 65 additions & 0 deletions config/auth_databricks_oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package config

import (
"context"
"net/url"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

const jwtBearerGrantTypeURN = "urn:ietf:params:oauth:grant-type:jwt-bearer"

type DatabricksOIDCCredentials struct{}

// Configure implements CredentialsStrategy.
func (d DatabricksOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Host == "" || cfg.ClientID == "" {
return nil, nil
}

// Get the OIDC token from the environment.
// TODO: trim the first 8 characters (https://) from the host
audience := cfg.CanonicalHostName()
if cfg.IsAccountClient() {
audience = cfg.AccountID
}
idToken, err := cfg.getAllOIDCSuppliers().GetOIDCToken(ctx, audience)
if err != nil {
return nil, err
}
if idToken == "" {
logger.Debugf(ctx, "No OIDC token found")
return nil, nil
}

endpoints, err := oidcEndpoints(ctx, cfg)
if err != nil {
return nil, err
}

tsConfig := clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: "",
AuthStyle: oauth2.AuthStyleInParams,
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
EndpointParams: url.Values{
"grant_type": {jwtBearerGrantTypeURN},
"assertion": {idToken},
},
}
ts := tsConfig.TokenSource(httpclient.WithDebug(ctx, true))
visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

// Name implements CredentialsStrategy.
func (d DatabricksOIDCCredentials) Name() string {
return "inhouse-oidc"
}

var _ CredentialsStrategy = DatabricksOIDCCredentials{}
1 change: 1 addition & 0 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ var authProviders = []CredentialsStrategy{
PatCredentials{},
BasicCredentials{},
M2mCredentials{},
DatabricksOIDCCredentials{},
DatabricksCliCredentials{},
MetadataServiceCredentials{},

Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (c *Config) IsAccountClient() bool {
return true
}
}
return false
return strings.HasPrefix(c.Host, "https://accounts-")
}

func (c *Config) EnsureResolved() error {
Expand Down
99 changes: 99 additions & 0 deletions config/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package config

import (
"context"
"fmt"
"os"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

type oidcTokenSupplier interface {
Name() string

// GetOIDCToken returns an OIDC token for the given audience.
GetOIDCToken(ctx context.Context, audience string) (string, error)
}

type githubOIDCTokenSupplier struct {
idTokenRequestURL string
idTokenRequestToken string
client *httpclient.ApiClient
}

func githubOIDCTokenSupplierFromConfig(cfg *Config) githubOIDCTokenSupplier {
return githubOIDCTokenSupplier{
idTokenRequestURL: cfg.ActionsIDTokenRequestURL,
idTokenRequestToken: cfg.ActionsIDTokenRequestToken,
client: cfg.refreshClient,
}
}

func (g githubOIDCTokenSupplier) Name() string {
return "github"
}

// requestIDToken requests an ID token from the Github Action.
func (g githubOIDCTokenSupplier) GetOIDCToken(ctx context.Context, audience string) (string, error) {
if g.idTokenRequestURL == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestURL, likely not calling from a Github action")
return "", nil
}
if g.idTokenRequestToken == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestToken, likely not calling from a Github action")
return "", nil
}
url := g.idTokenRequestURL
if audience != "" {
url = fmt.Sprintf("%s&audience=%s", url, audience)
}
resp := struct { // anonymous struct to parse the response
Value string `json:"value"`
}{}
err := g.client.Do(ctx, "GET", url,
httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", g.idTokenRequestToken)),
httpclient.WithResponseUnmarshal(&resp),
)
if err != nil {
return "", fmt.Errorf("failed to request ID token from %s: %w", g.idTokenRequestURL, err)
}

return resp.Value, nil
}

var _ oidcTokenSupplier = githubOIDCTokenSupplier{}

type azureDevOpsOIDCTokenSupplier struct{}

func (a azureDevOpsOIDCTokenSupplier) Name() string {
return "azure-devops"
}

func (a azureDevOpsOIDCTokenSupplier) GetOIDCToken(ctx context.Context, audience string) (string, error) {
return os.Getenv("idToken"), nil
}

type oidcTokenSuppliers []oidcTokenSupplier

func (c *Config) getAllOIDCSuppliers() oidcTokenSuppliers {
return []oidcTokenSupplier{
githubOIDCTokenSupplierFromConfig(c),
azureDevOpsOIDCTokenSupplier{},
}
}

func (o oidcTokenSuppliers) GetOIDCToken(ctx context.Context, audience string) (string, error) {
for _, s := range o {
token, err := s.GetOIDCToken(ctx, audience)
if err != nil {
return "", err
}
if token != "" {
logger.Debugf(ctx, "OIDC token found from %s", s.Name())
return token, nil
}
logger.Debugf(ctx, "No OIDC token found from %s", s.Name())
}
return "", nil
}
29 changes: 25 additions & 4 deletions httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -327,16 +328,36 @@ func (c *ApiClient) recordRequestLog(
logger.Debugf(ctx, "%s", message)
}

type debugKeyType int

const debugKey debugKeyType = 1

func WithDebug(ctx context.Context, debug bool) context.Context {
return context.WithValue(ctx, debugKey, debug)
}

func IsDebug(ctx context.Context) bool {
debug, ok := ctx.Value(debugKey).(bool)
return ok && debug
}

func getDebugBody(ctx context.Context, body io.Reader) (io.Reader, []byte) {
if IsDebug(ctx) {
debugBytes, _ := io.ReadAll(body)
return strings.NewReader(string(debugBytes)), debugBytes
}
return body, []byte("<http.RoundTripper>")
}

// RoundTrip implements http.RoundTripper to integrate with golang.org/x/oauth2
func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) {
ctx := request.Context()
requestURL := request.URL.String()
body, debugBytes := getDebugBody(ctx, request.Body)
resp, err := retries.Poll(ctx, c.config.RetryTimeout,
c.attempt(ctx, request.Method, requestURL, common.RequestBody{
Reader: request.Body,
// DO NOT DECODE BODY, because it may contain sensitive payload,
// like Azure Service Principal in a multipart/form-data body.
DebugBytes: []byte("<http.RoundTripper>"),
Reader: body,
DebugBytes: debugBytes,
}, func(r *http.Request) error {
r.Header = request.Header
return nil
Expand Down
Loading