Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mirror registries behind mutual TLS #35658

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 112 additions & 9 deletions internal/command/cliconfig/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,69 @@ func collectCredentialsFromEnv() map[svchost.Hostname]string {
return ret
}

func collectMTLSCredentialsFromEnv() map[svchost.Hostname]*svcauth.HostCredentialsMTLS {
// Prefixes for environment variables
const (
prefixCert = "TF_CLIENT_CERT_"
prefixKey = "TF_CLIENT_KEY_"
prefixCACert = "TF_CA_CERT_"
)

ret := make(map[svchost.Hostname]*svcauth.HostCredentialsMTLS)

for _, ev := range os.Environ() {
eqIdx := strings.Index(ev, "=")
if eqIdx < 0 {
continue
}
name := ev[:eqIdx]
value := ev[eqIdx+1:]

var rawHost, cert, key, caCert string

// Determine the type of credential and parse accordingly
if strings.HasPrefix(name, prefixCert) {
rawHost = name[len(prefixCert):]
cert = value
} else if strings.HasPrefix(name, prefixKey) {
rawHost = name[len(prefixKey):]
key = value
} else if strings.HasPrefix(name, prefixCACert) {
rawHost = name[len(prefixCACert):]
caCert = value
} else {
continue
}

// Normalize the hostname
rawHost = strings.ReplaceAll(rawHost, "__", "-")
rawHost = strings.ReplaceAll(rawHost, "_", ".")
dispHost := svchost.ForDisplay(rawHost)
hostname, err := svchost.ForComparison(dispHost)
if err != nil {
continue
}

// Retrieve or create the entry in the map
if ret[hostname] == nil {
ret[hostname] = &svcauth.HostCredentialsMTLS{}
}

// Update the corresponding fields
if cert != "" {
ret[hostname].ClientCert = cert
}
if key != "" {
ret[hostname].ClientKey = key
}
if caCert != "" {
ret[hostname].CACertificate = caCert
}
}

return ret
}

// hostCredentialsFromEnv returns a token credential by searching for a hostname-specific
// environment variable. The host parameter is expected to be in the "comparison" form,
// for example, hostnames containing non-ASCII characters like "café.fr"
Expand All @@ -183,11 +246,26 @@ func collectCredentialsFromEnv() map[svchost.Hostname]string {
// For the example "café.fr", you may use the variable names "TF_TOKEN_xn____caf__dma_fr",
// "TF_TOKEN_xn--caf-dma_fr", or "TF_TOKEN_xn--caf-dma.fr"
func hostCredentialsFromEnv(host svchost.Hostname) svcauth.HostCredentials {
token, ok := collectCredentialsFromEnv()[host]
if !ok {
return nil
token, tokenOk := collectCredentialsFromEnv()[host]
mtlsCreds, mtlsOk := collectMTLSCredentialsFromEnv()[host]

// If both mTLS and token are found, combine them
if mtlsOk && tokenOk {
mtlsCreds.TokenValue = token
return mtlsCreds
}

// If only mTLS credentials are found
if mtlsOk {
return mtlsCreds
}
return svcauth.HostCredentialsToken(token)

// If only token credentials are found
if tokenOk {
return svcauth.HostCredentialsToken(token)
}

return nil
}

// CredentialsSource is an implementation of svcauth.CredentialsSource
Expand Down Expand Up @@ -230,20 +308,45 @@ type CredentialsSource struct {
var _ svcauth.CredentialsSource = (*CredentialsSource)(nil)

func (s *CredentialsSource) ForHost(host svchost.Hostname) (svcauth.HostCredentials, error) {
var token string

// The first order of precedence for credentials is a host-specific environment variable
if envCreds := hostCredentialsFromEnv(host); envCreds != nil {
return envCreds, nil
token = envCreds.Token()
if mtls, ok := envCreds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}

// Then, any credentials block present in the CLI config
v, ok := s.configured[host]
if ok {
return svcauth.HostCredentialsFromObject(v), nil
if v, ok := s.configured[host]; ok {
creds := svcauth.HostCredentialsFromObject(v)
if creds != nil {
token = creds.Token()
if mtls, ok := creds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}
}

// And finally, the credentials helper
if s.helper != nil {
return s.helper.ForHost(host)
creds, err := s.helper.ForHost(host)
if err != nil {
return nil, err
}
if creds != nil {
token = creds.Token()
if mtls, ok := creds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}
}

return nil, nil
Expand Down
214 changes: 214 additions & 0 deletions internal/command/cliconfig/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
package cliconfig

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/zclconf/go-cty/cty"
Expand Down Expand Up @@ -473,3 +481,209 @@ func (s *mockCredentialsHelper) ForgetForHost(hostname svchost.Hostname) error {
delete(s.current, hostname)
return nil
}

func TestMTLSCredentialsForHost(t *testing.T) {
certFile, keyFile, caCertFile, err := generateSelfSignedCert(t)
if err != nil {
t.Fatalf("failed to generate self-signed certs: %v", err)
}

credSrc := &CredentialsSource{
configured: map[svchost.Hostname]cty.Value{
"configured.example.com": cty.ObjectVal(map[string]cty.Value{}),
"only-mtls.example.com": cty.ObjectVal(map[string]cty.Value{}),
},
}

testReqMTLSAuthHeader := func(t *testing.T, creds svcauth.HostCredentials) *http.Request {
t.Helper()

if creds == nil {
t.Fatal("No credentials found")
}

req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
t.Fatalf("cannot construct HTTP request: %s", err)
}
creds.PrepareRequest(req)
return req
}

t.Run("mtls credentials from environment", func(t *testing.T) {
t.Setenv("TF_CLIENT_CERT_configured_example_com", certFile)
t.Setenv("TF_CLIENT_KEY_configured_example_com", keyFile)
t.Setenv("TF_CA_CERT_configured_example_com", caCertFile)
t.Setenv("TF_TOKEN_configured_example_com", "configured-token")

creds, err := credSrc.ForHost(svchost.Hostname("configured.example.com"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if creds == nil {
t.Fatal("no credentials found")
}

mtlsCreds, ok := creds.(svcauth.HostCredentialsExtended)
if !ok {
t.Fatal("expected mTLS credentials")
}

// Verify that the TLS configuration is correct
tlsConfig, err := mtlsCreds.GetTLSConfig()
if err != nil {
t.Fatalf("failed to get TLS config: %s", err)
}

if len(tlsConfig.Certificates) == 0 {
t.Fatal("expected at least one certificate in TLS config")
}

// Check if CA certificate is loaded correctly
if tlsConfig.RootCAs == nil {
t.Fatal("expected RootCAs to be set")
}

// Check if the token is correctly set as an authorization header
req := testReqMTLSAuthHeader(t, mtlsCreds)
if got, want := req.Header.Get("Authorization"), "Bearer configured-token"; got != want {
t.Errorf("wrong token header\ngot: %s\nwant: %s", got, want)
}
})

t.Run("mtls credentials without token", func(t *testing.T) {
t.Setenv("TF_CLIENT_CERT_only__mtls_example_com", certFile)
t.Setenv("TF_CLIENT_KEY_only__mtls_example_com", keyFile)
t.Setenv("TF_CA_CERT_only__mtls_example_com", caCertFile)
creds, err := credSrc.ForHost(svchost.Hostname("only-mtls.example.com"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if creds == nil {
t.Fatal("no credentials found")
}

mtlsCreds, ok := creds.(svcauth.HostCredentialsExtended)
if !ok {
t.Fatal("expected mTLS credentials")
}

// Verify that the TLS configuration is correct
tlsConfig, err := mtlsCreds.GetTLSConfig()
if err != nil {
t.Fatalf("failed to get TLS config: %s", err)
}

if len(tlsConfig.Certificates) == 0 {
t.Fatal("expected at least one certificate in TLS config")
}

// Check if CA certificate is loaded correctly
if tlsConfig.RootCAs == nil {
t.Fatal("expected RootCAs to be set")
}

// Since there's no token, the Authorization header should be empty
req := testReqMTLSAuthHeader(t, mtlsCreds)
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("expected empty authorization header, got: %s", got)
}
})
}

// generateSelfSignedCert generates a self-signed certificate and private key
// and writes them to temporary files. It also generates a CA certificate.
func generateSelfSignedCert(t *testing.T) (certFile, keyFile, caCertFile string, err error) {
t.Helper()

// Generate a private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}

// Create a CA certificate template
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}

// Self-sign the CA certificate
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create CA certificate: %v", err)
}

// Create a client certificate template
clientTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Client Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
IsCA: false,
}

// Sign the client certificate with the CA certificate
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create client certificate: %v", err)
}

// Encode CA certificate to PEM
caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})

// Encode client certificate to PEM
clientCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})

// Encode private key to PEM
keyPEM, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
t.Fatalf("failed to marshal private key: %v", err)
}
keyPEMBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM})

// Write the CA certificate to a temporary file
caCertTempFile, err := os.CreateTemp("", "ca_cert.pem")
if err != nil {
t.Fatalf("failed to create CA cert temp file: %v", err)
}
defer caCertTempFile.Close()
if _, err := caCertTempFile.Write(caCertPEM); err != nil {
t.Fatalf("failed to write CA cert to file: %v", err)
}

// Write the client certificate to a temporary file
certTempFile, err := os.CreateTemp("", "cert.pem")
if err != nil {
t.Fatalf("failed to create cert temp file: %v", err)
}
defer certTempFile.Close()
if _, err := certTempFile.Write(clientCertPEM); err != nil {
t.Fatalf("failed to write cert to file: %v", err)
}

// Write the private key to a temporary file
keyTempFile, err := os.CreateTemp("", "key.pem")
if err != nil {
t.Fatalf("failed to create key temp file: %v", err)
}
defer keyTempFile.Close()
if _, err := keyTempFile.Write(keyPEMBytes); err != nil {
t.Fatalf("failed to write key to file: %v", err)
}

return certTempFile.Name(), keyTempFile.Name(), caCertTempFile.Name(), nil
}
Loading