Skip to content

Commit

Permalink
Support empty OAuth2 inline secrets (prometheus#547)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Hrabovcak <[email protected]>
  • Loading branch information
TheSpiritXIII authored Jan 30, 2024
1 parent bd0376d commit a3bdb9e
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 66 deletions.
7 changes: 3 additions & 4 deletions config/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ func SerialNumber() *big.Int {
serialNumber.Add(&serial, big.NewInt(1))

return &serial

}

func GenerateCertificateAuthority(commonName string, parentCert *x509.Certificate, parentKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) {
Expand Down Expand Up @@ -170,7 +169,7 @@ func writeCertificateAndKey(path string, cert *x509.Certificate, key *rsa.Privat
return err
}

if err := os.WriteFile(fmt.Sprintf("%s.crt", path), b.Bytes(), 0644); err != nil {
if err := os.WriteFile(fmt.Sprintf("%s.crt", path), b.Bytes(), 0o644); err != nil {
return err
}

Expand All @@ -179,7 +178,7 @@ func writeCertificateAndKey(path string, cert *x509.Certificate, key *rsa.Privat
return err
}

if err := os.WriteFile(fmt.Sprintf("%s.key", path), b.Bytes(), 0644); err != nil {
if err := os.WriteFile(fmt.Sprintf("%s.key", path), b.Bytes(), 0o644); err != nil {
return err
}

Expand Down Expand Up @@ -239,7 +238,7 @@ func main() {
log.Fatal(err)
}

if err := os.WriteFile("testdata/tls-ca-chain.pem", b.Bytes(), 0644); err != nil {
if err := os.WriteFile("testdata/tls-ca-chain.pem", b.Bytes(), 0o644); err != nil {
log.Fatal(err)
}
}
10 changes: 3 additions & 7 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ func (c *HTTPClientConfig) Validate() error {
if len(c.OAuth2.ClientID) == 0 {
return fmt.Errorf("oauth2 client_id must be configured")
}
if len(c.OAuth2.ClientSecret) == 0 && len(c.OAuth2.ClientSecretFile) == 0 {
return fmt.Errorf("either oauth2 client_secret or client_secret_file must be configured")
}
if len(c.OAuth2.TokenURL) == 0 {
return fmt.Errorf("oauth2 token_url must be configured")
}
Expand Down Expand Up @@ -729,13 +726,12 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
rt.mtx.RLock()
changed = secret != rt.secret
rt.mtx.RUnlock()
} else {
// Either an inline secret or nothing (use an empty string) was provided.
secret = string(rt.config.ClientSecret)
}

if changed || rt.rt == nil {
if rt.config.ClientSecret != "" {
secret = string(rt.config.ClientSecret)
}

config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Expand Down
226 changes: 174 additions & 52 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ var invalidHTTPClientConfigs = []struct {
httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml",
errMsg: "oauth2 client_id must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.oauth2-no-client-secret.bad.yaml",
errMsg: "either oauth2 client_secret or client_secret_file must be configured",
},
{
httpClientConfigFile: "testdata/http.conf.oauth2-no-token-url.bad.yaml",
errMsg: "oauth2 token_url must be configured",
Expand Down Expand Up @@ -423,6 +419,46 @@ func TestNewClientFromConfig(t *testing.T) {
}
},
},
{
clientConfig: HTTPClientConfig{
OAuth2: &OAuth2{
ClientID: "ExpectedUsername",
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false,
},
},
TLSConfig: TLSConfig{
CAFile: TLSCAChainPath,
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false,
},
},
handler: func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: ExpectedAccessToken,
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)

default:
authorization := r.Header.Get("Authorization")
if authorization != "Bearer "+ExpectedAccessToken {
fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization)
} else {
fmt.Fprint(w, ExpectedMessage)
}
}
},
},
{
clientConfig: HTTPClientConfig{
OAuth2: &OAuth2{
Expand Down Expand Up @@ -1448,38 +1484,81 @@ type oauth2TestServerResponse struct {
TokenType string `json:"token_type"`
}

func TestOAuth2(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type testOAuthServer struct {
tokenTS *httptest.Server
ts *httptest.Server
}

// newTestOAuthServer returns a new test server with the expected base64 encoded client ID and secret.
func newTestOAuthServer(t testing.TB, expectedAuth *string) testOAuthServer {
var previousAuth string
tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != *expectedAuth {
t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth)
}
if auth == previousAuth {
t.Fatal("token endpoint called twice")
}
previousAuth = auth
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)
}))
defer ts.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer 12345" {
t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth)
}
fmt.Fprintln(w, "Hello, client")
}))
return testOAuthServer{
tokenTS: tokenTS,
ts: ts,
}
}

func (s *testOAuthServer) url() string {
return s.ts.URL
}

func (s *testOAuthServer) tokenURL() string {
return s.tokenTS.URL
}

func (s *testOAuthServer) close() {
s.tokenTS.Close()
s.ts.Close()
}

func TestOAuth2(t *testing.T) {
var expectedAuth string
ts := newTestOAuthServer(t, &expectedAuth)
defer ts.close()

yamlConfig := fmt.Sprintf(`
client_id: 1
client_secret: 2
scopes:
- A
- B
token_url: %s/token
token_url: %s
endpoint_params:
hi: hello
`, ts.URL)
`, ts.tokenURL())
expectedConfig := OAuth2{
ClientID: "1",
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: fmt.Sprintf("%s/token", ts.URL),
TokenURL: ts.tokenURL(),
}

var unmarshalledConfig OAuth2
err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig)
if err != nil {
if err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig); err != nil {
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
}
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
Expand All @@ -1491,9 +1570,59 @@ endpoint_params:
client := http.Client{
Transport: rt,
}
resp, _ := client.Get(ts.URL)

// Default secret.
expectedAuth = "Basic MToy"
resp, err := client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer', got '%s'", authorization)
}

// Making a second request with the same secret should not re-call the token API.
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

// Empty secret.
expectedAuth = "Basic MTo="
expectedConfig.ClientSecret = ""
resp, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

authorization = resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}

// Making a second request with the same secret should not re-call the token API.
resp, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

// Update secret.
expectedAuth = "Basic MToxMjM0NTY3"
expectedConfig.ClientSecret = "1234567"
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

// Making a second request with the same secret should not re-call the token API.
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

authorization = resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}
Expand Down Expand Up @@ -1543,33 +1672,9 @@ func TestOAuth2UserAgent(t *testing.T) {
}

func TestOAuth2WithFile(t *testing.T) {
var expectedAuth *string
var previousAuth string
tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != *expectedAuth {
t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth)
}
if auth == previousAuth {
t.Fatal("token endpoint called twice")
}
previousAuth = auth
res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)
}))
defer tokenTS.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer 12345" {
t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth)
}
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()
var expectedAuth string
ts := newTestOAuthServer(t, &expectedAuth)
defer ts.close()

secretFile, err := os.CreateTemp("", "oauth2_secret")
if err != nil {
Expand All @@ -1586,13 +1691,13 @@ scopes:
token_url: %s
endpoint_params:
hi: hello
`, secretFile.Name(), tokenTS.URL)
`, secretFile.Name(), ts.tokenURL())
expectedConfig := OAuth2{
ClientID: "1",
ClientSecretFile: secretFile.Name(),
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: tokenTS.URL,
TokenURL: ts.tokenURL(),
}

var unmarshalledConfig OAuth2
Expand All @@ -1610,40 +1715,57 @@ endpoint_params:
Transport: rt,
}

tk := "Basic MToxMjM0NTY="
expectedAuth = &tk
// Empty secret file.
expectedAuth = "Basic MTo="
resp, err := client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer', got '%s'", authorization)
}

// Making a second request with the same file content should not re-call the token API.
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

// File populated.
expectedAuth = "Basic MToxMjM0NTY="
if _, err := secretFile.Write([]byte("123456")); err != nil {
t.Fatal(err)
}
resp, err := client.Get(ts.URL)
resp, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
authorization = resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}

// Making a second request with the same file content should not re-call the token API.
resp, err = client.Get(ts.URL)
resp, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

tk = "Basic MToxMjM0NTY3"
expectedAuth = &tk
// Update file.
expectedAuth = "Basic MToxMjM0NTY3"
if _, err := secretFile.Write([]byte("7")); err != nil {
t.Fatal(err)
}

_, err = client.Get(ts.URL)
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}

// Making a second request with the same file content should not re-call the token API.
_, err = client.Get(ts.URL)
_, err = client.Get(ts.url())
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 0 additions & 3 deletions config/testdata/http.conf.oauth2-no-client-secret.bad.yaml

This file was deleted.

0 comments on commit a3bdb9e

Please sign in to comment.