diff --git a/config/generate.go b/config/generate.go index 0033dd75..dbcf9be5 100644 --- a/config/generate.go +++ b/config/generate.go @@ -71,7 +71,6 @@ func SerialNumber() *big.Int { serialNumber.Add(&serial, big.NewInt(1)) return &serial - } func GenerateCertificateAuthority(commonName string, parentCert *x509.Certificate, parentKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) { @@ -170,7 +169,7 @@ func writeCertificateAndKey(path string, cert *x509.Certificate, key *rsa.Privat return err } - if err := os.WriteFile(fmt.Sprintf("%s.crt", path), b.Bytes(), 0644); err != nil { + if err := os.WriteFile(fmt.Sprintf("%s.crt", path), b.Bytes(), 0o644); err != nil { return err } @@ -179,7 +178,7 @@ func writeCertificateAndKey(path string, cert *x509.Certificate, key *rsa.Privat return err } - if err := os.WriteFile(fmt.Sprintf("%s.key", path), b.Bytes(), 0644); err != nil { + if err := os.WriteFile(fmt.Sprintf("%s.key", path), b.Bytes(), 0o644); err != nil { return err } @@ -239,7 +238,7 @@ func main() { log.Fatal(err) } - if err := os.WriteFile("testdata/tls-ca-chain.pem", b.Bytes(), 0644); err != nil { + if err := os.WriteFile("testdata/tls-ca-chain.pem", b.Bytes(), 0o644); err != nil { log.Fatal(err) } } diff --git a/config/http_config.go b/config/http_config.go index 7a67a0a6..f295e917 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -378,9 +378,6 @@ func (c *HTTPClientConfig) Validate() error { if len(c.OAuth2.ClientID) == 0 { return fmt.Errorf("oauth2 client_id must be configured") } - if len(c.OAuth2.ClientSecret) == 0 && len(c.OAuth2.ClientSecretFile) == 0 { - return fmt.Errorf("either oauth2 client_secret or client_secret_file must be configured") - } if len(c.OAuth2.TokenURL) == 0 { return fmt.Errorf("oauth2 token_url must be configured") } @@ -729,13 +726,12 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro rt.mtx.RLock() changed = secret != rt.secret rt.mtx.RUnlock() + } else { + // Either an inline secret or nothing (use an empty string) was provided. + secret = string(rt.config.ClientSecret) } if changed || rt.rt == nil { - if rt.config.ClientSecret != "" { - secret = string(rt.config.ClientSecret) - } - config := &clientcredentials.Config{ ClientID: rt.config.ClientID, ClientSecret: secret, diff --git a/config/http_config_test.go b/config/http_config_test.go index 7eeedfed..b0d3939f 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -115,10 +115,6 @@ var invalidHTTPClientConfigs = []struct { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", errMsg: "oauth2 client_id must be configured", }, - { - httpClientConfigFile: "testdata/http.conf.oauth2-no-client-secret.bad.yaml", - errMsg: "either oauth2 client_secret or client_secret_file must be configured", - }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-token-url.bad.yaml", errMsg: "oauth2 token_url must be configured", @@ -423,6 +419,46 @@ func TestNewClientFromConfig(t *testing.T) { } }, }, + { + clientConfig: HTTPClientConfig{ + OAuth2: &OAuth2{ + ClientID: "ExpectedUsername", + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false, + }, + }, + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false, + }, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: ExpectedAccessToken, + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + + default: + authorization := r.Header.Get("Authorization") + if authorization != "Bearer "+ExpectedAccessToken { + fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization) + } else { + fmt.Fprint(w, ExpectedMessage) + } + } + }, + }, { clientConfig: HTTPClientConfig{ OAuth2: &OAuth2{ @@ -1448,8 +1484,23 @@ type oauth2TestServerResponse struct { TokenType string `json:"token_type"` } -func TestOAuth2(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +type testOAuthServer struct { + tokenTS *httptest.Server + ts *httptest.Server +} + +// newTestOAuthServer returns a new test server with the expected base64 encoded client ID and secret. +func newTestOAuthServer(t testing.TB, expectedAuth *string) testOAuthServer { + var previousAuth string + tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != *expectedAuth { + t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth) + } + if auth == previousAuth { + t.Fatal("token endpoint called twice") + } + previousAuth = auth res, _ := json.Marshal(oauth2TestServerResponse{ AccessToken: "12345", TokenType: "Bearer", @@ -1457,7 +1508,36 @@ func TestOAuth2(t *testing.T) { w.Header().Add("Content-Type", "application/json") _, _ = w.Write(res) })) - defer ts.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer 12345" { + t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth) + } + fmt.Fprintln(w, "Hello, client") + })) + return testOAuthServer{ + tokenTS: tokenTS, + ts: ts, + } +} + +func (s *testOAuthServer) url() string { + return s.ts.URL +} + +func (s *testOAuthServer) tokenURL() string { + return s.tokenTS.URL +} + +func (s *testOAuthServer) close() { + s.tokenTS.Close() + s.ts.Close() +} + +func TestOAuth2(t *testing.T) { + var expectedAuth string + ts := newTestOAuthServer(t, &expectedAuth) + defer ts.close() yamlConfig := fmt.Sprintf(` client_id: 1 @@ -1465,21 +1545,20 @@ client_secret: 2 scopes: - A - B -token_url: %s/token +token_url: %s endpoint_params: hi: hello -`, ts.URL) +`, ts.tokenURL()) expectedConfig := OAuth2{ ClientID: "1", ClientSecret: "2", Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, - TokenURL: fmt.Sprintf("%s/token", ts.URL), + TokenURL: ts.tokenURL(), } var unmarshalledConfig OAuth2 - err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) - if err != nil { + if err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig); err != nil { t.Fatalf("Expected no error unmarshalling yaml, got %v", err) } if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { @@ -1491,9 +1570,59 @@ endpoint_params: client := http.Client{ Transport: rt, } - resp, _ := client.Get(ts.URL) + + // Default secret. + expectedAuth = "Basic MToy" + resp, err := client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } authorization := resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer', got '%s'", authorization) + } + + // Making a second request with the same secret should not re-call the token API. + _, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + // Empty secret. + expectedAuth = "Basic MTo=" + expectedConfig.ClientSecret = "" + resp, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + authorization = resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) + } + + // Making a second request with the same secret should not re-call the token API. + resp, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + // Update secret. + expectedAuth = "Basic MToxMjM0NTY3" + expectedConfig.ClientSecret = "1234567" + _, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + // Making a second request with the same secret should not re-call the token API. + _, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + authorization = resp.Request.Header.Get("Authorization") if authorization != "Bearer 12345" { t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } @@ -1543,33 +1672,9 @@ func TestOAuth2UserAgent(t *testing.T) { } func TestOAuth2WithFile(t *testing.T) { - var expectedAuth *string - var previousAuth string - tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if auth != *expectedAuth { - t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth) - } - if auth == previousAuth { - t.Fatal("token endpoint called twice") - } - previousAuth = auth - res, _ := json.Marshal(oauth2TestServerResponse{ - AccessToken: "12345", - TokenType: "Bearer", - }) - w.Header().Add("Content-Type", "application/json") - _, _ = w.Write(res) - })) - defer tokenTS.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if auth != "Bearer 12345" { - t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth) - } - fmt.Fprintln(w, "Hello, client") - })) - defer ts.Close() + var expectedAuth string + ts := newTestOAuthServer(t, &expectedAuth) + defer ts.close() secretFile, err := os.CreateTemp("", "oauth2_secret") if err != nil { @@ -1586,13 +1691,13 @@ scopes: token_url: %s endpoint_params: hi: hello -`, secretFile.Name(), tokenTS.URL) +`, secretFile.Name(), ts.tokenURL()) expectedConfig := OAuth2{ ClientID: "1", ClientSecretFile: secretFile.Name(), Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, - TokenURL: tokenTS.URL, + TokenURL: ts.tokenURL(), } var unmarshalledConfig OAuth2 @@ -1610,40 +1715,57 @@ endpoint_params: Transport: rt, } - tk := "Basic MToxMjM0NTY=" - expectedAuth = &tk + // Empty secret file. + expectedAuth = "Basic MTo=" + resp, err := client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + authorization := resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer', got '%s'", authorization) + } + + // Making a second request with the same file content should not re-call the token API. + _, err = client.Get(ts.url()) + if err != nil { + t.Fatal(err) + } + + // File populated. + expectedAuth = "Basic MToxMjM0NTY=" if _, err := secretFile.Write([]byte("123456")); err != nil { t.Fatal(err) } - resp, err := client.Get(ts.URL) + resp, err = client.Get(ts.url()) if err != nil { t.Fatal(err) } - authorization := resp.Request.Header.Get("Authorization") + authorization = resp.Request.Header.Get("Authorization") if authorization != "Bearer 12345" { t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } // Making a second request with the same file content should not re-call the token API. - resp, err = client.Get(ts.URL) + resp, err = client.Get(ts.url()) if err != nil { t.Fatal(err) } - tk = "Basic MToxMjM0NTY3" - expectedAuth = &tk + // Update file. + expectedAuth = "Basic MToxMjM0NTY3" if _, err := secretFile.Write([]byte("7")); err != nil { t.Fatal(err) } - - _, err = client.Get(ts.URL) + _, err = client.Get(ts.url()) if err != nil { t.Fatal(err) } // Making a second request with the same file content should not re-call the token API. - _, err = client.Get(ts.URL) + _, err = client.Get(ts.url()) if err != nil { t.Fatal(err) } diff --git a/config/testdata/http.conf.oauth2-no-client-secret.bad.yaml b/config/testdata/http.conf.oauth2-no-client-secret.bad.yaml deleted file mode 100644 index 774a5998..00000000 --- a/config/testdata/http.conf.oauth2-no-client-secret.bad.yaml +++ /dev/null @@ -1,3 +0,0 @@ -oauth2: - client_id: "myclientid" - token_url: "http://auth"