diff --git a/api/bmclapi/hijacker.go b/api/bmclapi/hijacker.go index f45d0b24..3218c2b0 100644 --- a/api/bmclapi/hijacker.go +++ b/api/bmclapi/hijacker.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package bmclapi import ( "context" diff --git a/cluster/cluster.go b/cluster/cluster.go index 3f7943f2..51d2cf6d 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -70,9 +70,13 @@ type Cluster struct { func NewCluster( opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, - storageManager *storage.Manager, storages []int, + storageManager *storage.Manager, statManager *StatManager, ) (cr *Cluster) { + storages := make([]int, len(opts.Storages)) + for i, name := range opts.Storages { + storages[i] = storageManager.GetIndex(name) + } cr = &Cluster{ opts: opts, gcfg: gcfg, @@ -120,10 +124,23 @@ func (cr *Cluster) AcceptHost(host string) bool { return false } +func (cr *Cluster) Options() *config.ClusterOptions { + return &cr.opts +} + +func (cr *Cluster) GeneralConfig() *config.ClusterGeneralConfig { + return &cr.gcfg +} + // Init do setup on the cluster // Init should only be called once during the cluster's whole life // The context passed in only affect the logical of Init method func (cr *Cluster) Init(ctx context.Context) error { + for i, ind := range cr.storages { + if ind == -1 { + return fmt.Errorf("Storage %q does not exists", cr.opts.Storages[i]) + } + } return nil } @@ -172,7 +189,7 @@ func (cr *Cluster) enable(ctx context.Context) error { Host: cr.gcfg.PublicHost, Port: cr.gcfg.PublicPort, Version: build.ClusterVersion, - Byoc: cr.gcfg.Byoc, + Byoc: cr.opts.Byoc, NoFastEnable: cr.gcfg.NoFastEnable, Flavor: ConfigFlavor{ Runtime: "golang/" + runtime.GOOS + "-" + runtime.GOARCH, diff --git a/cluster/handler.go b/cluster/handler.go index dd2151a4..4dbd3a18 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -36,7 +36,7 @@ import ( "github.com/LiterMC/go-openbmclapi/storage" ) -func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string, size int64) { +func (cr *Cluster) HandleFile(rw http.ResponseWriter, req *http.Request, hash string) { defer log.RecoverPanic(nil) if !cr.Enabled() { @@ -88,6 +88,8 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st api.SetAccessInfo(req, "cluster", cr.ID()) + var size int64 = -1 // TODO: get the size + var ( sto storage.Storage err error @@ -121,7 +123,7 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st http.Error(rw, err.Error(), http.StatusInternalServerError) } -func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size int) { +func (cr *Cluster) HandleMeasure(rw http.ResponseWriter, req *http.Request, size int) { if !cr.Enabled() { // do not serve file if cluster is not enabled yet http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable) diff --git a/cluster/http.go b/cluster/http.go index f66f3f54..4910444c 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -90,7 +90,7 @@ func (cr *Cluster) makeReqWithBody( query url.Values, body io.Reader, ) (req *http.Request, err error) { var u *url.URL - if u, err = url.Parse(cr.opts.Prefix); err != nil { + if u, err = url.Parse(cr.opts.Server); err != nil { return } u.Path = path.Join(u.Path, relpath) diff --git a/cluster/socket.go b/cluster/socket.go index 7b2b78f1..db238397 100644 --- a/cluster/socket.go +++ b/cluster/socket.go @@ -48,10 +48,10 @@ func (cr *Cluster) Connect(ctx context.Context) error { } engio, err := engine.NewSocket(engine.Options{ - Host: cr.opts.Prefix, + Host: cr.opts.Server, Path: "/socket.io/", ExtraHeaders: http.Header{ - "Origin": {cr.opts.Prefix}, + "Origin": {cr.opts.Server}, "User-Agent": {build.ClusterUserAgent}, }, DialTimeout: time.Minute * 6, diff --git a/config.go b/config.go index 43298f58..b5259111 100644 --- a/config.go +++ b/config.go @@ -21,12 +21,20 @@ package main import ( "bytes" + "errors" + "fmt" + "net/url" + "os" "gopkg.in/yaml.v3" "github.com/LiterMC/go-openbmclapi/config" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" ) +const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com" + func migrateConfig(data []byte, cfg *config.Config) { var oldConfig map[string]any if err := yaml.Unmarshal(data, &oldConfig); err != nil { @@ -45,11 +53,13 @@ func migrateConfig(data []byte, cfg *config.Config) { if oldConfig["clusters"].(map[string]any) == nil { id, ok1 := oldConfig["cluster-id"].(string) secret, ok2 := oldConfig["cluster-secret"].(string) - if ok1 && ok2 { - cfg.Clusters = map[string]ClusterItem{ + publicHost, ok3 := oldConfig["public-host"].(string) + if ok1 && ok2 && ok3 { + cfg.Clusters = map[string]config.ClusterOptions{ "main": { - Id: id, - Secret: secret, + Id: id, + Secret: secret, + PublicHosts: []string{publicHost}, }, } } @@ -67,7 +77,7 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { log.TrErrorf("error.config.read.failed", err) os.Exit(1) } - log.TrError("error.config.not.exists") + log.TrErrorf("error.config.not.exists") notexists = true } else { migrateConfig(data, cfg) @@ -76,15 +86,18 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } if len(cfg.Clusters) == 0 { - cfg.Clusters = map[string]ClusterItem{ + cfg.Clusters = map[string]config.ClusterOptions{ "main": { - Id: "${CLUSTER_ID}", - Secret: "${CLUSTER_SECRET}", + Id: "${CLUSTER_ID}", + Secret: "${CLUSTER_SECRET}", + PublicHosts: []string{}, + Server: DefaultBMCLAPIServer, + SkipSignatureCheck: false, }, } } if len(cfg.Certificates) == 0 { - cfg.Certificates = []CertificateConfig{ + cfg.Certificates = []config.CertificateConfig{ { Cert: "/path/to/cert.pem", Key: "/path/to/key.pem", @@ -123,12 +136,6 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } ids[s.Id] = i - if s.Cluster != "" && s.Cluster != "-" { - if _, ok := cfg.Clusters[s.Cluster]; !ok { - log.Errorf("Storage %q is trying to connect to a not exists cluster %q.", s.Id, s.Cluster) - os.Exit(1) - } - } } } @@ -173,7 +180,8 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } if notexists { - log.TrError("error.config.created") + log.TrErrorf("error.config.created") + return nil, errors.New("Please edit the config before continue!") } return } diff --git a/config/config.go b/config/config.go index 5c843c9d..ec18ca2f 100644 --- a/config/config.go +++ b/config/config.go @@ -36,7 +36,6 @@ type Config struct { PublicPort uint16 `yaml:"public-port"` Host string `yaml:"host"` Port uint16 `yaml:"port"` - Byoc bool `yaml:"byoc"` UseCert bool `yaml:"use-cert"` TrustedXForwardedFor bool `yaml:"trusted-x-forwarded-for"` @@ -65,7 +64,7 @@ type Config struct { Advanced AdvancedConfig `yaml:"advanced"` } -func (cfg *Config) applyWebManifest(manifest map[string]any) { +func (cfg *Config) ApplyWebManifest(manifest map[string]any) { if cfg.Dashboard.Enable { manifest["name"] = cfg.Dashboard.PwaName manifest["short_name"] = cfg.Dashboard.PwaShortName @@ -79,7 +78,6 @@ func NewDefaultConfig() *Config { PublicPort: 0, Host: "0.0.0.0", Port: 4000, - Byoc: false, TrustedXForwardedFor: false, OnlyGcWhenStart: false, diff --git a/config/server.go b/config/server.go index 32963a58..a1763a8f 100644 --- a/config/server.go +++ b/config/server.go @@ -34,15 +34,16 @@ import ( type ClusterOptions struct { Id string `json:"id" yaml:"id"` Secret string `json:"secret" yaml:"secret"` + Byoc bool `json:"byoc"` PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` - Prefix string `json:"prefix" yaml:"prefix"` + Server string `json:"server" yaml:"server"` SkipSignatureCheck bool `json:"skip-signature-check" yaml:"skip-signature-check"` + Storages []string `json:"storages" yaml:"storages"` } type ClusterGeneralConfig struct { PublicHost string `json:"public-host"` PublicPort uint16 `json:"public-port"` - Byoc bool `json:"byoc"` NoFastEnable bool `json:"no-fast-enable"` MaxReconnectCount int `json:"max-reconnect-count"` } @@ -77,6 +78,10 @@ type CacheConfig struct { newCache func() cache.Cache `yaml:"-"` } +func (c *CacheConfig) NewCache() cache.Cache { + return c.newCache() +} + func (c *CacheConfig) UnmarshalYAML(n *yaml.Node) (err error) { var cfg struct { Type string `yaml:"type"` @@ -148,3 +153,11 @@ func (c *TunnelConfig) UnmarshalYAML(n *yaml.Node) (err error) { } return } + +func (c *TunnelConfig) MatchTunnelOutput(line []byte) (host, port []byte, ok bool) { + res := c.outputRegex.FindSubmatch(line) + if res == nil { + return + } + return res[c.hostOut], res[c.portOut], true +} diff --git a/dashboard.go b/dashboard.go index f8fbcaef..3bcd689d 100644 --- a/dashboard.go +++ b/dashboard.go @@ -60,13 +60,14 @@ var dsbManifest = func() (dsbManifest map[string]any) { return }() -func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request, pth string) { +func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet && req.Method != http.MethodHead { rw.Header().Set("Allow", http.MethodGet+", "+http.MethodHead) http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed) return } acceptEncoding := utils.SplitCSV(req.Header.Get("Accept-Encoding")) + pth := strings.TrimPrefix(req.URL.Path, "/") switch pth { case "": break diff --git a/handler.go b/handler.go index bedc3fb6..97e4e85a 100644 --- a/handler.go +++ b/handler.go @@ -24,16 +24,11 @@ import ( "context" "crypto" _ "embed" - "encoding/base64" "encoding/hex" "encoding/json" - "errors" "fmt" - "io" "net" "net/http" - "net/textproto" - "os" "strconv" "strings" "time" @@ -43,9 +38,9 @@ import ( "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/internal/gosrc" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/storage" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -108,13 +103,13 @@ var wsUpgrader = &websocket.Upgrader{ } func (r *Runner) GetHandler() http.Handler { - r.apiRateLimiter = limited.NewAPIRateMiddleWare(RealAddrCtxKey, loggedUserKey) - r.apiRateLimiter.SetAnonymousRateLimit(r.RateLimit.Anonymous) - r.apiRateLimiter.SetLoggedRateLimit(r.RateLimit.Logged) + r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) + r.apiRateLimiter.SetAnonymousRateLimit(r.Config.RateLimit.Anonymous) + r.apiRateLimiter.SetLoggedRateLimit(r.Config.RateLimit.Logged) r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader)) - r.hijackHandler = http.StripPrefix("/bmclapi", r.hijackProxy) + r.hijackHandler = http.StripPrefix("/bmclapi", r.hijacker) - handler := utils.NewHttpMiddleWareHandler(r) + handler := utils.NewHttpMiddleWareHandler((http.HandlerFunc)(r.serveHTTP)) // recover panic and log it handler.UseFunc(func(rw http.ResponseWriter, req *http.Request, next http.Handler) { defer log.RecoverPanic(func(any) { @@ -124,67 +119,57 @@ func (r *Runner) GetHandler() http.Handler { }) handler.Use(r.apiRateLimiter) - handler.Use(r.getRecordMiddleWare()) + handler.UseFunc(r.recordMiddleWare) return handler } -func (r *Runner) getRecordMiddleWare() utils.MiddleWareFunc { - type record struct { - used float64 - bytes float64 - ua string - skipUA bool +func (r *Runner) recordMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + ua := req.UserAgent() + var addr string + if r.Config.TrustedXForwardedFor { + // X-Forwarded-For: , , + adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",") + addr = strings.TrimSpace(adr) } - recordCh := make(chan record, 1024) - - return func(rw http.ResponseWriter, req *http.Request, next http.Handler) { - ua := req.UserAgent() - var addr string - if config.TrustedXForwardedFor { - // X-Forwarded-For: , , - adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",") - addr = strings.TrimSpace(adr) - } - if addr == "" { - addr, _, _ = net.SplitHostPort(req.RemoteAddr) - } - srw := utils.WrapAsStatusResponseWriter(rw) - start := time.Now() + if addr == "" { + addr, _, _ = net.SplitHostPort(req.RemoteAddr) + } + srw := utils.WrapAsStatusResponseWriter(rw) + start := time.Now() - log.LogAccess(log.LevelDebug, &preAccessRecord{ - Type: "pre-access", - Time: start, - Addr: addr, - Method: req.Method, - URI: req.RequestURI, - UA: ua, - }) + log.LogAccess(log.LevelDebug, &preAccessRecord{ + Type: "pre-access", + Time: start, + Addr: addr, + Method: req.Method, + URI: req.RequestURI, + UA: ua, + }) - extraInfoMap := make(map[string]any) - ctx := req.Context() - ctx = context.WithValue(ctx, RealAddrCtxKey, addr) - ctx = context.WithValue(ctx, RealPathCtxKey, req.URL.Path) - ctx = context.WithValue(ctx, AccessLogExtraCtxKey, extraInfoMap) - req = req.WithContext(ctx) - next.ServeHTTP(srw, req) + extraInfoMap := make(map[string]any) + ctx := req.Context() + ctx = context.WithValue(ctx, api.RealAddrCtxKey, addr) + ctx = context.WithValue(ctx, api.RealPathCtxKey, req.URL.Path) + ctx = context.WithValue(ctx, api.AccessLogExtraCtxKey, extraInfoMap) + req = req.WithContext(ctx) + next.ServeHTTP(srw, req) - used := time.Since(start) - accRec := &accessRecord{ - Type: "access", - Status: srw.Status, - Used: used, - Content: srw.Wrote, - Addr: addr, - Proto: req.Proto, - Method: req.Method, - URI: req.RequestURI, - UA: ua, - } - if len(extraInfoMap) > 0 { - accRec.Extra = extraInfoMap - } - log.LogAccess(log.LevelInfo, accRec) + used := time.Since(start) + accRec := &accessRecord{ + Type: "access", + Status: srw.Status, + Used: used, + Content: srw.Wrote, + Addr: addr, + Proto: req.Proto, + Method: req.Method, + URI: req.RequestURI, + UA: ua, + } + if len(extraInfoMap) > 0 { + accRec.Extra = extraInfoMap } + log.LogAccess(log.LevelInfo, accRec) } var emptyHashes = func() (hashes map[string]struct{}) { @@ -202,11 +187,11 @@ var emptyHashes = func() (hashes map[string]struct{}) { //go:embed robots.txt var robotTxtContent string -func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (r *Runner) serveHTTP(rw http.ResponseWriter, req *http.Request) { method := req.Method u := req.URL - rw.Header().Set("X-Powered-By", HeaderXPoweredBy) + rw.Header().Set("X-Powered-By", build.HeaderXPoweredBy) rawpath := u.EscapedPath() switch { @@ -249,7 +234,7 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { for _, cr := range r.clusters { if cr.AcceptHost(req.Host) { - cr.HandleFile(rw, req, hash) + cr.HandleMeasure(rw, req, size) return } } @@ -264,23 +249,24 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { case "v0": r.handlerAPIv0.ServeHTTP(rw, req) return - case "v1": - r.handlerAPIv1.ServeHTTP(rw, req) - return + // case "v1": + // r.handlerAPIv1.ServeHTTP(rw, req) + // return } case rawpath == "/" || rawpath == "/dashboard": http.Redirect(rw, req, "/dashboard/", http.StatusFound) return case strings.HasPrefix(rawpath, "/dashboard/"): - if !r.DashboardEnabled { + if !r.Config.Dashboard.Enable { http.NotFound(rw, req) return } - pth := rawpath[len("/dashboard/"):] - r.serveDashboard(rw, req, pth) + req2 := gosrc.RequestStripPrefix(req, "/dashboard") + r.serveDashboard(rw, req2) return case strings.HasPrefix(rawpath, "/bmclapi/"): - r.hijackHandler.ServeHTTP(rw, req) + req2 := gosrc.RequestStripPrefix(req, "/bmclapi") + r.hijackHandler.ServeHTTP(rw, req2) return } http.NotFound(rw, req) diff --git a/internal/gosrc/httpstrip.go b/internal/gosrc/httpstrip.go new file mode 100644 index 00000000..8dcc129a --- /dev/null +++ b/internal/gosrc/httpstrip.go @@ -0,0 +1,26 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gosrc + +import ( + "net/http" + "net/url" + "strings" +) + +func RequestStripPrefix(r *http.Request, prefix string) *http.Request { + p, ok := strings.CutPrefix(r.URL.Path, prefix) + rp, ok2 := strings.CutPrefix(r.URL.RawPath, prefix) + if ok && (ok2 || r.URL.RawPath == "") { + r2 := new(http.Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = p + r2.URL.RawPath = rp + return r2 + } + return nil +} diff --git a/main.go b/main.go index fff4384d..55f42dc4 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -44,6 +45,7 @@ import ( doh "github.com/libp2p/go-doh-resolver" + "github.com/LiterMC/go-openbmclapi/api/bmclapi" "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/database" @@ -51,7 +53,9 @@ import ( "github.com/LiterMC/go-openbmclapi/lang" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" subcmds "github.com/LiterMC/go-openbmclapi/sub_commands" + "github.com/LiterMC/go-openbmclapi/utils" _ "github.com/LiterMC/go-openbmclapi/lang/en" _ "github.com/LiterMC/go-openbmclapi/lang/zh" @@ -124,18 +128,8 @@ func main() { } else { r.Config = config } - if r.Config.Advanced.DebugLog { - log.SetLevel(log.LevelDebug) - } else { - log.SetLevel(log.LevelInfo) - } - if r.Config.NoAccessLog { - log.SetAccessLogSlots(-1) - } else { - log.SetAccessLogSlots(r.Config.AccessLogSlots) - } - r.Config.applyWebManifest(dsbManifest) + r.SetupLogger(ctx) log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) @@ -143,13 +137,37 @@ func main() { r.StartTunneler() } r.InitServer() + r.StartServer(ctx) + r.InitClusters(ctx) go func(ctx context.Context) { defer log.RecordPanic() - if !r.cluster.Connect(ctx) { - osExit(CodeClientOrServerError) + var wg sync.WaitGroup + errs := make([]error, len(r.clusters)) + { + i := 0 + for _, cr := range r.clusters { + i++ + go func(i int, cr *cluster.Cluster) { + defer wg.Done() + errs[i] = cr.Connect(ctx) + }(i, cr) + } + } + wg.Wait() + if ctx.Err() != nil { + return + } + + { + var err error + r.tlsConfig, err = r.PatchTLSWithClusterCert(ctx, r.tlsConfig) + if err != nil { + return + } + r.listener.TLSConfig.Store(r.tlsConfig) } firstSyncDone := make(chan struct{}, 0) @@ -159,30 +177,11 @@ func main() { r.InitSynchronizer(ctx) }() - listener := r.CreateHTTPServerListener(ctx) - go func(listener net.Listener) { - defer listener.Close() - if err := r.clusterSvr.Serve(listener); !errors.Is(err, http.ErrServerClosed) { - log.Error("Error when serving:", err) - os.Exit(1) - } - }(listener) - - var publicHost string - if len(r.publicHosts) == 0 { - publicHost = config.PublicHost - } else { - publicHost = r.publicHosts[0] - } - if !config.Tunneler.Enable { + if !r.Config.Tunneler.Enable { strPort := strconv.Itoa((int)(r.getPublicPort())) - log.TrInfof("info.server.public.at", net.JoinHostPort(publicHost, strPort), r.clusterSvr.Addr, r.getCertCount()) - if len(r.publicHosts) > 1 { - log.TrInfof("info.server.alternative.hosts") - for _, h := range r.publicHosts[1:] { - log.Infof("\t- https://%s", net.JoinHostPort(h, strPort)) - } - } + pubAddr := net.JoinHostPort(r.Config.PublicHost, strPort) + localAddr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) + log.TrInfof("info.server.public.at", pubAddr, localAddr, r.getCertCount()) } log.TrInfof("info.wait.first.sync") @@ -192,14 +191,14 @@ func main() { return } - r.EnableCluster(ctx) + // r.EnableCluster(ctx) }(ctx) - code := r.DoSignals(cancel) + code := r.ListenSignals(ctx, cancel) if code != 0 { log.TrErrorf("program.exited", code) log.TrErrorf("error.exit.please.read.faq") - if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { + if runtime.GOOS == "windows" && !r.Config.Advanced.DoNotOpenFAQOnWindows { // log.TrWarnf("warn.exit.detected.windows.open.browser") // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") // cmd.Start() @@ -212,12 +211,21 @@ func main() { type Runner struct { Config *config.Config - clusters map[string]*Cluster - server *http.Server + clusters map[string]*cluster.Cluster + apiRateLimiter *limited.APIRateMiddleWare + storageManager *storage.Manager + statManager *cluster.StatManager + hijacker *bmclapi.HjProxy + database database.DB + + server *http.Server + handlerAPIv0 http.Handler + hijackHandler http.Handler tlsConfig *tls.Config - listener net.Listener - publicHosts []string + publicHost string + publicPort uint16 + listener *utils.HTTPTLSListener reloading atomic.Bool updating atomic.Bool @@ -225,8 +233,8 @@ type Runner struct { } func (r *Runner) getPublicPort() uint16 { - if r.Config.PublicPort > 0 { - return r.Config.PublicPort + if r.publicPort > 0 { + return r.publicPort } return r.Config.Port } @@ -238,7 +246,80 @@ func (r *Runner) getCertCount() int { return len(r.tlsConfig.Certificates) } -func (r *Runner) DoSignals(cancel context.CancelFunc) int { +func (r *Runner) InitServer() { + r.server = &http.Server{ + ReadTimeout: 10 * time.Second, + IdleTimeout: 5 * time.Second, + Handler: r.GetHandler(), + ErrorLog: log.ProxiedStdLog, + } +} + +// StartServer will start the HTTP server +// If a server is already running on an old listener, the listener will be closed. +func (r *Runner) StartServer(ctx context.Context) error { + htListener, err := r.CreateHTTPListener(ctx) + if err != nil { + return err + } + if r.listener != nil { + r.listener.Close() + } + r.listener = htListener + go func() { + defer htListener.Close() + if err := r.server.Serve(htListener); !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { + log.Error("Error when serving:", err) + os.Exit(1) + } + }() + return nil +} + +func (r *Runner) GetClusterGeneralConfig() config.ClusterGeneralConfig { + return config.ClusterGeneralConfig{ + PublicHost: r.publicHost, + PublicPort: r.getPublicPort(), + NoFastEnable: r.Config.Advanced.NoFastEnable, + MaxReconnectCount: r.Config.MaxReconnectCount, + } +} + +func (r *Runner) InitClusters(ctx context.Context) { + // var ( + // dialer *net.Dialer + // cache = r.Config.Cache.NewCache() + // ) + + _ = doh.NewResolver // TODO: use doh resolver + + r.clusters = make(map[string]*cluster.Cluster) + gcfg := r.GetClusterGeneralConfig() + for name, opts := range r.Config.Clusters { + cr := cluster.NewCluster(opts, gcfg, r.storageManager, r.statManager) + if err := cr.Init(ctx); err != nil { + log.TrErrorf("error.init.failed", err) + } else { + r.clusters[name] = cr + } + } + + // r.cluster = NewCluster(ctx, + // ClusterServerURL, + // baseDir, + // config.PublicHost, r.getPublicPort(), + // config.ClusterId, config.ClusterSecret, + // config.Byoc, dialer, + // config.Storages, + // cache, + // ) + // if err := r.cluster.Init(ctx); err != nil { + // log.TrErrorf("error.init.failed"), err) + // os.Exit(1) + // } +} + +func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int { signalCh := make(chan os.Signal, 1) log.Debugf("Receiving signals") signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) @@ -290,7 +371,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { } } case syscall.SIGHUP: - go r.ReloadConfig() + go r.ReloadConfig(ctx) default: cancel() if forceStop == nil { @@ -312,7 +393,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { return 0 } -func (r *Runner) ReloadConfig() { +func (r *Runner) ReloadConfig(ctx context.Context) { if r.reloading.CompareAndSwap(false, true) { log.Error("Config is already reloading!") return @@ -321,10 +402,54 @@ func (r *Runner) ReloadConfig() { config, err := readAndRewriteConfig() if err != nil { - log.Errorf("Config error: %s", err) + log.Errorf("Config error: %v", err) } else { - r.Config = config + if err := r.updateConfig(ctx, config); err != nil { + log.Errorf("Error when reloading config: %v", err) + } + } +} + +func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error { + oldConfig := r.Config + reloadProcesses := make([]func(context.Context) error, 0, 8) + + if newConfig.LogSlots != oldConfig.LogSlots || newConfig.NoAccessLog != oldConfig.NoAccessLog || newConfig.AccessLogSlots != oldConfig.AccessLogSlots || newConfig.Advanced.DebugLog != oldConfig.Advanced.DebugLog { + reloadProcesses = append(reloadProcesses, r.SetupLogger) + } + if newConfig.Host != oldConfig.Host || newConfig.Port != oldConfig.Port { + reloadProcesses = append(reloadProcesses, r.StartServer) + } + if newConfig.PublicHost != oldConfig.PublicHost || newConfig.PublicPort != oldConfig.PublicPort || newConfig.Advanced.NoFastEnable != oldConfig.Advanced.NoFastEnable || newConfig.MaxReconnectCount != oldConfig.MaxReconnectCount { + reloadProcesses = append(reloadProcesses, r.updateClustersWithGeneralConfig) + } + + r.Config = newConfig + r.publicHost = r.Config.PublicHost + r.publicPort = r.Config.PublicPort + for _, proc := range reloadProcesses { + if err := proc(ctx); err != nil { + return err + } + } + return nil +} + +func (r *Runner) SetupLogger(ctx context.Context) error { + if r.Config.Advanced.DebugLog { + log.SetLevel(log.LevelDebug) + } else { + log.SetLevel(log.LevelInfo) + } + log.SetLogSlots(r.Config.LogSlots) + if r.Config.NoAccessLog { + log.SetAccessLogSlots(-1) + } else { + log.SetAccessLogSlots(r.Config.AccessLogSlots) } + + r.Config.ApplyWebManifest(dsbManifest) + return nil } func (r *Runner) StopServer(ctx context.Context) { @@ -346,6 +471,8 @@ func (r *Runner) StopServer(ctx context.Context) { wg.Wait() log.TrWarnf("warn.httpserver.closing") r.server.Shutdown(shutCtx) + r.listener.Close() + r.listener = nil }() select { case <-shutDone: @@ -355,41 +482,8 @@ func (r *Runner) StopServer(ctx context.Context) { log.TrWarnf("warn.server.closed") } -func (r *Runner) InitServer() { - r.server = &http.Server{ - Addr: fmt.Sprintf("%s:%d", d.Config.Host, d.Config.Port), - ReadTimeout: 10 * time.Second, - IdleTimeout: 5 * time.Second, - Handler: r, - ErrorLog: log.ProxiedStdLog, - } -} - -func (r *Runner) InitClusters(ctx context.Context) { - var ( - dialer *net.Dialer - cache = r.Config.Cache.newCache() - ) - - _ = doh.NewResolver // TODO: use doh resolver - - r.cluster = NewCluster(ctx, - ClusterServerURL, - baseDir, - config.PublicHost, r.getPublicPort(), - config.ClusterId, config.ClusterSecret, - config.Byoc, dialer, - config.Storages, - cache, - ) - if err := r.cluster.Init(ctx); err != nil { - log.Errorf(Tr("error.init.failed"), err) - os.Exit(1) - } -} - func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { - if !r.hijacker.Enabled { + if !r.Config.Hijack.Enable { return } if !r.updating.CompareAndSwap(false, true) { @@ -409,7 +503,7 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol sem.Acquire() go func(rec database.FileRecord) { defer sem.Release() - r.cluster.database.SetFileRecord(rec) + r.database.SetFileRecord(rec) }(database.FileRecord{ Path: f.Path, Hash: f.Hash, @@ -421,11 +515,11 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol } func (r *Runner) InitSynchronizer(ctx context.Context) { - fileMap := make(map[string]*StorageFileInfo) + fileMap := make(map[string]*cluster.StorageFileInfo) for _, cr := range r.clusters { - log.Info(Tr("info.filelist.fetching"), cr.ID()) + log.TrInfof("info.filelist.fetching", cr.ID()) if err := cr.GetFileList(ctx, fileMap, true); err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) if errors.Is(err, context.Canceled) { return } @@ -433,33 +527,34 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } checkCount := -1 - heavyCheck := !config.Advanced.NoHeavyCheck - heavyCheckInterval := config.Advanced.HeavyCheckInterval + heavyCheck := !r.Config.Advanced.NoHeavyCheck + heavyCheckInterval := r.Config.Advanced.HeavyCheckInterval if heavyCheckInterval <= 0 { heavyCheck = false } - if !config.Advanced.SkipFirstSync { - if !r.cluster.SyncFiles(ctx, fileMap, false) { - return - } - go r.UpdateFileRecords(fileMap, nil) + // if !r.Config.Advanced.SkipFirstSync { + // if !r.cluster.SyncFiles(ctx, fileMap, false) { + // return + // } + // go r.UpdateFileRecords(fileMap, nil) - if !config.Advanced.NoGC { - go r.cluster.Gc() - } - } else if fl != nil { - if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { - return - } - } + // if !r.Config.Advanced.NoGC { + // go r.cluster.Gc() + // } + // } else + // if fl != nil { + // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { + // return + // } + // } createInterval(ctx, func() { - fileMap := make(map[string]*StorageFileInfo) + fileMap := make(map[string]*cluster.StorageFileInfo) for _, cr := range r.clusters { - log.Info(Tr("info.filelist.fetching"), cr.ID()) + log.TrInfof("info.filelist.fetching", cr.ID()) if err := cr.GetFileList(ctx, fileMap, false); err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) return } } @@ -472,104 +567,122 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { oldfileset := r.cluster.CloneFileset() if r.cluster.SyncFiles(ctx, fl, heavyCheck && checkCount == 0) { go r.UpdateFileRecords(fl, oldfileset) - if !config.Advanced.NoGC && !config.OnlyGcWhenStart { + if !r.Config.Advanced.NoGC && !r.Config.OnlyGcWhenStart { go r.cluster.Gc() } } - }, (time.Duration)(config.SyncInterval)*time.Minute) + }, (time.Duration)(r.Config.SyncInterval)*time.Minute) } -func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Listener) { - listener, err := net.Listen("tcp", r.Addr) +func (r *Runner) CreateHTTPListener(ctx context.Context) (*utils.HTTPTLSListener, error) { + addr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) + listener, err := net.Listen("tcp", addr) if err != nil { - log.Errorf(Tr("error.address.listen.failed"), r.Addr, err) - osExit(CodeEnvironmentError) + log.TrErrorf("error.address.listen.failed", addr, err) + return nil, err } if r.Config.ServeLimit.Enable { - limted := limited.NewLimitedListener(listener, config.ServeLimit.MaxConn, 0, config.ServeLimit.UploadRate*1024) + limted := limited.NewLimitedListener(listener, r.Config.ServeLimit.MaxConn, 0, r.Config.ServeLimit.UploadRate*1024) limted.SetMinWriteRate(1024) listener = limted } - tlsConfig := r.GenerateTLSConfig(ctx) - r.publicHosts = make([]string, 0, 2) - if tlsConfig != nil { - for _, cert := range tlsConfig.Certificates { - if h, err := parseCertCommonName(cert.Certificate[0]); err == nil { - r.publicHosts = append(r.publicHosts, strings.ToLower(h)) - } + if r.Config.UseCert { + var err error + r.tlsConfig, err = r.GenerateTLSConfig() + if err != nil { + log.Errorf("Failed to generate TLS config: %v", err) + return nil, err } - listener = utils.NewHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) } - r.listener = listener - return + return utils.NewHttpTLSListener(listener, r.tlsConfig), nil } -func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) { - if config.UseCert { - if len(config.Certificates) == 0 { - log.Error(Tr("error.cert.not.set")) - os.Exit(1) +func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { + if len(r.Config.Certificates) == 0 { + log.TrErrorf("error.cert.not.set") + return nil, errors.New("No certificate is defined") + } + tlsConfig := new(tls.Config) + tlsConfig.Certificates = make([]tls.Certificate, len(r.Config.Certificates)) + for i, c := range r.Config.Certificates { + var err error + tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + log.TrErrorf("error.cert.parse.failed", i, err) + return nil, err } - tlsConfig = new(tls.Config) - tlsConfig.Certificates = make([]tls.Certificate, len(config.Certificates)) - for i, c := range config.Certificates { - var err error - tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) - if err != nil { - log.Errorf(Tr("error.cert.parse.failed"), i, err) - os.Exit(1) - } + } + return tlsConfig, nil +} + +func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Config) (*tls.Config, error) { + certs := make([]tls.Certificate, 0) + for _, cr := range r.clusters { + if cr.Options().Byoc { + continue + } + log.TrInfof("info.cert.requesting", cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cert.request.failed", err) + continue } + cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.TrErrorf("error.cert.requested.parse.failed", err) + continue + } + certs = append(certs, cert) + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.TrInfof("info.cert.requested", certHost) } - if !config.Byoc { - for _, cr := range r.clusters { - log.Info(Tr("info.cert.requesting"), cr.ID()) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := cr.RequestCert(tctx) - cancel() - if err != nil { - log.Errorf(Tr("error.cert.request.failed"), err) - os.Exit(2) - } - if tlsConfig == nil { - tlsConfig = new(tls.Config) - } - var cert tls.Certificate - cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.Errorf(Tr("error.cert.requested.parse.failed"), err) - os.Exit(2) - } - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.Infof(Tr("info.cert.requested"), certHost) + if len(certs) == 0 { + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } else { + tlsConfig = tlsConfig.Clone() } + tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) } - r.tlsConfig = tlsConfig - return + return tlsConfig, nil } -func (r *Runner) EnableCluster(ctx context.Context) { - if config.Advanced.WaitBeforeEnable > 0 { - select { - case <-time.After(time.Second * (time.Duration)(config.Advanced.WaitBeforeEnable)): - case <-ctx.Done(): - return - } +// updateClustersWithGeneralConfig will re-enable all clusters with latest general config +func (r *Runner) updateClustersWithGeneralConfig(ctx context.Context) error { + gcfg := r.GetClusterGeneralConfig() + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + cr.Disable(ctx) + *cr.GeneralConfig() = gcfg + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) + return + } + }(cr) } + wg.Wait() + return nil +} - if config.Tunneler.Enable { - r.enableClusterByTunnel(ctx) - } else { - if err := r.cluster.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) - if ctx.Err() != nil { +func (r *Runner) EnableClusterAll(ctx context.Context) { + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) return } - osExit(CodeServerOrEnvionmentError) - } + }(cr) } + wg.Wait() } func (r *Runner) StartTunneler() { @@ -597,8 +710,8 @@ func (r *Runner) StartTunneler() { } func (r *Runner) RunTunneler(ctx context.Context) { - cmd := exec.CommandContext(ctx, config.Tunneler.TunnelProg) - log.Infof(Tr("info.tunnel.running"), cmd.String()) + cmd := exec.CommandContext(ctx, r.Config.Tunneler.TunnelProg) + log.TrInfof("info.tunnel.running", cmd.String()) var ( cmdOut, cmdErr io.ReadCloser err error @@ -606,15 +719,15 @@ func (r *Runner) RunTunneler(ctx context.Context) { cmd.Env = append(os.Environ(), "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) if cmdOut, err = cmd.StdoutPipe(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } if cmdErr, err = cmd.StderrPipe(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } if err = cmd.Start(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } type addrOut struct { @@ -623,11 +736,10 @@ func (r *Runner) RunTunneler(ctx context.Context) { } detectedCh := make(chan addrOut, 1) onLog := func(line []byte) { - res := config.Tunneler.outputRegex.FindSubmatch(line) - if res == nil { + tunnelHost, tunnelPort, ok := r.Config.Tunneler.MatchTunnelOutput(line) + if !ok { return } - tunnelHost, tunnelPort := res[config.Tunneler.hostOut], res[config.Tunneler.portOut] if len(tunnelHost) > 0 && tunnelHost[0] == '[' && tunnelHost[len(tunnelHost)-1] == ']' { // a IPv6 with port []: tunnelHost = tunnelHost[1 : len(tunnelHost)-1] } @@ -667,33 +779,11 @@ func (r *Runner) RunTunneler(ctx context.Context) { for { select { case addr := <-detectedCh: - log.Infof(Tr("info.tunnel.detected"), addr.host, addr.port) - r.cluster.publicPort = addr.port - if !r.cluster.byoc { - r.cluster.host = addr.host - } - strPort := strconv.Itoa((int)(r.getPublicPort())) - if spp, ok := r.listener.(interface{ SetPublicPort(port string) }); ok { - spp.SetPublicPort(strPort) - } - log.Infof(Tr("info.server.public.at"), net.JoinHostPort(addr.host, strPort), r.clusterSvr.Addr, r.getCertCount()) - if len(r.publicHosts) > 1 { - log.Info(Tr("info.server.alternative.hosts")) - for _, h := range r.publicHosts[1:] { - log.Infof("\t- https://%s", net.JoinHostPort(h, strPort)) - } - } - if !r.cluster.Enabled() { - shutCtx, cancel := context.WithTimeout(ctx, time.Minute) - r.cluster.Disable(shutCtx) - cancel() - } - if err := r.cluster.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) - if ctx.Err() != nil { - return - } - os.Exit(2) + log.TrInfof("info.tunnel.detected", addr.host, addr.port) + r.publicHost, r.publicPort = addr.host, addr.port + r.updateClustersWithGeneralConfig(ctx) + if ctx.Err() != nil { + return } case <-ctx.Done(): return diff --git a/storage/manager.go b/storage/manager.go index 79353dd6..164a34a7 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -57,6 +57,15 @@ func (m *Manager) Get(id string) Storage { return nil } +func (m *Manager) GetIndex(id string) int { + for i, s := range m.Storages { + if s.Id() == id { + return i + } + } + return -1 +} + func (m *Manager) GetFlavorString(storages []int) string { typeCount := make(map[string]int, 2) for _, i := range storages { diff --git a/sync.go b/sync.go index bded4ac1..12430b40 100644 --- a/sync.go +++ b/sync.go @@ -1,3 +1,5 @@ +//go:build ignore + /** * OpenBmclAPI (Golang Edition) * Copyright (C) 2024 Kevin Z diff --git a/util.go b/util.go index 539cf9a3..3bd2ba97 100644 --- a/util.go +++ b/util.go @@ -25,7 +25,6 @@ import ( "crypto/x509" "fmt" "io" - "math/rand" "net/http" "net/url" "os" @@ -83,54 +82,6 @@ func parseCertCommonName(body []byte) (string, error) { return cert.Subject.CommonName, nil } -func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) { - if leng <= 0 { - return false - } - start := randIntn(leng) - for i := start; i < leng; i++ { - if cb(i) { - return true - } - } - for i := 0; i < start; i++ { - if cb(i) { - return true - } - } - return false -} - -func forEachFromRandomIndexWithPossibility(poss []uint, total uint, cb func(i int) (done bool)) (done bool) { - leng := len(poss) - if leng == 0 { - return false - } - if total == 0 { - return forEachFromRandomIndex(leng, cb) - } - n := (uint)(randIntn((int)(total))) - start := 0 - for i, p := range poss { - if n < p { - start = i - break - } - n -= p - } - for i := start; i < leng; i++ { - if cb(i) { - return true - } - } - for i := 0; i < start; i++ { - if cb(i) { - return true - } - } - return false -} - func copyFile(src, dst string, mode os.FileMode) (err error) { var srcFd, dstFd *os.File if srcFd, err = os.Open(src); err != nil { diff --git a/utils/http.go b/utils/http.go index d83fbdbf..0412bf27 100644 --- a/utils/http.go +++ b/utils/http.go @@ -30,7 +30,6 @@ import ( "net/url" "path" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -290,13 +289,10 @@ func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) // Else it will just return the tls connection type HTTPTLSListener struct { net.Listener - TLSConfig *tls.Config + TLSConfig atomic.Pointer[tls.Config] + DoRedirect bool AllowUnsecure bool - mux sync.RWMutex - hosts []string - port string - accepting atomic.Bool acceptedCh chan net.Conn errCh chan error @@ -304,15 +300,17 @@ type HTTPTLSListener struct { var _ net.Listener = (*HTTPTLSListener)(nil) -func NewHttpTLSListener(l net.Listener, cfg *tls.Config, publicHosts []string, port uint16) net.Listener { - return &HTTPTLSListener{ - Listener: l, - TLSConfig: cfg, - hosts: publicHosts, - port: strconv.Itoa((int)(port)), +func NewHttpTLSListener(l net.Listener, cfg *tls.Config) *HTTPTLSListener { + h := &HTTPTLSListener{ + Listener: l, + DoRedirect: true, + AllowUnsecure: false, + acceptedCh: make(chan net.Conn, 1), errCh: make(chan error, 1), } + h.TLSConfig.Store(cfg) + return h } func (s *HTTPTLSListener) Close() (err error) { @@ -329,22 +327,7 @@ func (s *HTTPTLSListener) Close() (err error) { return } -func (s *HTTPTLSListener) SetPublicPort(port string) { - s.mux.Lock() - defer s.mux.Unlock() - s.port = port -} - -func (s *HTTPTLSListener) GetPublicPort() string { - s.mux.RLock() - defer s.mux.RUnlock() - return s.port -} - func (s *HTTPTLSListener) maybeHTTPConn(c *connHeadReader) (ishttp bool) { - if len(s.hosts) == 0 { - return false - } var buf [4096]byte i, n := 0, 0 READ_HEAD: @@ -389,6 +372,11 @@ func (s *HTTPTLSListener) accepter() { s.errCh <- err return } + tlsCfg := s.TLSConfig.Load() + if tlsCfg == nil { + s.acceptedCh <- conn + return + } go s.accepter() hr := &connHeadReader{Conn: conn} hr.SetReadDeadline(time.Now().Add(time.Second * 5)) @@ -396,7 +384,7 @@ func (s *HTTPTLSListener) accepter() { hr.SetReadDeadline(time.Time{}) if !ishttp { // if it's not a http connection, it must be a tls connection - s.acceptedCh <- tls.Server(hr, s.TLSConfig) + s.acceptedCh <- tls.Server(hr, tlsCfg) return } if s.AllowUnsecure { @@ -416,41 +404,13 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { return } conn.SetReadDeadline(time.Time{}) - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host - } - inhosts := false - if host != "" { - host = strings.ToLower(host) - for _, h := range s.hosts { - if h == "*" { - inhosts = true - break - } - if h, ok := strings.CutPrefix(h, "*."); ok { - if strings.HasSuffix(host, h) { - inhosts = true - break - } - } else if h == host { - inhosts = true - break - } - } - } + // host, _, err := net.SplitHostPort(req.Host) + // if err != nil { + // host = req.Host + // } u := *req.URL u.Scheme = "https" - if !inhosts { - host = "" - for _, h := range s.hosts { - if h != "*" && !strings.HasSuffix(h, "*.") { - host = h - break - } - } - } - if host == "" { + if !s.DoRedirect { body := strings.NewReader("Sent http request on https server") resp := &http.Response{ StatusCode: http.StatusBadRequest, @@ -468,7 +428,7 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { io.Copy(conn, body) return } - u.Host = net.JoinHostPort(host, s.GetPublicPort()) + // u.Host = net.JoinHostPort(host, s.GetPublicPort()) resp := &http.Response{ StatusCode: http.StatusPermanentRedirect, ProtoMajor: req.ProtoMajor,