diff --git a/auth.go b/auth.go deleted file mode 100644 index a1d2938..0000000 --- a/auth.go +++ /dev/null @@ -1,221 +0,0 @@ -package crproxy - -import ( - "bytes" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "hash" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/docker/distribution/registry/api/errcode" -) - -func (c *CRProxy) AuthToken(rw http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) - return - } - if !c.simpleAuth { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) - return - } - query := r.URL.Query() - scope := query.Get("scope") - service := query.Get("service") - - if c.simpleAuthUserpassFunc != nil { - authorization := r.Header.Get("Authorization") - auth := strings.SplitN(authorization, " ", 2) - if len(auth) != 2 { - if c.logger != nil { - c.logger.Println("Login failed", authorization) - } - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) - return - } - switch auth[0] { - case "Basic": - user, pass, ok := parseBasicAuth(auth[1]) - if user == "" || pass == "" { - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) - return - } - - var u *url.Userinfo - if ok { - u = url.UserPassword(user, pass) - } else { - u = url.User(user) - } - if !c.simpleAuthUserpassFunc(r, u) { - if c.logger != nil { - c.logger.Println("Login failed user and password", u) - } - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) - return - } - - if c.logger != nil { - c.logger.Println("Login succeed user", u.Username()) - } - default: - if c.logger != nil { - c.logger.Println("Unsupported authorization", authorization) - } - errcode.ServeJSON(rw, errcode.ErrorCodeDenied) - return - } - } - - rw.Header().Set("Content-Type", "application/json") - - now := time.Now() - expiresIn := 60 - token := defaultTokenManager.Encode(Token{ - Service: service, - Scope: scope, - ExpiresAt: now.Add(time.Duration(expiresIn) * time.Second), - }) - - json.NewEncoder(rw).Encode(tokenInfo{ - Token: token, - ExpiresIn: int64(expiresIn), - IssuedAt: now, - }) -} - -func (c *CRProxy) authenticate(rw http.ResponseWriter, r *http.Request) { - tokenURL := c.tokenURL - if tokenURL == "" { - var scheme = "http" - if c.tokenAuthForceTLS || r.TLS != nil || r.URL.Scheme == "https" { - scheme = "https" - } - tokenURL = scheme + "://" + r.Host + "/auth/token" - } - header := fmt.Sprintf("Bearer realm=%q,service=%q", tokenURL, r.Host) - rw.Header().Set("WWW-Authenticate", header) - c.errorResponse(rw, r, errcode.ErrorCodeUnauthorized) -} - -func (c *CRProxy) authorization(rw http.ResponseWriter, r *http.Request) bool { - if c.privilegedNoAuth && c.isPrivileged(r, nil) { - r.Header.Del("Authorization") - return true - } - - auth := r.Header.Get("Authorization") - if auth == "" { - return false - } - - if !strings.HasPrefix(auth, "Bearer ") { - return false - } - - token, ok := defaultTokenManager.Decode(auth[7:]) - if !ok { - return false - } - - if token.ExpiresAt.Before(time.Now()) { - return false - } - - r.Header.Del("Authorization") - return true -} - -type tokenInfo struct { - Token string `json:"token,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - IssuedAt time.Time `json:"issued_at,omitempty"` -} - -var defaultTokenManager = &tokenManager{ - NewHash: sha256.New, - RandReader: rand.Reader, - HashSize: sha256.Size, - RandSize: 16, - EncodeToString: base64.RawURLEncoding.EncodeToString, - DecodeString: base64.RawURLEncoding.DecodeString, -} - -type tokenManager struct { - NewHash func() hash.Hash - RandReader io.Reader - HashSize int - RandSize int - EncodeToString func([]byte) string - DecodeString func(string) ([]byte, error) -} - -type Token struct { - ExpiresAt time.Time `json:"expires_at,omitempty"` - Scope string `json:"scope,omitempty"` - Service string `json:"service,omitempty"` -} - -func (p *tokenManager) Encode(t Token) (code string) { - sum := make([]byte, p.RandSize+p.HashSize) - io.ReadFull(p.RandReader, sum[:p.RandSize]) - hashSum := p.NewHash() - data, _ := json.Marshal(t) - hashSum.Write(data) - hashSum.Write(sum[:p.RandSize]) - sum = hashSum.Sum(sum[:p.RandSize]) - return p.EncodeToString(sum) + "." + p.EncodeToString(data) -} - -func (p *tokenManager) Decode(code string) (t Token, b bool) { - cs := strings.Split(code, ".") - if len(cs) != 2 { - return t, false - } - - sum, err := p.DecodeString(cs[0]) - if err != nil { - return t, false - } - if len(sum) != p.HashSize+p.RandSize { - return t, false - } - data, err := p.DecodeString(cs[1]) - if err != nil { - return t, false - } - hashSum := p.NewHash() - hashSum.Write(data) - hashSum.Write(sum[:p.RandSize]) - newSum := hashSum.Sum(nil) - if !bytes.Equal(sum[p.RandSize:], newSum) { - return t, false - } - - err = json.Unmarshal(data, &t) - if err != nil { - return t, false - } - - return t, true -} - -func parseBasicAuth(auth string) (username, password string, ok bool) { - c, err := base64.StdEncoding.DecodeString(auth) - if err != nil { - return "", "", false - } - cs := string(c) - username, password, ok = strings.Cut(cs, ":") - if !ok { - return "", "", false - } - return username, password, true -} diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index 926edfd..3159a0f 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "crypto/rsa" "encoding/csv" "fmt" "io" @@ -31,7 +32,10 @@ import ( _ "github.com/docker/distribution/registry/storage/driver/s3-aws" "github.com/daocloud/crproxy" + "github.com/daocloud/crproxy/internal/pki" "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/signing" + "github.com/daocloud/crproxy/token" ) var ( @@ -63,8 +67,8 @@ var ( overrideDefaultRegistry map[string]string simpleAuth bool simpleAuthUserpass map[string]string + simpleAuthAllowAnonymous bool tokenURL string - tokenAuthForceTLS bool redirectOriginBlobLinks bool @@ -80,6 +84,9 @@ var ( allowHeadMethod bool manifestCacheDuration time.Duration + + tokenPrivateKeyFile string + tokenPublicKeyFile string ) func init() { @@ -111,8 +118,9 @@ func init() { pflag.StringToStringVar(&overrideDefaultRegistry, "override-default-registry", nil, "override default registry") pflag.BoolVar(&simpleAuth, "simple-auth", false, "enable simple auth") pflag.StringToStringVar(&simpleAuthUserpass, "simple-auth-user", nil, "simple auth user and password") - pflag.StringVar(&tokenURL, "token-url", "", "token url (deprecated)") - pflag.BoolVar(&tokenAuthForceTLS, "token-auth-force-tls", false, "token auth force TLS (deprecated)") + pflag.BoolVar(&simpleAuthAllowAnonymous, "simple-auth-allow-anonymous", false, "simple auth allow anonymous") + + pflag.StringVar(&tokenURL, "token-url", "", "token url") pflag.BoolVar(&redirectOriginBlobLinks, "redirect-origin-blob-links", false, "redirect origin blob links") @@ -126,6 +134,9 @@ func init() { pflag.BoolVar(&allowHeadMethod, "allow-head-method", false, "allow head method") pflag.DurationVar(&manifestCacheDuration, "manifest-cache-duration", 0, "manifest cache duration") + + pflag.StringVar(&tokenPrivateKeyFile, "token-private-key-file", "", "private key file") + pflag.StringVar(&tokenPublicKeyFile, "token-public-key-file", "", "public key file") pflag.Parse() } @@ -459,25 +470,82 @@ func main() { opts = append(opts, crproxy.WithOverrideDefaultRegistry(overrideDefaultRegistry)) } - if simpleAuth { - opts = append(opts, crproxy.WithSimpleAuth(true, tokenURL, tokenAuthForceTLS)) - } - if len(simpleAuthUserpass) != 0 { + var auth func(r *http.Request, userinfo *url.Userinfo) (token.Attribute, bool) - opts = append(opts, crproxy.WithSimpleAuthUserFunc(func(r *http.Request, userinfo *url.Userinfo) bool { + if len(simpleAuthUserpass) != 0 { + auth = func(r *http.Request, userinfo *url.Userinfo) (token.Attribute, bool) { + if userinfo == nil { + return token.Attribute{}, simpleAuthAllowAnonymous + } pass, ok := simpleAuthUserpass[userinfo.Username()] if !ok { - return false + return token.Attribute{}, false } upass, ok := userinfo.Password() if !ok { - return false + return token.Attribute{}, false } if upass != pass { - return false + return token.Attribute{}, false } - return true - })) + return token.Attribute{ + NoRateLimit: true, + NoAllowlist: true, + AllowTagsList: true, + }, true + } + } + + if simpleAuth { + var privateKey *rsa.PrivateKey + var publicKey *rsa.PublicKey + if tokenPrivateKeyFile == "" && tokenPublicKeyFile == "" { + k, err := pki.GenerateKey() + if err != nil { + logger.Println("failed to GenerateKey:", err) + os.Exit(1) + } + privateKey = k + publicKey = &k.PublicKey + } else { + if tokenPrivateKeyFile != "" { + privateKeyData, err := os.ReadFile(tokenPrivateKeyFile) + if err != nil { + logger.Println("failed to ReadFile:", err) + os.Exit(1) + } + k, err := pki.DecodePrivateKey(privateKeyData) + if err != nil { + logger.Println("failed to DecodePrivateKey:", err) + os.Exit(1) + } + privateKey = k + } + if tokenPublicKeyFile != "" { + publicKeyData, err := os.ReadFile(tokenPublicKeyFile) + if err != nil { + logger.Println("failed to ReadFile:", err) + os.Exit(1) + } + k, err := pki.DecodePublicKey(publicKeyData) + if err != nil { + logger.Println("failed to DecodePublicKey:", err) + os.Exit(1) + } + publicKey = k + } else if privateKey != nil { + publicKey = &privateKey.PublicKey + } + } + opts = append(opts, crproxy.WithSimpleAuth(true)) + + authenticator := token.NewAuthenticator(token.NewDecoder(signing.NewVerifier(publicKey)), tokenURL) + opts = append(opts, crproxy.WithAuthenticator(authenticator)) + + if privateKey != nil { + gen := token.NewGenerator(token.NewEncoder(signing.NewSigner(privateKey)), auth, logger) + mux.Handle("/auth/token", gen) + } } if redirectOriginBlobLinks { @@ -501,9 +569,10 @@ func main() { } mux.Handle("/v2/", crp) - mux.HandleFunc("/auth/token", crp.AuthToken) - mux.HandleFunc("/internal/api/image/sync", crp.Sync) + if enableInternalAPI { + mux.HandleFunc("/internal/api/image/sync", crp.Sync) + } if enablePprof { mux.HandleFunc("/debug/pprof/", pprof.Index) diff --git a/crproxy.go b/crproxy.go index 726b00e..ca07b5e 100644 --- a/crproxy.go +++ b/crproxy.go @@ -14,6 +14,8 @@ import ( "time" "github.com/daocloud/crproxy/internal/maps" + "github.com/daocloud/crproxy/logger" + "github.com/daocloud/crproxy/token" "github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/client/auth" "github.com/docker/distribution/registry/client/auth/challenge" @@ -30,10 +32,6 @@ var ( catalog = prefix + "_catalog" ) -type Logger interface { - Println(v ...interface{}) -} - type ImageInfo struct { Host string Name string @@ -58,7 +56,7 @@ type CRProxy struct { basicCredentials *basicCredentials mutClientset sync.Mutex bytesPool sync.Pool - logger Logger + logger logger.Logger totalBlobsSpeedLimit *geario.Gear speedLimitRecord maps.SyncMap[string, *geario.BPS] blobsSpeedLimit *geario.B @@ -76,9 +74,6 @@ type CRProxy struct { privilegedNoAuth bool disableTagsList bool simpleAuth bool - simpleAuthUserpassFunc func(r *http.Request, userinfo *url.Userinfo) bool - tokenURL string - tokenAuthForceTLS bool matcher hostmatcher.Matcher defaultRegistry string @@ -90,6 +85,8 @@ type CRProxy struct { manifestCache maps.SyncMap[string, time.Time] manifestCacheDuration time.Duration + + authenticator *token.Authenticator } type Option func(c *CRProxy) @@ -112,17 +109,9 @@ func WithRedirectToOriginBlobFunc(f func(r *http.Request, info *ImageInfo) bool) } } -func WithSimpleAuth(b bool, tokenURL string, forceTLS bool) Option { +func WithSimpleAuth(b bool) Option { return func(c *CRProxy) { c.simpleAuth = b - c.tokenURL = tokenURL - c.tokenAuthForceTLS = forceTLS - } -} - -func WithSimpleAuthUserFunc(f func(r *http.Request, userinfo *url.Userinfo) bool) Option { - return func(c *CRProxy) { - c.simpleAuthUserpassFunc = f } } @@ -200,7 +189,7 @@ func WithBaseClient(baseClient *http.Client) Option { } } -func WithLogger(logger Logger) Option { +func WithLogger(logger logger.Logger) Option { return func(c *CRProxy) { c.logger = logger } @@ -258,6 +247,12 @@ func WithAllowHeadMethod(allowHeadMethod bool) Option { } } +func WithAuthenticator(authenticator *token.Authenticator) Option { + return func(c *CRProxy) { + c.authenticator = authenticator + } +} + func NewCRProxy(opts ...Option) (*CRProxy, error) { c := &CRProxy{ challengeManager: challenge.NewSimpleManager(), @@ -269,6 +264,7 @@ func NewCRProxy(opts ...Option) (*CRProxy, error) { }, }, } + for _, opt := range opts { opt(c) } @@ -279,6 +275,12 @@ func NewCRProxy(opts ...Option) (*CRProxy, error) { } c.basicCredentials = bc } + + if c.simpleAuth { + if c.authenticator == nil { + return nil, fmt.Errorf("no authenticator provided") + } + } return c, nil } @@ -505,25 +507,38 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) return } + oriPath := r.URL.Path + if oriPath == catalog { + errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + return + } r.RemoteAddr = getIP(r.RemoteAddr) - - if c.simpleAuth && !c.authorization(rw, r) { - c.authenticate(rw, r) - return + var t *token.Token + if c.simpleAuth { + gt, err := c.authenticator.Authorization(r) + if err != nil { + if c.logger != nil { + c.logger.Println("failed to authorize", r.RemoteAddr, err) + } + c.authenticator.Authenticate(rw, r) + return + } + t = > + } else { + t = &token.Token{} } - oriPath := r.URL.Path if oriPath == prefix { apiBase(rw, r) return } - if !strings.HasPrefix(oriPath, prefix) { - c.notFoundResponse(rw, r) + if c.simpleAuth && (t.Scope == "") { + c.authenticator.Authenticate(rw, r) return } - if oriPath == catalog { - errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + if !strings.HasPrefix(oriPath, prefix) { + c.notFoundResponse(rw, r) return } @@ -554,8 +569,18 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { Name: info.Image, } - if c.blockFunc != nil && !c.isPrivileged(r, nil) { + if c.isPrivileged(r, imageInfo) { + t.NoRateLimit = true + t.NoAllowlist = true + t.AllowTagsList = true + } + + if c.disableTagsList && info.TagsList && !t.AllowTagsList { + emptyTagsList(rw, r) + return + } + if c.blockFunc != nil && !c.isPrivileged(r, nil) { blockMessage, block := c.block(&BlockInfo{ IP: r.RemoteAddr, Host: info.Host, @@ -573,11 +598,6 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { info.Host = c.getDomainAlias(info.Host) - if info.TagsList && !c.isPrivileged(r, nil) && c.disableTagsList { - emptyTagsList(rw, r) - return - } - path, err := info.Path() if err != nil { if c.logger != nil { @@ -599,7 +619,7 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } - if !c.isPrivileged(r, imageInfo) { + if !t.NoRateLimit { if !c.checkLimit(rw, r, info) { return } @@ -607,17 +627,19 @@ func (c *CRProxy) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if c.storageDriver != nil { if info.Blobs != "" { - c.cacheBlobResponse(rw, r, info) + c.cacheBlobResponse(rw, r, info, t) return - } else if info.Manifests != "" { - c.cacheManifestResponse(rw, r, info) + } + + if info.Manifests != "" { + c.cacheManifestResponse(rw, r, info, t) return } } - c.directResponse(rw, r, info) + c.directResponse(rw, r, info, t) } -func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) { +func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { cli := c.getClientset(info.Host, info.Image) resp, err := c.doWithAuth(cli, r, info.Host) if err != nil { @@ -661,10 +683,7 @@ func (c *CRProxy) directResponse(rw http.ResponseWriter, r *http.Request, info * defer c.bytesPool.Put(buf) var body io.Reader = resp.Body - if !c.isPrivileged(r, &ImageInfo{ - Host: info.Host, - Name: info.Image, - }) { + if !t.NoRateLimit { c.accumulativeLimit(r, info, resp.ContentLength) if c.totalBlobsSpeedLimit != nil && info.Blobs != "" { diff --git a/crproxy_blob.go b/crproxy_blob.go index fe2bf4a..8f130c3 100644 --- a/crproxy_blob.go +++ b/crproxy_blob.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/daocloud/crproxy/token" "github.com/docker/distribution/registry/api/errcode" ) @@ -19,7 +20,7 @@ func blobCachePath(blob string) string { return path.Join("/docker/registry/v2/blobs/sha256", blob[:2], blob, "data") } -func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) { +func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { ctx := r.Context() blobPath := blobCachePath(info.Blobs) @@ -57,10 +58,7 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf return } - if !c.isPrivileged(r, &ImageInfo{ - Host: info.Host, - Name: info.Image, - }) { + if !t.NoRateLimit { c.accumulativeLimit(r, info, size) if !c.waitForLimit(r, info, size) { c.errorResponse(rw, r, nil) @@ -109,10 +107,7 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf return } - if !c.isPrivileged(r, &ImageInfo{ - Host: info.Host, - Name: info.Image, - }) { + if !t.NoRateLimit { c.accumulativeLimit(r, info, signal.size) if !c.waitForLimit(r, info, signal.size) { c.errorResponse(rw, r, nil) diff --git a/crproxy_manifest.go b/crproxy_manifest.go index 317eeac..16c8fef 100644 --- a/crproxy_manifest.go +++ b/crproxy_manifest.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/daocloud/crproxy/token" "github.com/docker/distribution/registry/api/errcode" ) @@ -25,7 +26,7 @@ func manifestTagCachePath(host, image, tagOrBlob string) string { return path.Join("/docker/registry/v2/repositories", host, image, "_manifests/tags", tagOrBlob, "current/link") } -func (c *CRProxy) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo) { +func (c *CRProxy) cacheManifestResponse(rw http.ResponseWriter, r *http.Request, info *PathInfo, t *token.Token) { if c.cachedManifest(rw, r, info, true) { return } diff --git a/go.mod b/go.mod index 2a15afe..7f74a9e 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/s2a-go v0.1.4 // indirect github.com/google/uuid v1.3.1 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect diff --git a/go.sum b/go.sum index f059cfb..958827d 100644 --- a/go.sum +++ b/go.sum @@ -105,8 +105,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= diff --git a/internal/pki/pem.go b/internal/pki/pem.go new file mode 100644 index 0000000..d62567a --- /dev/null +++ b/internal/pki/pem.go @@ -0,0 +1,48 @@ +package pki + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" +) + +func EncodePrivateKey(privateKey *rsa.PrivateKey) ([]byte, error) { + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + return pem.EncodeToMemory(block), nil +} + +func DecodePrivateKey(data []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return privateKey, nil +} + +func EncodePublicKey(publicKey *rsa.PublicKey) ([]byte, error) { + block := &pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: x509.MarshalPKCS1PublicKey(publicKey), + } + return pem.EncodeToMemory(block), nil +} + +func DecodePublicKey(data []byte) (*rsa.PublicKey, error) { + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode PEM block") + } + publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, err + } + return publicKey, nil +} diff --git a/internal/pki/pki.go b/internal/pki/pki.go new file mode 100644 index 0000000..d66bef3 --- /dev/null +++ b/internal/pki/pki.go @@ -0,0 +1,10 @@ +package pki + +import ( + "crypto/rand" + "crypto/rsa" +) + +func GenerateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 1024) +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..a52dc24 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,5 @@ +package logger + +type Logger interface { + Println(v ...interface{}) +} diff --git a/signing/signing.go b/signing/signing.go new file mode 100644 index 0000000..421ab35 --- /dev/null +++ b/signing/signing.go @@ -0,0 +1,76 @@ +package signing + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "fmt" + "strings" +) + +// The signing format is as follows +// +// base64(signature(data)) + "," + base64(data) +// +// Don't store private data like passwords. + +var base = base64.RawURLEncoding + +type Signer struct { + PrivateKey *rsa.PrivateKey +} + +func NewSigner(privateKey *rsa.PrivateKey) *Signer { + return &Signer{ + PrivateKey: privateKey, + } +} + +func (e *Signer) Sign(data []byte) (code string, err error) { + encodedData := base.EncodeToString(data) + digest := sha256.Sum256([]byte(encodedData)) + signature, err := rsa.SignPSS(rand.Reader, e.PrivateKey, crypto.SHA256, digest[:], nil) + if err != nil { + return "", err + } + + encodedSignature := base.EncodeToString(signature) + return encodedSignature + "," + encodedData, nil +} + +type Verifier struct { + PublicKey *rsa.PublicKey +} + +func NewVerifier(publicKey *rsa.PublicKey) *Verifier { + return &Verifier{ + PublicKey: publicKey, + } +} + +func (d *Verifier) Verify(code string) ([]byte, error) { + cs := strings.SplitN(code, ",", 3) + if len(cs) != 2 { + return nil, fmt.Errorf("invalid token code: %s", code) + } + encodedSignature := cs[0] + encodedData := cs[1] + signature, err := base.DecodeString(encodedSignature) + if err != nil { + return nil, err + } + + digest := sha256.Sum256([]byte(encodedData)) + err = rsa.VerifyPSS(d.PublicKey, crypto.SHA256, digest[:], signature, nil) + if err != nil { + return nil, err + } + + data, err := base.DecodeString(encodedData) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/signing/signing_test.go b/signing/signing_test.go new file mode 100644 index 0000000..af3f69d --- /dev/null +++ b/signing/signing_test.go @@ -0,0 +1,77 @@ +package signing + +import ( + "reflect" + "testing" + + "github.com/daocloud/crproxy/internal/pki" +) + +func TestSigning(t *testing.T) { + privateKey, err := pki.GenerateKey() + if err != nil { + t.Fatalf("failed to generate private key: %s", err) + } + + raw := []byte("Hello world") + encoder := NewSigner(privateKey) + code, err := encoder.Sign(raw) + if err != nil { + t.Fatalf("failed to sign token: %s", err) + } + + t.Logf("encoded token: %s", code) + t.Logf("encoded size: %d", len(code)) + + decoder := NewVerifier(&privateKey.PublicKey) + decoded, err := decoder.Verify(code) + if err != nil { + t.Fatalf("failed to verify token: %s", err) + } + if !reflect.DeepEqual(decoded, raw) { + t.Fatalf("decoded token does not match original") + } +} + +func BenchmarkSign(b *testing.B) { + privateKey, err := pki.GenerateKey() + if err != nil { + b.Fatalf("failed to generate private key: %s", err) + } + encoder := NewSigner(privateKey) + raw := []byte("Hello world") + + b.StartTimer() + defer b.StopTimer() + for i := 0; i < b.N; i++ { + _, err := encoder.Sign(raw) + if err != nil { + b.Fatalf("failed to sign token: %s", err) + } + } +} + +func BenchmarkVerify(b *testing.B) { + privateKey, err := pki.GenerateKey() + if err != nil { + b.Fatalf("failed to generate private key: %s", err) + } + + raw := []byte("Hello world") + encoder := NewSigner(privateKey) + code, err := encoder.Sign(raw) + if err != nil { + b.Fatalf("failed to sign token: %s", err) + } + + decoder := NewVerifier(&privateKey.PublicKey) + + b.StartTimer() + defer b.StopTimer() + for i := 0; i < b.N; i++ { + _, err := decoder.Verify(code) + if err != nil { + b.Fatalf("failed to verify token: %s", err) + } + } +} diff --git a/token/authenticator.go b/token/authenticator.go new file mode 100644 index 0000000..632652d --- /dev/null +++ b/token/authenticator.go @@ -0,0 +1,62 @@ +package token + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/docker/distribution/registry/api/errcode" +) + +type Authenticator struct { + tokenDecoder *Decoder + tokenURL string +} + +func NewAuthenticator( + tokenDecoder *Decoder, + tokenURL string, +) *Authenticator { + return &Authenticator{ + tokenDecoder: tokenDecoder, + tokenURL: tokenURL, + } +} + +func (c *Authenticator) Authenticate(rw http.ResponseWriter, r *http.Request) { + tokenURL := c.tokenURL + if tokenURL == "" { + var scheme = "http" + if r.TLS != nil || r.URL.Scheme == "https" { + scheme = "https" + } + tokenURL = scheme + "://" + r.Host + "/auth/token" + } + header := fmt.Sprintf("Bearer realm=%q,service=%q", tokenURL, r.Host) + rw.Header().Set("WWW-Authenticate", header) + errcode.ServeJSON(rw, errcode.ErrorCodeUnauthorized) +} + +func (c *Authenticator) Authorization(r *http.Request) (Token, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return Token{}, fmt.Errorf("no authorization header found") + } + + if !strings.HasPrefix(auth, "Bearer ") { + return Token{}, fmt.Errorf("invalid authorization header: %q", auth) + } + + t, err := c.tokenDecoder.Decode(auth[7:]) + if err != nil { + return Token{}, err + } + + if t.ExpiresAt.Before(time.Now()) { + return Token{}, fmt.Errorf("%s token expired", t.Account) + } + + r.Header.Del("Authorization") + return t, nil +} diff --git a/token/encoding.go b/token/encoding.go new file mode 100644 index 0000000..9077fd8 --- /dev/null +++ b/token/encoding.go @@ -0,0 +1,68 @@ +package token + +import ( + "encoding/json" + "time" + + "github.com/daocloud/crproxy/signing" +) + +type Encoder struct { + signer *signing.Signer +} + +func NewEncoder(signer *signing.Signer) *Encoder { + return &Encoder{ + signer: signer, + } +} + +type Decoder struct { + verifier *signing.Verifier +} + +func NewDecoder(verifier *signing.Verifier) *Decoder { + return &Decoder{ + verifier: verifier, + } +} + +type Token struct { + ExpiresAt time.Time `json:"expires_at,omitempty"` + Scope string `json:"scope,omitempty"` + Service string `json:"service,omitempty"` + + Account string `json:"account,omitempty"` + Image string `json:"image,omitempty"` + + Attribute `json:"attribute,omitempty"` +} + +type Attribute struct { + NoRateLimit bool `json:"no_rate_limit,omitempty"` + NoAllowlist bool `json:"no_allowlist,omitempty"` + AllowTagsList bool `json:"allow_tags_list,omitempty"` +} + +func (p *Encoder) Encode(t Token) (code string, err error) { + data, err := json.Marshal(t) + if err != nil { + return "", err + } + + return p.signer.Sign(data) +} + +func (p *Decoder) Decode(code string) (t Token, err error) { + data, err := p.verifier.Verify(code) + if err != nil { + return t, err + } + + err = json.Unmarshal(data, &t) + if err != nil { + return t, err + } + + return t, nil +} diff --git a/token/generator.go b/token/generator.go new file mode 100644 index 0000000..7d526be --- /dev/null +++ b/token/generator.go @@ -0,0 +1,171 @@ +package token + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/url" + "strings" + "time" + + "github.com/daocloud/crproxy/logger" + "github.com/docker/distribution/registry/api/errcode" +) + +type Generator struct { + authFunc func(r *http.Request, userinfo *url.Userinfo) (Attribute, bool) + logger logger.Logger + tokenEncoder *Encoder +} + +func NewGenerator( + tokenEncoder *Encoder, + authFunc func(r *http.Request, userinfo *url.Userinfo) (Attribute, bool), + logger logger.Logger, +) *Generator { + return &Generator{ + authFunc: authFunc, + logger: logger, + tokenEncoder: tokenEncoder, + } +} + +func (g *Generator) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errcode.ServeJSON(rw, errcode.ErrorCodeUnsupported) + return + } + + t, err := g.getToken(r) + if err != nil { + errcode.ServeJSON(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + + now := time.Now() + expiresIn := 60 + + t.ExpiresAt = now.Add((time.Duration(expiresIn) + 10) * time.Second) + + code, err := g.tokenEncoder.Encode(*t) + if err != nil { + if g.logger != nil { + g.logger.Println("Error encoding token", err) + } + errcode.ServeJSON(rw, errcode.ErrorCodeUnknown) + return + } + + json.NewEncoder(rw).Encode(tokenInfo{ + Token: code, + ExpiresIn: int64(expiresIn), + IssuedAt: now, + }) +} + +func (g *Generator) getToken(r *http.Request) (*Token, error) { + query := r.URL.Query() + account := query.Get("account") + scope := query.Get("scope") + service := query.Get("service") + + t := Token{ + Service: service, + Scope: scope, + Account: account, + } + + if scope != "" { + scopeSlice := strings.SplitN(scope, ":", 4) + if len(scopeSlice) != 3 { + return nil, errcode.ErrorCodeDenied + } + + if scopeSlice[2] != "pull" { + return nil, errcode.ErrorCodeDenied + } + + t.Image = scopeSlice[1] + } + + if g.authFunc == nil { + t.Account = "" + return &t, nil + } + + authorization := r.Header.Get("Authorization") + if authorization == "" { + attribute, login := g.authFunc(r, nil) + if !login { + return nil, errcode.ErrorCodeDenied + } + t.Attribute = attribute + return &t, nil + } + auth := strings.SplitN(authorization, " ", 2) + if len(auth) != 2 { + if g.logger != nil { + g.logger.Println("Login failed", authorization) + } + return nil, errcode.ErrorCodeDenied + } + switch auth[0] { + case "Basic": + user, pass, ok := parseBasicAuth(auth[1]) + if user == "" || pass == "" { + return nil, errcode.ErrorCodeDenied + } + + if account != user { + return nil, errcode.ErrorCodeDenied + } + + var u *url.Userinfo + if ok { + u = url.UserPassword(user, pass) + } else { + u = url.User(user) + } + + attribute, login := g.authFunc(r, u) + if !login { + if g.logger != nil { + g.logger.Println("Login failed user and password", u) + } + return nil, errcode.ErrorCodeDenied + } + t.Attribute = attribute + + if g.logger != nil { + g.logger.Println("Login succeed user", u.Username()) + } + default: + if g.logger != nil { + g.logger.Println("Unsupported authorization", authorization) + } + return nil, errcode.ErrorCodeDenied + } + + return &t, nil +} + +type tokenInfo struct { + Token string `json:"token,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + IssuedAt time.Time `json:"issued_at,omitempty"` +} + +func parseBasicAuth(auth string) (username, password string, ok bool) { + c, err := base64.StdEncoding.DecodeString(auth) + if err != nil { + return "", "", false + } + cs := string(c) + username, password, ok = strings.Cut(cs, ":") + if !ok { + return "", "", false + } + return username, password, true +}