diff --git a/cluster/cluster.go b/cluster/cluster.go index 51d2cf6..08a850e 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -44,6 +44,7 @@ var ( ) type Cluster struct { + name string opts config.ClusterOptions gcfg config.ClusterGeneralConfig @@ -69,7 +70,7 @@ type Cluster struct { } func NewCluster( - opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, + name string, opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, storageManager *storage.Manager, statManager *StatManager, ) (cr *Cluster) { @@ -89,10 +90,17 @@ func NewCluster( } // ID returns the cluster id +// The ID may not be unique in the openbmclapi cluster runtime func (cr *Cluster) ID() string { return cr.opts.Id } +// Name returns the cluster's alias name +// The name must be unique in the openbmclapi cluster runtime +func (cr *Cluster) Name() string { + return cr.name +} + // Secret returns the cluster secret func (cr *Cluster) Secret() string { return cr.opts.Secret diff --git a/main.go b/main.go index 0d46a91..0dd9007 100644 --- a/main.go +++ b/main.go @@ -145,38 +145,41 @@ func main() { go func(ctx context.Context) { defer log.RecordPanic() - var wg sync.WaitGroup - errs := make([]error, len(r.clusters)) { - i := 0 + type clusterSetupRes struct { + cluster *cluster.Cluster + err error + cert *tls.Certificate + } + resCh := make(chan clusterSetupRes) for _, cr := range r.clusters { - i++ - go func(i int, cr *cluster.Cluster) { - defer wg.Done() - errs[i] = cr.Connect(ctx) - }(i, cr) + go func(cr *cluster.Cluster) { + defer log.RecordPanic() + if err := cr.Connect(ctx); err != nil { + log.Errorf("Failed to connect cluster %s to server %q: %v", cr.ID(), cr.Options().Server, err) + resCh <- clusterSetupRes{cluster: cr, err: err} + return + } + cert, err := r.RequestClusterCert(ctx, cr) + if err != nil { + log.Errorf("Failed to request certificate for cluster %s: %v", cr.ID(), err) + resCh <- clusterSetupRes{cluster: cr, err: err} + return + } + resCh <- clusterSetupRes{cluster: cr, cert: cert} + }(cr) } - } - wg.Wait() - if ctx.Err() != nil { - return - } - - { - var err error - r.tlsConfig, err = r.PatchTLSWithClusterCert(ctx, r.tlsConfig) - if err != nil { - return + for range len(r.clusters) { + select { + case res := <-resCh: + r.certificates[res.cluster.Name()] = res.cert + case <-ctx.Done(): + return + } } - r.listener.TLSConfig.Store(r.tlsConfig) } - firstSyncDone := make(chan struct{}, 0) - go func() { - defer log.RecordPanic() - defer close(firstSyncDone) - r.InitSynchronizer(ctx) - }() + r.listener.TLSConfig.Store(r.PatchTLSWithClusterCertificates(r.tlsConfig)) if !r.Config.Tunneler.Enable { strPort := strconv.Itoa((int)(r.getPublicPort())) @@ -186,13 +189,9 @@ func main() { } log.TrInfof("info.wait.first.sync") - select { - case <-firstSyncDone: - case <-ctx.Done(): - return - } + r.InitSynchronizer(ctx) - // r.EnableCluster(ctx) + r.EnableClusterAll(ctx) }(ctx) code := r.ListenSignals(ctx, cancel) @@ -229,10 +228,11 @@ type Runner struct { handlerAPIv0 http.Handler hijackHandler http.Handler - tlsConfig *tls.Config - publicHost string - publicPort uint16 - listener *utils.HTTPTLSListener + tlsConfig *tls.Config + certificates map[string]*tls.Certificate + publicHost string + publicPort uint16 + listener *utils.HTTPTLSListener reloading atomic.Bool updating atomic.Bool @@ -314,7 +314,7 @@ func (r *Runner) InitClusters(ctx context.Context) { r.clusters = make(map[string]*cluster.Cluster) gcfg := r.GetClusterGeneralConfig() for name, opts := range r.Config.Clusters { - cr := cluster.NewCluster(opts, gcfg, r.storageManager, r.statManager) + cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager) if err := cr.Init(ctx); err != nil { log.TrErrorf("error.init.failed", err) } else { @@ -618,28 +618,34 @@ func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { return tlsConfig, nil } -func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Config) (*tls.Config, error) { - certs := make([]tls.Certificate, 0) - for _, cr := range r.clusters { - if cr.Options().Byoc { - continue - } - log.TrInfof("info.cert.requesting", cr.ID()) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := cr.RequestCert(tctx) - cancel() - if err != nil { - log.TrErrorf("error.cert.request.failed", err) - continue - } - cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.TrErrorf("error.cert.requested.parse.failed", err) - continue +func (r *Runner) RequestClusterCert(ctx context.Context, cr *cluster.Cluster) (*tls.Certificate, error) { + if cr.Options().Byoc { + return nil, nil + } + log.TrInfof("info.cert.requesting", cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cert.request.failed", err) + return nil, err + } + cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.TrErrorf("error.cert.requested.parse.failed", err) + return nil, err + } + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.TrInfof("info.cert.requested", certHost) + return &cert, nil +} + +func (r *Runner) PatchTLSWithClusterCertificates(tlsConfig *tls.Config) *tls.Config { + certs := make([]tls.Certificate, 0, len(r.certificates)) + for _, c := range r.certificates { + if c != nil { + certs = append(certs, *c) } - certs = append(certs, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.TrInfof("info.cert.requested", certHost) } if len(certs) == 0 { if tlsConfig == nil { @@ -649,7 +655,7 @@ func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Con } tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) } - return tlsConfig, nil + return tlsConfig } // updateClustersWithGeneralConfig will re-enable all clusters with latest general config