Skip to content

Commit

Permalink
Add support for mirror registries behind mutual TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
416e64726579 committed Aug 30, 2024
1 parent 2770cb3 commit e1dba72
Show file tree
Hide file tree
Showing 7 changed files with 583 additions and 9 deletions.
121 changes: 112 additions & 9 deletions internal/command/cliconfig/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,69 @@ func collectCredentialsFromEnv() map[svchost.Hostname]string {
return ret
}

func collectMTLSCredentialsFromEnv() map[svchost.Hostname]*svcauth.HostCredentialsMTLS {
// Prefixes for environment variables
const (
prefixCert = "TF_CLIENT_CERT_"
prefixKey = "TF_CLIENT_KEY_"
prefixCACert = "TF_CA_CERT_"
)

ret := make(map[svchost.Hostname]*svcauth.HostCredentialsMTLS)

for _, ev := range os.Environ() {
eqIdx := strings.Index(ev, "=")
if eqIdx < 0 {
continue
}
name := ev[:eqIdx]
value := ev[eqIdx+1:]

var rawHost, cert, key, caCert string

// Determine the type of credential and parse accordingly
if strings.HasPrefix(name, prefixCert) {
rawHost = name[len(prefixCert):]
cert = value
} else if strings.HasPrefix(name, prefixKey) {
rawHost = name[len(prefixKey):]
key = value
} else if strings.HasPrefix(name, prefixCACert) {
rawHost = name[len(prefixCACert):]
caCert = value
} else {
continue
}

// Normalize the hostname
rawHost = strings.ReplaceAll(rawHost, "__", "-")
rawHost = strings.ReplaceAll(rawHost, "_", ".")
dispHost := svchost.ForDisplay(rawHost)
hostname, err := svchost.ForComparison(dispHost)
if err != nil {
continue
}

// Retrieve or create the entry in the map
if ret[hostname] == nil {
ret[hostname] = &svcauth.HostCredentialsMTLS{}
}

// Update the corresponding fields
if cert != "" {
ret[hostname].ClientCert = cert
}
if key != "" {
ret[hostname].ClientKey = key
}
if caCert != "" {
ret[hostname].CACertificate = caCert
}
}

return ret
}

// hostCredentialsFromEnv returns a token credential by searching for a hostname-specific
// environment variable. The host parameter is expected to be in the "comparison" form,
// for example, hostnames containing non-ASCII characters like "café.fr"
Expand All @@ -183,11 +246,26 @@ func collectCredentialsFromEnv() map[svchost.Hostname]string {
// For the example "café.fr", you may use the variable names "TF_TOKEN_xn____caf__dma_fr",
// "TF_TOKEN_xn--caf-dma_fr", or "TF_TOKEN_xn--caf-dma.fr"
func hostCredentialsFromEnv(host svchost.Hostname) svcauth.HostCredentials {
token, ok := collectCredentialsFromEnv()[host]
if !ok {
return nil
token, tokenOk := collectCredentialsFromEnv()[host]
mtlsCreds, mtlsOk := collectMTLSCredentialsFromEnv()[host]

// If both mTLS and token are found, combine them
if mtlsOk && tokenOk {
mtlsCreds.TokenValue = token
return mtlsCreds
}

// If only mTLS credentials are found
if mtlsOk {
return mtlsCreds
}
return svcauth.HostCredentialsToken(token)

// If only token credentials are found
if tokenOk {
return svcauth.HostCredentialsToken(token)
}

return nil
}

// CredentialsSource is an implementation of svcauth.CredentialsSource
Expand Down Expand Up @@ -230,20 +308,45 @@ type CredentialsSource struct {
var _ svcauth.CredentialsSource = (*CredentialsSource)(nil)

func (s *CredentialsSource) ForHost(host svchost.Hostname) (svcauth.HostCredentials, error) {
var token string

// The first order of precedence for credentials is a host-specific environment variable
if envCreds := hostCredentialsFromEnv(host); envCreds != nil {
return envCreds, nil
token = envCreds.Token()
if mtls, ok := envCreds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}

// Then, any credentials block present in the CLI config
v, ok := s.configured[host]
if ok {
return svcauth.HostCredentialsFromObject(v), nil
if v, ok := s.configured[host]; ok {
creds := svcauth.HostCredentialsFromObject(v)
if creds != nil {
token = creds.Token()
if mtls, ok := creds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}
}

// And finally, the credentials helper
if s.helper != nil {
return s.helper.ForHost(host)
creds, err := s.helper.ForHost(host)
if err != nil {
return nil, err
}
if creds != nil {
token = creds.Token()
if mtls, ok := creds.(svcauth.HostCredentialsExtended); ok {
mtls.SetToken(token)
return mtls, nil
}
return svcauth.HostCredentialsToken(token), nil
}
}

return nil, nil
Expand Down
214 changes: 214 additions & 0 deletions internal/command/cliconfig/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
package cliconfig

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/zclconf/go-cty/cty"
Expand Down Expand Up @@ -473,3 +481,209 @@ func (s *mockCredentialsHelper) ForgetForHost(hostname svchost.Hostname) error {
delete(s.current, hostname)
return nil
}

func TestMTLSCredentialsForHost(t *testing.T) {
certFile, keyFile, caCertFile, err := generateSelfSignedCert(t)
if err != nil {
t.Fatalf("failed to generate self-signed certs: %v", err)
}

credSrc := &CredentialsSource{
configured: map[svchost.Hostname]cty.Value{
"configured.example.com": cty.ObjectVal(map[string]cty.Value{}),
"only-mtls.example.com": cty.ObjectVal(map[string]cty.Value{}),
},
}

testReqMTLSAuthHeader := func(t *testing.T, creds svcauth.HostCredentials) *http.Request {
t.Helper()

if creds == nil {
t.Fatal("No credentials found")
}

req, err := http.NewRequest("GET", "http://example.com/", nil)
if err != nil {
t.Fatalf("cannot construct HTTP request: %s", err)
}
creds.PrepareRequest(req)
return req
}

t.Run("mtls credentials from environment", func(t *testing.T) {
t.Setenv("TF_CLIENT_CERT_configured_example_com", certFile)
t.Setenv("TF_CLIENT_KEY_configured_example_com", keyFile)
t.Setenv("TF_CA_CERT_configured_example_com", caCertFile)
t.Setenv("TF_TOKEN_configured_example_com", "configured-token")

creds, err := credSrc.ForHost(svchost.Hostname("configured.example.com"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if creds == nil {
t.Fatal("no credentials found")
}

mtlsCreds, ok := creds.(svcauth.HostCredentialsExtended)
if !ok {
t.Fatal("expected mTLS credentials")
}

// Verify that the TLS configuration is correct
tlsConfig, err := mtlsCreds.GetTLSConfig()
if err != nil {
t.Fatalf("failed to get TLS config: %s", err)
}

if len(tlsConfig.Certificates) == 0 {
t.Fatal("expected at least one certificate in TLS config")
}

// Check if CA certificate is loaded correctly
if tlsConfig.RootCAs == nil {
t.Fatal("expected RootCAs to be set")
}

// Check if the token is correctly set as an authorization header
req := testReqMTLSAuthHeader(t, mtlsCreds)
if got, want := req.Header.Get("Authorization"), "Bearer configured-token"; got != want {
t.Errorf("wrong token header\ngot: %s\nwant: %s", got, want)
}
})

t.Run("mtls credentials without token", func(t *testing.T) {
t.Setenv("TF_CLIENT_CERT_only__mtls_example_com", certFile)
t.Setenv("TF_CLIENT_KEY_only__mtls_example_com", keyFile)
t.Setenv("TF_CA_CERT_only__mtls_example_com", caCertFile)
creds, err := credSrc.ForHost(svchost.Hostname("only-mtls.example.com"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if creds == nil {
t.Fatal("no credentials found")
}

mtlsCreds, ok := creds.(svcauth.HostCredentialsExtended)
if !ok {
t.Fatal("expected mTLS credentials")
}

// Verify that the TLS configuration is correct
tlsConfig, err := mtlsCreds.GetTLSConfig()
if err != nil {
t.Fatalf("failed to get TLS config: %s", err)
}

if len(tlsConfig.Certificates) == 0 {
t.Fatal("expected at least one certificate in TLS config")
}

// Check if CA certificate is loaded correctly
if tlsConfig.RootCAs == nil {
t.Fatal("expected RootCAs to be set")
}

// Since there's no token, the Authorization header should be empty
req := testReqMTLSAuthHeader(t, mtlsCreds)
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("expected empty authorization header, got: %s", got)
}
})
}

// generateSelfSignedCert generates a self-signed certificate and private key
// and writes them to temporary files. It also generates a CA certificate.
func generateSelfSignedCert(t *testing.T) (certFile, keyFile, caCertFile string, err error) {
t.Helper()

// Generate a private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}

// Create a CA certificate template
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}

// Self-sign the CA certificate
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create CA certificate: %v", err)
}

// Create a client certificate template
clientTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Client Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
IsCA: false,
}

// Sign the client certificate with the CA certificate
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create client certificate: %v", err)
}

// Encode CA certificate to PEM
caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})

// Encode client certificate to PEM
clientCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})

// Encode private key to PEM
keyPEM, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
t.Fatalf("failed to marshal private key: %v", err)
}
keyPEMBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM})

// Write the CA certificate to a temporary file
caCertTempFile, err := os.CreateTemp("", "ca_cert.pem")
if err != nil {
t.Fatalf("failed to create CA cert temp file: %v", err)
}
defer caCertTempFile.Close()
if _, err := caCertTempFile.Write(caCertPEM); err != nil {
t.Fatalf("failed to write CA cert to file: %v", err)
}

// Write the client certificate to a temporary file
certTempFile, err := os.CreateTemp("", "cert.pem")
if err != nil {
t.Fatalf("failed to create cert temp file: %v", err)
}
defer certTempFile.Close()
if _, err := certTempFile.Write(clientCertPEM); err != nil {
t.Fatalf("failed to write cert to file: %v", err)
}

// Write the private key to a temporary file
keyTempFile, err := os.CreateTemp("", "key.pem")
if err != nil {
t.Fatalf("failed to create key temp file: %v", err)
}
defer keyTempFile.Close()
if _, err := keyTempFile.Write(keyPEMBytes); err != nil {
t.Fatalf("failed to write key to file: %v", err)
}

return certTempFile.Name(), keyTempFile.Name(), caCertTempFile.Name(), nil
}
Loading

0 comments on commit e1dba72

Please sign in to comment.