diff --git a/api/bmclapi/hijacker.go b/api/bmclapi/hijacker.go
index f45d0b24..3218c2b0 100644
--- a/api/bmclapi/hijacker.go
+++ b/api/bmclapi/hijacker.go
@@ -17,7 +17,7 @@
* along with this program. If not, see .
*/
-package main
+package bmclapi
import (
"context"
diff --git a/cluster/cluster.go b/cluster/cluster.go
index 3f7943f2..51d2cf6d 100644
--- a/cluster/cluster.go
+++ b/cluster/cluster.go
@@ -70,9 +70,13 @@ type Cluster struct {
func NewCluster(
opts config.ClusterOptions, gcfg config.ClusterGeneralConfig,
- storageManager *storage.Manager, storages []int,
+ storageManager *storage.Manager,
statManager *StatManager,
) (cr *Cluster) {
+ storages := make([]int, len(opts.Storages))
+ for i, name := range opts.Storages {
+ storages[i] = storageManager.GetIndex(name)
+ }
cr = &Cluster{
opts: opts,
gcfg: gcfg,
@@ -120,10 +124,23 @@ func (cr *Cluster) AcceptHost(host string) bool {
return false
}
+func (cr *Cluster) Options() *config.ClusterOptions {
+ return &cr.opts
+}
+
+func (cr *Cluster) GeneralConfig() *config.ClusterGeneralConfig {
+ return &cr.gcfg
+}
+
// Init do setup on the cluster
// Init should only be called once during the cluster's whole life
// The context passed in only affect the logical of Init method
func (cr *Cluster) Init(ctx context.Context) error {
+ for i, ind := range cr.storages {
+ if ind == -1 {
+ return fmt.Errorf("Storage %q does not exists", cr.opts.Storages[i])
+ }
+ }
return nil
}
@@ -172,7 +189,7 @@ func (cr *Cluster) enable(ctx context.Context) error {
Host: cr.gcfg.PublicHost,
Port: cr.gcfg.PublicPort,
Version: build.ClusterVersion,
- Byoc: cr.gcfg.Byoc,
+ Byoc: cr.opts.Byoc,
NoFastEnable: cr.gcfg.NoFastEnable,
Flavor: ConfigFlavor{
Runtime: "golang/" + runtime.GOOS + "-" + runtime.GOARCH,
diff --git a/cluster/handler.go b/cluster/handler.go
index dd2151a4..4dbd3a18 100644
--- a/cluster/handler.go
+++ b/cluster/handler.go
@@ -36,7 +36,7 @@ import (
"github.com/LiterMC/go-openbmclapi/storage"
)
-func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string, size int64) {
+func (cr *Cluster) HandleFile(rw http.ResponseWriter, req *http.Request, hash string) {
defer log.RecoverPanic(nil)
if !cr.Enabled() {
@@ -88,6 +88,8 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st
api.SetAccessInfo(req, "cluster", cr.ID())
+ var size int64 = -1 // TODO: get the size
+
var (
sto storage.Storage
err error
@@ -121,7 +123,7 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
-func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size int) {
+func (cr *Cluster) HandleMeasure(rw http.ResponseWriter, req *http.Request, size int) {
if !cr.Enabled() {
// do not serve file if cluster is not enabled yet
http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable)
diff --git a/cluster/http.go b/cluster/http.go
index f66f3f54..4910444c 100644
--- a/cluster/http.go
+++ b/cluster/http.go
@@ -90,7 +90,7 @@ func (cr *Cluster) makeReqWithBody(
query url.Values, body io.Reader,
) (req *http.Request, err error) {
var u *url.URL
- if u, err = url.Parse(cr.opts.Prefix); err != nil {
+ if u, err = url.Parse(cr.opts.Server); err != nil {
return
}
u.Path = path.Join(u.Path, relpath)
diff --git a/cluster/socket.go b/cluster/socket.go
index 7b2b78f1..db238397 100644
--- a/cluster/socket.go
+++ b/cluster/socket.go
@@ -48,10 +48,10 @@ func (cr *Cluster) Connect(ctx context.Context) error {
}
engio, err := engine.NewSocket(engine.Options{
- Host: cr.opts.Prefix,
+ Host: cr.opts.Server,
Path: "/socket.io/",
ExtraHeaders: http.Header{
- "Origin": {cr.opts.Prefix},
+ "Origin": {cr.opts.Server},
"User-Agent": {build.ClusterUserAgent},
},
DialTimeout: time.Minute * 6,
diff --git a/config.go b/config.go
index 43298f58..b5259111 100644
--- a/config.go
+++ b/config.go
@@ -21,12 +21,20 @@ package main
import (
"bytes"
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
"gopkg.in/yaml.v3"
"github.com/LiterMC/go-openbmclapi/config"
+ "github.com/LiterMC/go-openbmclapi/log"
+ "github.com/LiterMC/go-openbmclapi/storage"
)
+const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com"
+
func migrateConfig(data []byte, cfg *config.Config) {
var oldConfig map[string]any
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
@@ -45,11 +53,13 @@ func migrateConfig(data []byte, cfg *config.Config) {
if oldConfig["clusters"].(map[string]any) == nil {
id, ok1 := oldConfig["cluster-id"].(string)
secret, ok2 := oldConfig["cluster-secret"].(string)
- if ok1 && ok2 {
- cfg.Clusters = map[string]ClusterItem{
+ publicHost, ok3 := oldConfig["public-host"].(string)
+ if ok1 && ok2 && ok3 {
+ cfg.Clusters = map[string]config.ClusterOptions{
"main": {
- Id: id,
- Secret: secret,
+ Id: id,
+ Secret: secret,
+ PublicHosts: []string{publicHost},
},
}
}
@@ -67,7 +77,7 @@ func readAndRewriteConfig() (cfg *config.Config, err error) {
log.TrErrorf("error.config.read.failed", err)
os.Exit(1)
}
- log.TrError("error.config.not.exists")
+ log.TrErrorf("error.config.not.exists")
notexists = true
} else {
migrateConfig(data, cfg)
@@ -76,15 +86,18 @@ func readAndRewriteConfig() (cfg *config.Config, err error) {
os.Exit(1)
}
if len(cfg.Clusters) == 0 {
- cfg.Clusters = map[string]ClusterItem{
+ cfg.Clusters = map[string]config.ClusterOptions{
"main": {
- Id: "${CLUSTER_ID}",
- Secret: "${CLUSTER_SECRET}",
+ Id: "${CLUSTER_ID}",
+ Secret: "${CLUSTER_SECRET}",
+ PublicHosts: []string{},
+ Server: DefaultBMCLAPIServer,
+ SkipSignatureCheck: false,
},
}
}
if len(cfg.Certificates) == 0 {
- cfg.Certificates = []CertificateConfig{
+ cfg.Certificates = []config.CertificateConfig{
{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
@@ -123,12 +136,6 @@ func readAndRewriteConfig() (cfg *config.Config, err error) {
os.Exit(1)
}
ids[s.Id] = i
- if s.Cluster != "" && s.Cluster != "-" {
- if _, ok := cfg.Clusters[s.Cluster]; !ok {
- log.Errorf("Storage %q is trying to connect to a not exists cluster %q.", s.Id, s.Cluster)
- os.Exit(1)
- }
- }
}
}
@@ -173,7 +180,8 @@ func readAndRewriteConfig() (cfg *config.Config, err error) {
os.Exit(1)
}
if notexists {
- log.TrError("error.config.created")
+ log.TrErrorf("error.config.created")
+ return nil, errors.New("Please edit the config before continue!")
}
return
}
diff --git a/config/config.go b/config/config.go
index 5c843c9d..ec18ca2f 100644
--- a/config/config.go
+++ b/config/config.go
@@ -36,7 +36,6 @@ type Config struct {
PublicPort uint16 `yaml:"public-port"`
Host string `yaml:"host"`
Port uint16 `yaml:"port"`
- Byoc bool `yaml:"byoc"`
UseCert bool `yaml:"use-cert"`
TrustedXForwardedFor bool `yaml:"trusted-x-forwarded-for"`
@@ -65,7 +64,7 @@ type Config struct {
Advanced AdvancedConfig `yaml:"advanced"`
}
-func (cfg *Config) applyWebManifest(manifest map[string]any) {
+func (cfg *Config) ApplyWebManifest(manifest map[string]any) {
if cfg.Dashboard.Enable {
manifest["name"] = cfg.Dashboard.PwaName
manifest["short_name"] = cfg.Dashboard.PwaShortName
@@ -79,7 +78,6 @@ func NewDefaultConfig() *Config {
PublicPort: 0,
Host: "0.0.0.0",
Port: 4000,
- Byoc: false,
TrustedXForwardedFor: false,
OnlyGcWhenStart: false,
diff --git a/config/server.go b/config/server.go
index 32963a58..a1763a8f 100644
--- a/config/server.go
+++ b/config/server.go
@@ -34,15 +34,16 @@ import (
type ClusterOptions struct {
Id string `json:"id" yaml:"id"`
Secret string `json:"secret" yaml:"secret"`
+ Byoc bool `json:"byoc"`
PublicHosts []string `json:"public-hosts" yaml:"public-hosts"`
- Prefix string `json:"prefix" yaml:"prefix"`
+ Server string `json:"server" yaml:"server"`
SkipSignatureCheck bool `json:"skip-signature-check" yaml:"skip-signature-check"`
+ Storages []string `json:"storages" yaml:"storages"`
}
type ClusterGeneralConfig struct {
PublicHost string `json:"public-host"`
PublicPort uint16 `json:"public-port"`
- Byoc bool `json:"byoc"`
NoFastEnable bool `json:"no-fast-enable"`
MaxReconnectCount int `json:"max-reconnect-count"`
}
@@ -77,6 +78,10 @@ type CacheConfig struct {
newCache func() cache.Cache `yaml:"-"`
}
+func (c *CacheConfig) NewCache() cache.Cache {
+ return c.newCache()
+}
+
func (c *CacheConfig) UnmarshalYAML(n *yaml.Node) (err error) {
var cfg struct {
Type string `yaml:"type"`
@@ -148,3 +153,11 @@ func (c *TunnelConfig) UnmarshalYAML(n *yaml.Node) (err error) {
}
return
}
+
+func (c *TunnelConfig) MatchTunnelOutput(line []byte) (host, port []byte, ok bool) {
+ res := c.outputRegex.FindSubmatch(line)
+ if res == nil {
+ return
+ }
+ return res[c.hostOut], res[c.portOut], true
+}
diff --git a/dashboard.go b/dashboard.go
index f8fbcaef..3bcd689d 100644
--- a/dashboard.go
+++ b/dashboard.go
@@ -60,13 +60,14 @@ var dsbManifest = func() (dsbManifest map[string]any) {
return
}()
-func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request, pth string) {
+func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet && req.Method != http.MethodHead {
rw.Header().Set("Allow", http.MethodGet+", "+http.MethodHead)
http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed)
return
}
acceptEncoding := utils.SplitCSV(req.Header.Get("Accept-Encoding"))
+ pth := strings.TrimPrefix(req.URL.Path, "/")
switch pth {
case "":
break
diff --git a/handler.go b/handler.go
index bedc3fb6..97e4e85a 100644
--- a/handler.go
+++ b/handler.go
@@ -24,16 +24,11 @@ import (
"context"
"crypto"
_ "embed"
- "encoding/base64"
"encoding/hex"
"encoding/json"
- "errors"
"fmt"
- "io"
"net"
"net/http"
- "net/textproto"
- "os"
"strconv"
"strings"
"time"
@@ -43,9 +38,9 @@ import (
"github.com/LiterMC/go-openbmclapi/api"
"github.com/LiterMC/go-openbmclapi/api/v0"
"github.com/LiterMC/go-openbmclapi/internal/build"
+ "github.com/LiterMC/go-openbmclapi/internal/gosrc"
"github.com/LiterMC/go-openbmclapi/limited"
"github.com/LiterMC/go-openbmclapi/log"
- "github.com/LiterMC/go-openbmclapi/storage"
"github.com/LiterMC/go-openbmclapi/utils"
)
@@ -108,13 +103,13 @@ var wsUpgrader = &websocket.Upgrader{
}
func (r *Runner) GetHandler() http.Handler {
- r.apiRateLimiter = limited.NewAPIRateMiddleWare(RealAddrCtxKey, loggedUserKey)
- r.apiRateLimiter.SetAnonymousRateLimit(r.RateLimit.Anonymous)
- r.apiRateLimiter.SetLoggedRateLimit(r.RateLimit.Logged)
+ r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */)
+ r.apiRateLimiter.SetAnonymousRateLimit(r.Config.RateLimit.Anonymous)
+ r.apiRateLimiter.SetLoggedRateLimit(r.Config.RateLimit.Logged)
r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader))
- r.hijackHandler = http.StripPrefix("/bmclapi", r.hijackProxy)
+ r.hijackHandler = http.StripPrefix("/bmclapi", r.hijacker)
- handler := utils.NewHttpMiddleWareHandler(r)
+ handler := utils.NewHttpMiddleWareHandler((http.HandlerFunc)(r.serveHTTP))
// recover panic and log it
handler.UseFunc(func(rw http.ResponseWriter, req *http.Request, next http.Handler) {
defer log.RecoverPanic(func(any) {
@@ -124,67 +119,57 @@ func (r *Runner) GetHandler() http.Handler {
})
handler.Use(r.apiRateLimiter)
- handler.Use(r.getRecordMiddleWare())
+ handler.UseFunc(r.recordMiddleWare)
return handler
}
-func (r *Runner) getRecordMiddleWare() utils.MiddleWareFunc {
- type record struct {
- used float64
- bytes float64
- ua string
- skipUA bool
+func (r *Runner) recordMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) {
+ ua := req.UserAgent()
+ var addr string
+ if r.Config.TrustedXForwardedFor {
+ // X-Forwarded-For: , ,
+ adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",")
+ addr = strings.TrimSpace(adr)
}
- recordCh := make(chan record, 1024)
-
- return func(rw http.ResponseWriter, req *http.Request, next http.Handler) {
- ua := req.UserAgent()
- var addr string
- if config.TrustedXForwardedFor {
- // X-Forwarded-For: , ,
- adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",")
- addr = strings.TrimSpace(adr)
- }
- if addr == "" {
- addr, _, _ = net.SplitHostPort(req.RemoteAddr)
- }
- srw := utils.WrapAsStatusResponseWriter(rw)
- start := time.Now()
+ if addr == "" {
+ addr, _, _ = net.SplitHostPort(req.RemoteAddr)
+ }
+ srw := utils.WrapAsStatusResponseWriter(rw)
+ start := time.Now()
- log.LogAccess(log.LevelDebug, &preAccessRecord{
- Type: "pre-access",
- Time: start,
- Addr: addr,
- Method: req.Method,
- URI: req.RequestURI,
- UA: ua,
- })
+ log.LogAccess(log.LevelDebug, &preAccessRecord{
+ Type: "pre-access",
+ Time: start,
+ Addr: addr,
+ Method: req.Method,
+ URI: req.RequestURI,
+ UA: ua,
+ })
- extraInfoMap := make(map[string]any)
- ctx := req.Context()
- ctx = context.WithValue(ctx, RealAddrCtxKey, addr)
- ctx = context.WithValue(ctx, RealPathCtxKey, req.URL.Path)
- ctx = context.WithValue(ctx, AccessLogExtraCtxKey, extraInfoMap)
- req = req.WithContext(ctx)
- next.ServeHTTP(srw, req)
+ extraInfoMap := make(map[string]any)
+ ctx := req.Context()
+ ctx = context.WithValue(ctx, api.RealAddrCtxKey, addr)
+ ctx = context.WithValue(ctx, api.RealPathCtxKey, req.URL.Path)
+ ctx = context.WithValue(ctx, api.AccessLogExtraCtxKey, extraInfoMap)
+ req = req.WithContext(ctx)
+ next.ServeHTTP(srw, req)
- used := time.Since(start)
- accRec := &accessRecord{
- Type: "access",
- Status: srw.Status,
- Used: used,
- Content: srw.Wrote,
- Addr: addr,
- Proto: req.Proto,
- Method: req.Method,
- URI: req.RequestURI,
- UA: ua,
- }
- if len(extraInfoMap) > 0 {
- accRec.Extra = extraInfoMap
- }
- log.LogAccess(log.LevelInfo, accRec)
+ used := time.Since(start)
+ accRec := &accessRecord{
+ Type: "access",
+ Status: srw.Status,
+ Used: used,
+ Content: srw.Wrote,
+ Addr: addr,
+ Proto: req.Proto,
+ Method: req.Method,
+ URI: req.RequestURI,
+ UA: ua,
+ }
+ if len(extraInfoMap) > 0 {
+ accRec.Extra = extraInfoMap
}
+ log.LogAccess(log.LevelInfo, accRec)
}
var emptyHashes = func() (hashes map[string]struct{}) {
@@ -202,11 +187,11 @@ var emptyHashes = func() (hashes map[string]struct{}) {
//go:embed robots.txt
var robotTxtContent string
-func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+func (r *Runner) serveHTTP(rw http.ResponseWriter, req *http.Request) {
method := req.Method
u := req.URL
- rw.Header().Set("X-Powered-By", HeaderXPoweredBy)
+ rw.Header().Set("X-Powered-By", build.HeaderXPoweredBy)
rawpath := u.EscapedPath()
switch {
@@ -249,7 +234,7 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
for _, cr := range r.clusters {
if cr.AcceptHost(req.Host) {
- cr.HandleFile(rw, req, hash)
+ cr.HandleMeasure(rw, req, size)
return
}
}
@@ -264,23 +249,24 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
case "v0":
r.handlerAPIv0.ServeHTTP(rw, req)
return
- case "v1":
- r.handlerAPIv1.ServeHTTP(rw, req)
- return
+ // case "v1":
+ // r.handlerAPIv1.ServeHTTP(rw, req)
+ // return
}
case rawpath == "/" || rawpath == "/dashboard":
http.Redirect(rw, req, "/dashboard/", http.StatusFound)
return
case strings.HasPrefix(rawpath, "/dashboard/"):
- if !r.DashboardEnabled {
+ if !r.Config.Dashboard.Enable {
http.NotFound(rw, req)
return
}
- pth := rawpath[len("/dashboard/"):]
- r.serveDashboard(rw, req, pth)
+ req2 := gosrc.RequestStripPrefix(req, "/dashboard")
+ r.serveDashboard(rw, req2)
return
case strings.HasPrefix(rawpath, "/bmclapi/"):
- r.hijackHandler.ServeHTTP(rw, req)
+ req2 := gosrc.RequestStripPrefix(req, "/bmclapi")
+ r.hijackHandler.ServeHTTP(rw, req2)
return
}
http.NotFound(rw, req)
diff --git a/internal/gosrc/httpstrip.go b/internal/gosrc/httpstrip.go
new file mode 100644
index 00000000..8dcc129a
--- /dev/null
+++ b/internal/gosrc/httpstrip.go
@@ -0,0 +1,26 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gosrc
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+func RequestStripPrefix(r *http.Request, prefix string) *http.Request {
+ p, ok := strings.CutPrefix(r.URL.Path, prefix)
+ rp, ok2 := strings.CutPrefix(r.URL.RawPath, prefix)
+ if ok && (ok2 || r.URL.RawPath == "") {
+ r2 := new(http.Request)
+ *r2 = *r
+ r2.URL = new(url.URL)
+ *r2.URL = *r.URL
+ r2.URL.Path = p
+ r2.URL.RawPath = rp
+ return r2
+ }
+ return nil
+}
diff --git a/main.go b/main.go
index fff4384d..55f42dc4 100644
--- a/main.go
+++ b/main.go
@@ -36,6 +36,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"sync/atomic"
"syscall"
"time"
@@ -44,6 +45,7 @@ import (
doh "github.com/libp2p/go-doh-resolver"
+ "github.com/LiterMC/go-openbmclapi/api/bmclapi"
"github.com/LiterMC/go-openbmclapi/cluster"
"github.com/LiterMC/go-openbmclapi/config"
"github.com/LiterMC/go-openbmclapi/database"
@@ -51,7 +53,9 @@ import (
"github.com/LiterMC/go-openbmclapi/lang"
"github.com/LiterMC/go-openbmclapi/limited"
"github.com/LiterMC/go-openbmclapi/log"
+ "github.com/LiterMC/go-openbmclapi/storage"
subcmds "github.com/LiterMC/go-openbmclapi/sub_commands"
+ "github.com/LiterMC/go-openbmclapi/utils"
_ "github.com/LiterMC/go-openbmclapi/lang/en"
_ "github.com/LiterMC/go-openbmclapi/lang/zh"
@@ -124,18 +128,8 @@ func main() {
} else {
r.Config = config
}
- if r.Config.Advanced.DebugLog {
- log.SetLevel(log.LevelDebug)
- } else {
- log.SetLevel(log.LevelInfo)
- }
- if r.Config.NoAccessLog {
- log.SetAccessLogSlots(-1)
- } else {
- log.SetAccessLogSlots(r.Config.AccessLogSlots)
- }
- r.Config.applyWebManifest(dsbManifest)
+ r.SetupLogger(ctx)
log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion)
@@ -143,13 +137,37 @@ func main() {
r.StartTunneler()
}
r.InitServer()
+ r.StartServer(ctx)
+
r.InitClusters(ctx)
go func(ctx context.Context) {
defer log.RecordPanic()
- if !r.cluster.Connect(ctx) {
- osExit(CodeClientOrServerError)
+ var wg sync.WaitGroup
+ errs := make([]error, len(r.clusters))
+ {
+ i := 0
+ for _, cr := range r.clusters {
+ i++
+ go func(i int, cr *cluster.Cluster) {
+ defer wg.Done()
+ errs[i] = cr.Connect(ctx)
+ }(i, cr)
+ }
+ }
+ wg.Wait()
+ if ctx.Err() != nil {
+ return
+ }
+
+ {
+ var err error
+ r.tlsConfig, err = r.PatchTLSWithClusterCert(ctx, r.tlsConfig)
+ if err != nil {
+ return
+ }
+ r.listener.TLSConfig.Store(r.tlsConfig)
}
firstSyncDone := make(chan struct{}, 0)
@@ -159,30 +177,11 @@ func main() {
r.InitSynchronizer(ctx)
}()
- listener := r.CreateHTTPServerListener(ctx)
- go func(listener net.Listener) {
- defer listener.Close()
- if err := r.clusterSvr.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
- log.Error("Error when serving:", err)
- os.Exit(1)
- }
- }(listener)
-
- var publicHost string
- if len(r.publicHosts) == 0 {
- publicHost = config.PublicHost
- } else {
- publicHost = r.publicHosts[0]
- }
- if !config.Tunneler.Enable {
+ if !r.Config.Tunneler.Enable {
strPort := strconv.Itoa((int)(r.getPublicPort()))
- log.TrInfof("info.server.public.at", net.JoinHostPort(publicHost, strPort), r.clusterSvr.Addr, r.getCertCount())
- if len(r.publicHosts) > 1 {
- log.TrInfof("info.server.alternative.hosts")
- for _, h := range r.publicHosts[1:] {
- log.Infof("\t- https://%s", net.JoinHostPort(h, strPort))
- }
- }
+ pubAddr := net.JoinHostPort(r.Config.PublicHost, strPort)
+ localAddr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port)))
+ log.TrInfof("info.server.public.at", pubAddr, localAddr, r.getCertCount())
}
log.TrInfof("info.wait.first.sync")
@@ -192,14 +191,14 @@ func main() {
return
}
- r.EnableCluster(ctx)
+ // r.EnableCluster(ctx)
}(ctx)
- code := r.DoSignals(cancel)
+ code := r.ListenSignals(ctx, cancel)
if code != 0 {
log.TrErrorf("program.exited", code)
log.TrErrorf("error.exit.please.read.faq")
- if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows {
+ if runtime.GOOS == "windows" && !r.Config.Advanced.DoNotOpenFAQOnWindows {
// log.TrWarnf("warn.exit.detected.windows.open.browser")
// cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq")
// cmd.Start()
@@ -212,12 +211,21 @@ func main() {
type Runner struct {
Config *config.Config
- clusters map[string]*Cluster
- server *http.Server
+ clusters map[string]*cluster.Cluster
+ apiRateLimiter *limited.APIRateMiddleWare
+ storageManager *storage.Manager
+ statManager *cluster.StatManager
+ hijacker *bmclapi.HjProxy
+ database database.DB
+
+ server *http.Server
+ handlerAPIv0 http.Handler
+ hijackHandler http.Handler
tlsConfig *tls.Config
- listener net.Listener
- publicHosts []string
+ publicHost string
+ publicPort uint16
+ listener *utils.HTTPTLSListener
reloading atomic.Bool
updating atomic.Bool
@@ -225,8 +233,8 @@ type Runner struct {
}
func (r *Runner) getPublicPort() uint16 {
- if r.Config.PublicPort > 0 {
- return r.Config.PublicPort
+ if r.publicPort > 0 {
+ return r.publicPort
}
return r.Config.Port
}
@@ -238,7 +246,80 @@ func (r *Runner) getCertCount() int {
return len(r.tlsConfig.Certificates)
}
-func (r *Runner) DoSignals(cancel context.CancelFunc) int {
+func (r *Runner) InitServer() {
+ r.server = &http.Server{
+ ReadTimeout: 10 * time.Second,
+ IdleTimeout: 5 * time.Second,
+ Handler: r.GetHandler(),
+ ErrorLog: log.ProxiedStdLog,
+ }
+}
+
+// StartServer will start the HTTP server
+// If a server is already running on an old listener, the listener will be closed.
+func (r *Runner) StartServer(ctx context.Context) error {
+ htListener, err := r.CreateHTTPListener(ctx)
+ if err != nil {
+ return err
+ }
+ if r.listener != nil {
+ r.listener.Close()
+ }
+ r.listener = htListener
+ go func() {
+ defer htListener.Close()
+ if err := r.server.Serve(htListener); !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) {
+ log.Error("Error when serving:", err)
+ os.Exit(1)
+ }
+ }()
+ return nil
+}
+
+func (r *Runner) GetClusterGeneralConfig() config.ClusterGeneralConfig {
+ return config.ClusterGeneralConfig{
+ PublicHost: r.publicHost,
+ PublicPort: r.getPublicPort(),
+ NoFastEnable: r.Config.Advanced.NoFastEnable,
+ MaxReconnectCount: r.Config.MaxReconnectCount,
+ }
+}
+
+func (r *Runner) InitClusters(ctx context.Context) {
+ // var (
+ // dialer *net.Dialer
+ // cache = r.Config.Cache.NewCache()
+ // )
+
+ _ = doh.NewResolver // TODO: use doh resolver
+
+ r.clusters = make(map[string]*cluster.Cluster)
+ gcfg := r.GetClusterGeneralConfig()
+ for name, opts := range r.Config.Clusters {
+ cr := cluster.NewCluster(opts, gcfg, r.storageManager, r.statManager)
+ if err := cr.Init(ctx); err != nil {
+ log.TrErrorf("error.init.failed", err)
+ } else {
+ r.clusters[name] = cr
+ }
+ }
+
+ // r.cluster = NewCluster(ctx,
+ // ClusterServerURL,
+ // baseDir,
+ // config.PublicHost, r.getPublicPort(),
+ // config.ClusterId, config.ClusterSecret,
+ // config.Byoc, dialer,
+ // config.Storages,
+ // cache,
+ // )
+ // if err := r.cluster.Init(ctx); err != nil {
+ // log.TrErrorf("error.init.failed"), err)
+ // os.Exit(1)
+ // }
+}
+
+func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int {
signalCh := make(chan os.Signal, 1)
log.Debugf("Receiving signals")
signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
@@ -290,7 +371,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int {
}
}
case syscall.SIGHUP:
- go r.ReloadConfig()
+ go r.ReloadConfig(ctx)
default:
cancel()
if forceStop == nil {
@@ -312,7 +393,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int {
return 0
}
-func (r *Runner) ReloadConfig() {
+func (r *Runner) ReloadConfig(ctx context.Context) {
if r.reloading.CompareAndSwap(false, true) {
log.Error("Config is already reloading!")
return
@@ -321,10 +402,54 @@ func (r *Runner) ReloadConfig() {
config, err := readAndRewriteConfig()
if err != nil {
- log.Errorf("Config error: %s", err)
+ log.Errorf("Config error: %v", err)
} else {
- r.Config = config
+ if err := r.updateConfig(ctx, config); err != nil {
+ log.Errorf("Error when reloading config: %v", err)
+ }
+ }
+}
+
+func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error {
+ oldConfig := r.Config
+ reloadProcesses := make([]func(context.Context) error, 0, 8)
+
+ if newConfig.LogSlots != oldConfig.LogSlots || newConfig.NoAccessLog != oldConfig.NoAccessLog || newConfig.AccessLogSlots != oldConfig.AccessLogSlots || newConfig.Advanced.DebugLog != oldConfig.Advanced.DebugLog {
+ reloadProcesses = append(reloadProcesses, r.SetupLogger)
+ }
+ if newConfig.Host != oldConfig.Host || newConfig.Port != oldConfig.Port {
+ reloadProcesses = append(reloadProcesses, r.StartServer)
+ }
+ if newConfig.PublicHost != oldConfig.PublicHost || newConfig.PublicPort != oldConfig.PublicPort || newConfig.Advanced.NoFastEnable != oldConfig.Advanced.NoFastEnable || newConfig.MaxReconnectCount != oldConfig.MaxReconnectCount {
+ reloadProcesses = append(reloadProcesses, r.updateClustersWithGeneralConfig)
+ }
+
+ r.Config = newConfig
+ r.publicHost = r.Config.PublicHost
+ r.publicPort = r.Config.PublicPort
+ for _, proc := range reloadProcesses {
+ if err := proc(ctx); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (r *Runner) SetupLogger(ctx context.Context) error {
+ if r.Config.Advanced.DebugLog {
+ log.SetLevel(log.LevelDebug)
+ } else {
+ log.SetLevel(log.LevelInfo)
+ }
+ log.SetLogSlots(r.Config.LogSlots)
+ if r.Config.NoAccessLog {
+ log.SetAccessLogSlots(-1)
+ } else {
+ log.SetAccessLogSlots(r.Config.AccessLogSlots)
}
+
+ r.Config.ApplyWebManifest(dsbManifest)
+ return nil
}
func (r *Runner) StopServer(ctx context.Context) {
@@ -346,6 +471,8 @@ func (r *Runner) StopServer(ctx context.Context) {
wg.Wait()
log.TrWarnf("warn.httpserver.closing")
r.server.Shutdown(shutCtx)
+ r.listener.Close()
+ r.listener = nil
}()
select {
case <-shutDone:
@@ -355,41 +482,8 @@ func (r *Runner) StopServer(ctx context.Context) {
log.TrWarnf("warn.server.closed")
}
-func (r *Runner) InitServer() {
- r.server = &http.Server{
- Addr: fmt.Sprintf("%s:%d", d.Config.Host, d.Config.Port),
- ReadTimeout: 10 * time.Second,
- IdleTimeout: 5 * time.Second,
- Handler: r,
- ErrorLog: log.ProxiedStdLog,
- }
-}
-
-func (r *Runner) InitClusters(ctx context.Context) {
- var (
- dialer *net.Dialer
- cache = r.Config.Cache.newCache()
- )
-
- _ = doh.NewResolver // TODO: use doh resolver
-
- r.cluster = NewCluster(ctx,
- ClusterServerURL,
- baseDir,
- config.PublicHost, r.getPublicPort(),
- config.ClusterId, config.ClusterSecret,
- config.Byoc, dialer,
- config.Storages,
- cache,
- )
- if err := r.cluster.Init(ctx); err != nil {
- log.Errorf(Tr("error.init.failed"), err)
- os.Exit(1)
- }
-}
-
func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) {
- if !r.hijacker.Enabled {
+ if !r.Config.Hijack.Enable {
return
}
if !r.updating.CompareAndSwap(false, true) {
@@ -409,7 +503,7 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol
sem.Acquire()
go func(rec database.FileRecord) {
defer sem.Release()
- r.cluster.database.SetFileRecord(rec)
+ r.database.SetFileRecord(rec)
}(database.FileRecord{
Path: f.Path,
Hash: f.Hash,
@@ -421,11 +515,11 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol
}
func (r *Runner) InitSynchronizer(ctx context.Context) {
- fileMap := make(map[string]*StorageFileInfo)
+ fileMap := make(map[string]*cluster.StorageFileInfo)
for _, cr := range r.clusters {
- log.Info(Tr("info.filelist.fetching"), cr.ID())
+ log.TrInfof("info.filelist.fetching", cr.ID())
if err := cr.GetFileList(ctx, fileMap, true); err != nil {
- log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err)
+ log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err)
if errors.Is(err, context.Canceled) {
return
}
@@ -433,33 +527,34 @@ func (r *Runner) InitSynchronizer(ctx context.Context) {
}
checkCount := -1
- heavyCheck := !config.Advanced.NoHeavyCheck
- heavyCheckInterval := config.Advanced.HeavyCheckInterval
+ heavyCheck := !r.Config.Advanced.NoHeavyCheck
+ heavyCheckInterval := r.Config.Advanced.HeavyCheckInterval
if heavyCheckInterval <= 0 {
heavyCheck = false
}
- if !config.Advanced.SkipFirstSync {
- if !r.cluster.SyncFiles(ctx, fileMap, false) {
- return
- }
- go r.UpdateFileRecords(fileMap, nil)
+ // if !r.Config.Advanced.SkipFirstSync {
+ // if !r.cluster.SyncFiles(ctx, fileMap, false) {
+ // return
+ // }
+ // go r.UpdateFileRecords(fileMap, nil)
- if !config.Advanced.NoGC {
- go r.cluster.Gc()
- }
- } else if fl != nil {
- if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil {
- return
- }
- }
+ // if !r.Config.Advanced.NoGC {
+ // go r.cluster.Gc()
+ // }
+ // } else
+ // if fl != nil {
+ // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil {
+ // return
+ // }
+ // }
createInterval(ctx, func() {
- fileMap := make(map[string]*StorageFileInfo)
+ fileMap := make(map[string]*cluster.StorageFileInfo)
for _, cr := range r.clusters {
- log.Info(Tr("info.filelist.fetching"), cr.ID())
+ log.TrInfof("info.filelist.fetching", cr.ID())
if err := cr.GetFileList(ctx, fileMap, false); err != nil {
- log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err)
+ log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err)
return
}
}
@@ -472,104 +567,122 @@ func (r *Runner) InitSynchronizer(ctx context.Context) {
oldfileset := r.cluster.CloneFileset()
if r.cluster.SyncFiles(ctx, fl, heavyCheck && checkCount == 0) {
go r.UpdateFileRecords(fl, oldfileset)
- if !config.Advanced.NoGC && !config.OnlyGcWhenStart {
+ if !r.Config.Advanced.NoGC && !r.Config.OnlyGcWhenStart {
go r.cluster.Gc()
}
}
- }, (time.Duration)(config.SyncInterval)*time.Minute)
+ }, (time.Duration)(r.Config.SyncInterval)*time.Minute)
}
-func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Listener) {
- listener, err := net.Listen("tcp", r.Addr)
+func (r *Runner) CreateHTTPListener(ctx context.Context) (*utils.HTTPTLSListener, error) {
+ addr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port)))
+ listener, err := net.Listen("tcp", addr)
if err != nil {
- log.Errorf(Tr("error.address.listen.failed"), r.Addr, err)
- osExit(CodeEnvironmentError)
+ log.TrErrorf("error.address.listen.failed", addr, err)
+ return nil, err
}
if r.Config.ServeLimit.Enable {
- limted := limited.NewLimitedListener(listener, config.ServeLimit.MaxConn, 0, config.ServeLimit.UploadRate*1024)
+ limted := limited.NewLimitedListener(listener, r.Config.ServeLimit.MaxConn, 0, r.Config.ServeLimit.UploadRate*1024)
limted.SetMinWriteRate(1024)
listener = limted
}
- tlsConfig := r.GenerateTLSConfig(ctx)
- r.publicHosts = make([]string, 0, 2)
- if tlsConfig != nil {
- for _, cert := range tlsConfig.Certificates {
- if h, err := parseCertCommonName(cert.Certificate[0]); err == nil {
- r.publicHosts = append(r.publicHosts, strings.ToLower(h))
- }
+ if r.Config.UseCert {
+ var err error
+ r.tlsConfig, err = r.GenerateTLSConfig()
+ if err != nil {
+ log.Errorf("Failed to generate TLS config: %v", err)
+ return nil, err
}
- listener = utils.NewHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort())
}
- r.listener = listener
- return
+ return utils.NewHttpTLSListener(listener, r.tlsConfig), nil
}
-func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) {
- if config.UseCert {
- if len(config.Certificates) == 0 {
- log.Error(Tr("error.cert.not.set"))
- os.Exit(1)
+func (r *Runner) GenerateTLSConfig() (*tls.Config, error) {
+ if len(r.Config.Certificates) == 0 {
+ log.TrErrorf("error.cert.not.set")
+ return nil, errors.New("No certificate is defined")
+ }
+ tlsConfig := new(tls.Config)
+ tlsConfig.Certificates = make([]tls.Certificate, len(r.Config.Certificates))
+ for i, c := range r.Config.Certificates {
+ var err error
+ tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key)
+ if err != nil {
+ log.TrErrorf("error.cert.parse.failed", i, err)
+ return nil, err
}
- tlsConfig = new(tls.Config)
- tlsConfig.Certificates = make([]tls.Certificate, len(config.Certificates))
- for i, c := range config.Certificates {
- var err error
- tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key)
- if err != nil {
- log.Errorf(Tr("error.cert.parse.failed"), i, err)
- os.Exit(1)
- }
+ }
+ return tlsConfig, nil
+}
+
+func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Config) (*tls.Config, error) {
+ certs := make([]tls.Certificate, 0)
+ for _, cr := range r.clusters {
+ if cr.Options().Byoc {
+ continue
+ }
+ log.TrInfof("info.cert.requesting", cr.ID())
+ tctx, cancel := context.WithTimeout(ctx, time.Minute*10)
+ pair, err := cr.RequestCert(tctx)
+ cancel()
+ if err != nil {
+ log.TrErrorf("error.cert.request.failed", err)
+ continue
}
+ cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key))
+ if err != nil {
+ log.TrErrorf("error.cert.requested.parse.failed", err)
+ continue
+ }
+ certs = append(certs, cert)
+ certHost, _ := parseCertCommonName(cert.Certificate[0])
+ log.TrInfof("info.cert.requested", certHost)
}
- if !config.Byoc {
- for _, cr := range r.clusters {
- log.Info(Tr("info.cert.requesting"), cr.ID())
- tctx, cancel := context.WithTimeout(ctx, time.Minute*10)
- pair, err := cr.RequestCert(tctx)
- cancel()
- if err != nil {
- log.Errorf(Tr("error.cert.request.failed"), err)
- os.Exit(2)
- }
- if tlsConfig == nil {
- tlsConfig = new(tls.Config)
- }
- var cert tls.Certificate
- cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key))
- if err != nil {
- log.Errorf(Tr("error.cert.requested.parse.failed"), err)
- os.Exit(2)
- }
- tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
- certHost, _ := parseCertCommonName(cert.Certificate[0])
- log.Infof(Tr("info.cert.requested"), certHost)
+ if len(certs) == 0 {
+ if tlsConfig == nil {
+ tlsConfig = new(tls.Config)
+ } else {
+ tlsConfig = tlsConfig.Clone()
}
+ tlsConfig.Certificates = append(tlsConfig.Certificates, certs...)
}
- r.tlsConfig = tlsConfig
- return
+ return tlsConfig, nil
}
-func (r *Runner) EnableCluster(ctx context.Context) {
- if config.Advanced.WaitBeforeEnable > 0 {
- select {
- case <-time.After(time.Second * (time.Duration)(config.Advanced.WaitBeforeEnable)):
- case <-ctx.Done():
- return
- }
+// updateClustersWithGeneralConfig will re-enable all clusters with latest general config
+func (r *Runner) updateClustersWithGeneralConfig(ctx context.Context) error {
+ gcfg := r.GetClusterGeneralConfig()
+ var wg sync.WaitGroup
+ for _, cr := range r.clusters {
+ wg.Add(1)
+ go func(cr *cluster.Cluster) {
+ defer wg.Done()
+ cr.Disable(ctx)
+ *cr.GeneralConfig() = gcfg
+ if err := cr.Enable(ctx); err != nil {
+ log.TrErrorf("error.cluster.enable.failed", cr.ID(), err)
+ return
+ }
+ }(cr)
}
+ wg.Wait()
+ return nil
+}
- if config.Tunneler.Enable {
- r.enableClusterByTunnel(ctx)
- } else {
- if err := r.cluster.Enable(ctx); err != nil {
- log.Errorf(Tr("error.cluster.enable.failed"), err)
- if ctx.Err() != nil {
+func (r *Runner) EnableClusterAll(ctx context.Context) {
+ var wg sync.WaitGroup
+ for _, cr := range r.clusters {
+ wg.Add(1)
+ go func(cr *cluster.Cluster) {
+ defer wg.Done()
+ if err := cr.Enable(ctx); err != nil {
+ log.TrErrorf("error.cluster.enable.failed", cr.ID(), err)
return
}
- osExit(CodeServerOrEnvionmentError)
- }
+ }(cr)
}
+ wg.Wait()
}
func (r *Runner) StartTunneler() {
@@ -597,8 +710,8 @@ func (r *Runner) StartTunneler() {
}
func (r *Runner) RunTunneler(ctx context.Context) {
- cmd := exec.CommandContext(ctx, config.Tunneler.TunnelProg)
- log.Infof(Tr("info.tunnel.running"), cmd.String())
+ cmd := exec.CommandContext(ctx, r.Config.Tunneler.TunnelProg)
+ log.TrInfof("info.tunnel.running", cmd.String())
var (
cmdOut, cmdErr io.ReadCloser
err error
@@ -606,15 +719,15 @@ func (r *Runner) RunTunneler(ctx context.Context) {
cmd.Env = append(os.Environ(),
"CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port)))
if cmdOut, err = cmd.StdoutPipe(); err != nil {
- log.Errorf(Tr("error.tunnel.command.prepare.failed"), err)
+ log.TrErrorf("error.tunnel.command.prepare.failed", err)
os.Exit(1)
}
if cmdErr, err = cmd.StderrPipe(); err != nil {
- log.Errorf(Tr("error.tunnel.command.prepare.failed"), err)
+ log.TrErrorf("error.tunnel.command.prepare.failed", err)
os.Exit(1)
}
if err = cmd.Start(); err != nil {
- log.Errorf(Tr("error.tunnel.command.prepare.failed"), err)
+ log.TrErrorf("error.tunnel.command.prepare.failed", err)
os.Exit(1)
}
type addrOut struct {
@@ -623,11 +736,10 @@ func (r *Runner) RunTunneler(ctx context.Context) {
}
detectedCh := make(chan addrOut, 1)
onLog := func(line []byte) {
- res := config.Tunneler.outputRegex.FindSubmatch(line)
- if res == nil {
+ tunnelHost, tunnelPort, ok := r.Config.Tunneler.MatchTunnelOutput(line)
+ if !ok {
return
}
- tunnelHost, tunnelPort := res[config.Tunneler.hostOut], res[config.Tunneler.portOut]
if len(tunnelHost) > 0 && tunnelHost[0] == '[' && tunnelHost[len(tunnelHost)-1] == ']' { // a IPv6 with port []:
tunnelHost = tunnelHost[1 : len(tunnelHost)-1]
}
@@ -667,33 +779,11 @@ func (r *Runner) RunTunneler(ctx context.Context) {
for {
select {
case addr := <-detectedCh:
- log.Infof(Tr("info.tunnel.detected"), addr.host, addr.port)
- r.cluster.publicPort = addr.port
- if !r.cluster.byoc {
- r.cluster.host = addr.host
- }
- strPort := strconv.Itoa((int)(r.getPublicPort()))
- if spp, ok := r.listener.(interface{ SetPublicPort(port string) }); ok {
- spp.SetPublicPort(strPort)
- }
- log.Infof(Tr("info.server.public.at"), net.JoinHostPort(addr.host, strPort), r.clusterSvr.Addr, r.getCertCount())
- if len(r.publicHosts) > 1 {
- log.Info(Tr("info.server.alternative.hosts"))
- for _, h := range r.publicHosts[1:] {
- log.Infof("\t- https://%s", net.JoinHostPort(h, strPort))
- }
- }
- if !r.cluster.Enabled() {
- shutCtx, cancel := context.WithTimeout(ctx, time.Minute)
- r.cluster.Disable(shutCtx)
- cancel()
- }
- if err := r.cluster.Enable(ctx); err != nil {
- log.Errorf(Tr("error.cluster.enable.failed"), err)
- if ctx.Err() != nil {
- return
- }
- os.Exit(2)
+ log.TrInfof("info.tunnel.detected", addr.host, addr.port)
+ r.publicHost, r.publicPort = addr.host, addr.port
+ r.updateClustersWithGeneralConfig(ctx)
+ if ctx.Err() != nil {
+ return
}
case <-ctx.Done():
return
diff --git a/storage/manager.go b/storage/manager.go
index 79353dd6..164a34a7 100644
--- a/storage/manager.go
+++ b/storage/manager.go
@@ -57,6 +57,15 @@ func (m *Manager) Get(id string) Storage {
return nil
}
+func (m *Manager) GetIndex(id string) int {
+ for i, s := range m.Storages {
+ if s.Id() == id {
+ return i
+ }
+ }
+ return -1
+}
+
func (m *Manager) GetFlavorString(storages []int) string {
typeCount := make(map[string]int, 2)
for _, i := range storages {
diff --git a/sync.go b/sync.go
index bded4ac1..12430b40 100644
--- a/sync.go
+++ b/sync.go
@@ -1,3 +1,5 @@
+//go:build ignore
+
/**
* OpenBmclAPI (Golang Edition)
* Copyright (C) 2024 Kevin Z
diff --git a/util.go b/util.go
index 539cf9a3..3bd2ba97 100644
--- a/util.go
+++ b/util.go
@@ -25,7 +25,6 @@ import (
"crypto/x509"
"fmt"
"io"
- "math/rand"
"net/http"
"net/url"
"os"
@@ -83,54 +82,6 @@ func parseCertCommonName(body []byte) (string, error) {
return cert.Subject.CommonName, nil
}
-func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) {
- if leng <= 0 {
- return false
- }
- start := randIntn(leng)
- for i := start; i < leng; i++ {
- if cb(i) {
- return true
- }
- }
- for i := 0; i < start; i++ {
- if cb(i) {
- return true
- }
- }
- return false
-}
-
-func forEachFromRandomIndexWithPossibility(poss []uint, total uint, cb func(i int) (done bool)) (done bool) {
- leng := len(poss)
- if leng == 0 {
- return false
- }
- if total == 0 {
- return forEachFromRandomIndex(leng, cb)
- }
- n := (uint)(randIntn((int)(total)))
- start := 0
- for i, p := range poss {
- if n < p {
- start = i
- break
- }
- n -= p
- }
- for i := start; i < leng; i++ {
- if cb(i) {
- return true
- }
- }
- for i := 0; i < start; i++ {
- if cb(i) {
- return true
- }
- }
- return false
-}
-
func copyFile(src, dst string, mode os.FileMode) (err error) {
var srcFd, dstFd *os.File
if srcFd, err = os.Open(src); err != nil {
diff --git a/utils/http.go b/utils/http.go
index d83fbdbf..0412bf27 100644
--- a/utils/http.go
+++ b/utils/http.go
@@ -30,7 +30,6 @@ import (
"net/url"
"path"
"runtime"
- "strconv"
"strings"
"sync"
"sync/atomic"
@@ -290,13 +289,10 @@ func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request)
// Else it will just return the tls connection
type HTTPTLSListener struct {
net.Listener
- TLSConfig *tls.Config
+ TLSConfig atomic.Pointer[tls.Config]
+ DoRedirect bool
AllowUnsecure bool
- mux sync.RWMutex
- hosts []string
- port string
-
accepting atomic.Bool
acceptedCh chan net.Conn
errCh chan error
@@ -304,15 +300,17 @@ type HTTPTLSListener struct {
var _ net.Listener = (*HTTPTLSListener)(nil)
-func NewHttpTLSListener(l net.Listener, cfg *tls.Config, publicHosts []string, port uint16) net.Listener {
- return &HTTPTLSListener{
- Listener: l,
- TLSConfig: cfg,
- hosts: publicHosts,
- port: strconv.Itoa((int)(port)),
+func NewHttpTLSListener(l net.Listener, cfg *tls.Config) *HTTPTLSListener {
+ h := &HTTPTLSListener{
+ Listener: l,
+ DoRedirect: true,
+ AllowUnsecure: false,
+
acceptedCh: make(chan net.Conn, 1),
errCh: make(chan error, 1),
}
+ h.TLSConfig.Store(cfg)
+ return h
}
func (s *HTTPTLSListener) Close() (err error) {
@@ -329,22 +327,7 @@ func (s *HTTPTLSListener) Close() (err error) {
return
}
-func (s *HTTPTLSListener) SetPublicPort(port string) {
- s.mux.Lock()
- defer s.mux.Unlock()
- s.port = port
-}
-
-func (s *HTTPTLSListener) GetPublicPort() string {
- s.mux.RLock()
- defer s.mux.RUnlock()
- return s.port
-}
-
func (s *HTTPTLSListener) maybeHTTPConn(c *connHeadReader) (ishttp bool) {
- if len(s.hosts) == 0 {
- return false
- }
var buf [4096]byte
i, n := 0, 0
READ_HEAD:
@@ -389,6 +372,11 @@ func (s *HTTPTLSListener) accepter() {
s.errCh <- err
return
}
+ tlsCfg := s.TLSConfig.Load()
+ if tlsCfg == nil {
+ s.acceptedCh <- conn
+ return
+ }
go s.accepter()
hr := &connHeadReader{Conn: conn}
hr.SetReadDeadline(time.Now().Add(time.Second * 5))
@@ -396,7 +384,7 @@ func (s *HTTPTLSListener) accepter() {
hr.SetReadDeadline(time.Time{})
if !ishttp {
// if it's not a http connection, it must be a tls connection
- s.acceptedCh <- tls.Server(hr, s.TLSConfig)
+ s.acceptedCh <- tls.Server(hr, tlsCfg)
return
}
if s.AllowUnsecure {
@@ -416,41 +404,13 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) {
return
}
conn.SetReadDeadline(time.Time{})
- host, _, err := net.SplitHostPort(req.Host)
- if err != nil {
- host = req.Host
- }
- inhosts := false
- if host != "" {
- host = strings.ToLower(host)
- for _, h := range s.hosts {
- if h == "*" {
- inhosts = true
- break
- }
- if h, ok := strings.CutPrefix(h, "*."); ok {
- if strings.HasSuffix(host, h) {
- inhosts = true
- break
- }
- } else if h == host {
- inhosts = true
- break
- }
- }
- }
+ // host, _, err := net.SplitHostPort(req.Host)
+ // if err != nil {
+ // host = req.Host
+ // }
u := *req.URL
u.Scheme = "https"
- if !inhosts {
- host = ""
- for _, h := range s.hosts {
- if h != "*" && !strings.HasSuffix(h, "*.") {
- host = h
- break
- }
- }
- }
- if host == "" {
+ if !s.DoRedirect {
body := strings.NewReader("Sent http request on https server")
resp := &http.Response{
StatusCode: http.StatusBadRequest,
@@ -468,7 +428,7 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) {
io.Copy(conn, body)
return
}
- u.Host = net.JoinHostPort(host, s.GetPublicPort())
+ // u.Host = net.JoinHostPort(host, s.GetPublicPort())
resp := &http.Response{
StatusCode: http.StatusPermanentRedirect,
ProtoMajor: req.ProtoMajor,