diff --git a/cluster/cluster.go b/cluster/cluster.go index f48ae38..3f7943f 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -63,8 +63,9 @@ type Cluster struct { client *http.Client cachedCli *http.Client - authTokenMux sync.RWMutex - authToken *ClusterToken + authTokenMux sync.RWMutex + authToken *ClusterToken + fileListLastMod int64 } func NewCluster( @@ -95,12 +96,12 @@ func (cr *Cluster) Secret() string { // Host returns the cluster public host func (cr *Cluster) Host() string { - return cr.gcfg.Host + return cr.gcfg.PublicHost } // Port returns the cluster public port func (cr *Cluster) Port() uint16 { - return cr.gcfg.Port + return cr.gcfg.PublicPort } // PublicHosts returns the cluster public hosts @@ -168,8 +169,8 @@ func (cr *Cluster) enable(ctx context.Context) error { log.TrInfof("info.cluster.enable.sending") resCh, err := cr.socket.EmitWithAck("enable", EnableData{ - Host: cr.gcfg.Host, - Port: cr.gcfg.Port, + Host: cr.gcfg.PublicHost, + Port: cr.gcfg.PublicPort, Version: build.ClusterVersion, Byoc: cr.gcfg.Byoc, NoFastEnable: cr.gcfg.NoFastEnable, diff --git a/cluster/config.go b/cluster/config.go index c8fac95..2eebc61 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -29,12 +29,8 @@ import ( "fmt" "net/http" "net/url" - "strconv" "time" - "github.com/hamba/avro/v2" - "github.com/klauspost/compress/zstd" - "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -263,63 +259,3 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } return } - -type FileInfo struct { - Path string `json:"path" avro:"path"` - Hash string `json:"hash" avro:"hash"` - Size int64 `json:"size" avro:"size"` - Mtime int64 `json:"mtime" avro:"mtime"` -} - -// from -var fileListSchema = avro.MustParse(`{ - "type": "array", - "items": { - "type": "record", - "name": "fileinfo", - "fields": [ - {"name": "path", "type": "string"}, - {"name": "hash", "type": "string"}, - {"name": "size", "type": "long"}, - {"name": "mtime", "type": "long"} - ] - } -}`) - -func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { - var query url.Values - if lastMod > 0 { - query = url.Values{ - "lastModified": {strconv.FormatInt(lastMod, 10)}, - } - } - req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) - if err != nil { - return - } - res, err := cr.cachedCli.Do(req) - if err != nil { - return - } - defer res.Body.Close() - switch res.StatusCode { - case http.StatusOK: - // - case http.StatusNoContent, http.StatusNotModified: - return - default: - err = utils.NewHTTPStatusErrorFromResponse(res) - return - } - log.Debug("Parsing filelist body ...") - zr, err := zstd.NewReader(res.Body) - if err != nil { - return - } - defer zr.Close() - if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { - return - } - log.Debugf("Filelist parsed, length = %d", len(files)) - return -} diff --git a/cluster/handler.go b/cluster/handler.go index a14ca24..dd2151a 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -113,7 +113,7 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st return true }) if sto != nil { - api.SetAccessInfo(req, "storage", sto.Options().Id) + api.SetAccessInfo(req, "storage", sto.Id()) } if ok { return @@ -135,7 +135,7 @@ func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size api.SetAccessInfo(req, "cluster", cr.ID()) storage := cr.storageManager.Storages[cr.storages[0]] - api.SetAccessInfo(req, "storage", storage.Options().Id) + api.SetAccessInfo(req, "storage", storage.Id()) if err := storage.ServeMeasure(rw, req, size); err != nil { log.Errorf("Could not serve measure %d: %v", size, err) api.SetAccessInfo(req, "error", err.Error()) diff --git a/cluster/storage.go b/cluster/storage.go new file mode 100644 index 0000000..9b7a65c --- /dev/null +++ b/cluster/storage.go @@ -0,0 +1,354 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "runtime" + "slices" + "strconv" + "sync" + "sync/atomic" + "crypto" + "time" + "encoding/hex" + + "github.com/hamba/avro/v2" + "github.com/klauspost/compress/zstd" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + + "github.com/LiterMC/go-openbmclapi/lang" + "github.com/LiterMC/go-openbmclapi/limited" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" +) + +// from +var fileListSchema = avro.MustParse(`{ + "type": "array", + "items": { + "type": "record", + "name": "fileinfo", + "fields": [ + {"name": "path", "type": "string"}, + {"name": "hash", "type": "string"}, + {"name": "size", "type": "long"}, + {"name": "mtime", "type": "long"} + ] + } +}`) + +type FileInfo struct { + Path string `json:"path" avro:"path"` + Hash string `json:"hash" avro:"hash"` + Size int64 `json:"size" avro:"size"` + Mtime int64 `json:"mtime" avro:"mtime"` +} + +type StorageFileInfo struct { + FileInfo + Storages []storage.Storage +} + +func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageFileInfo, forceAll bool) (err error) { + var query url.Values + lastMod := cr.fileListLastMod + if forceAll { + lastMod = 0 + } + if lastMod > 0 { + query = url.Values{ + "lastModified": {strconv.FormatInt(lastMod, 10)}, + } + } + req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) + if err != nil { + return + } + res, err := cr.cachedCli.Do(req) + if err != nil { + return + } + defer res.Body.Close() + switch res.StatusCode { + case http.StatusOK: + // + case http.StatusNoContent, http.StatusNotModified: + return + default: + err = utils.NewHTTPStatusErrorFromResponse(res) + return + } + log.Debug("Parsing filelist body ...") + zr, err := zstd.NewReader(res.Body) + if err != nil { + return + } + defer zr.Close() + var files []FileInfo + if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { + return + } + + for _, f := range files { + if f.Mtime > lastMod { + lastMod = f.Mtime + } + if ff, ok := fileMap[f.Hash]; ok { + if ff.Size != f.Size { + log.Panicf("Hash conflict detected, hash of both %q (%dB) and %q (%dB) is %s", ff.Path, ff.Size, f.Path, f.Size, f.Hash) + } + for _, s := range cr.storages { + sto := cr.storageManager.Storages[s] + if i, ok := slices.BinarySearchFunc(ff.Storages, sto, storageIdSortFunc); !ok { + ff.Storages = slices.Insert(ff.Storages, i, sto) + } + } + } else { + ff := &StorageFileInfo{ + FileInfo: f, + Storages: make([]storage.Storage, len(cr.storages)), + } + for i, s := range cr.storages { + ff.Storages[i] = cr.storageManager.Storages[s] + } + slices.SortFunc(ff.Storages, storageIdSortFunc) + fileMap[f.Hash] = ff + } + } + cr.fileListLastMod = lastMod + log.Debugf("Filelist parsed, length = %d, lastMod = %d", len(files), lastMod) + return +} + +func storageIdSortFunc(a, b storage.Storage) int { + if a.Id() < b.Id() { + return -1 + } + return 1 +} + +// func SyncFiles(ctx context.Context, manager *storage.Manager, files map[string]*StorageFileInfo, heavyCheck bool) bool { +// log.TrInfof("info.sync.prepare", len(files)) + +// slices.SortFunc(files, func(a, b *StorageFileInfo) int { return a.Size - b.Size }) +// if cr.syncFiles(ctx, files, heavyCheck) != nil { +// return false +// } + +// cr.filesetMux.Lock() +// for _, f := range files { +// cr.fileset[f.Hash] = f.Size +// } +// cr.filesetMux.Unlock() + +// return true +// } + +var emptyStr string + +func checkFile( + ctx context.Context, + manager *storage.Manager, + files map[string]*StorageFileInfo, + heavy bool, + missing map[string]*StorageFileInfo, + pg *mpb.Progress, +) (err error) { + var missingCount atomic.Int32 + addMissing := func(f FileInfo, sto storage.Storage) { + missingCount.Add(1) + if info, ok := missing[f.Hash]; ok { + info.Storages = append(info.Storages, sto) + } else { + missing[f.Hash] = &StorageFileInfo{ + FileInfo: f, + Storages: []storage.Storage{sto}, + } + } + } + + log.TrInfof("info.check.start", heavy) + + var ( + checkingHash atomic.Pointer[string] + lastCheckingHash string + slots *limited.BufSlots + wg sync.WaitGroup + ) + checkingHash.Store(&emptyStr) + + if heavy { + slots = limited.NewBufSlots(runtime.GOMAXPROCS(0) * 2) + } + + bar := pg.AddBar(0, + mpb.BarRemoveOnComplete(), + mpb.PrependDecorators( + decor.Name(lang.Tr("hint.check.checking")), + decor.OnCondition( + decor.Any(func(decor.Statistics) string { + c, l := slots.Cap(), slots.Len() + return fmt.Sprintf(" (%d / %d)", c-l, c) + }), + heavy, + ), + ), + mpb.AppendDecorators( + decor.CountersNoUnit("%d / %d", decor.WCSyncSpaceR), + decor.NewPercentage("%d", decor.WCSyncSpaceR), + decor.EwmaETA(decor.ET_STYLE_GO, 60), + ), + mpb.BarExtender((mpb.BarFillerFunc)(func(w io.Writer, _ decor.Statistics) (err error) { + lastCheckingHash = *checkingHash.Load() + if lastCheckingHash != "" { + _, err = fmt.Fprintln(w, "\t", lastCheckingHash) + } + return + }), false), + ) + defer bar.Wait() + defer bar.Abort(true) + + bar.SetTotal(0x100, false) + + ssizeMap := make(map[storage.Storage]map[string]int64, len(manager.Storages)) + for _, sto := range manager.Storages { + sizeMap := make(map[string]int64, len(files)) + ssizeMap[sto] = sizeMap + wg.Add(1) + go func(sto storage.Storage, sizeMap map[string]int64) { + defer wg.Done() + start := time.Now() + var checkedMp [256]bool + if err := sto.WalkDir(func(hash string, size int64) error { + if n := utils.HexTo256(hash); !checkedMp[n] { + checkedMp[n] = true + now := time.Now() + bar.EwmaIncrement(now.Sub(start)) + start = now + } + sizeMap[hash] = size + return nil + }); err != nil { + log.Errorf("Cannot walk %s: %v", sto.Id(), err) + return + } + }(sto, sizeMap) + } + wg.Wait() + + bar.SetCurrent(0) + bar.SetTotal((int64)(len(files)), false) + for _, f := range files { + if err := ctx.Err(); err != nil { + return err + } + start := time.Now() + hash := f.Hash + checkingHash.Store(&hash) + if f.Size == 0 { + log.Debugf("Skipped empty file %s", hash) + bar.EwmaIncrement(time.Since(start)) + continue + } + for _, sto := range f.Storages { + name := sto.Id() + "/" + hash + size, ok := ssizeMap[sto][hash] + if !ok { + // log.Debugf("Could not found file %q", name) + addMissing(f.FileInfo, sto) + bar.EwmaIncrement(time.Since(start)) + continue + } + if size != f.Size { + log.TrWarnf("warn.check.modified.size", name, size, f.Size) + addMissing(f.FileInfo, sto) + bar.EwmaIncrement(time.Since(start)) + continue + } + if !heavy { + bar.EwmaIncrement(time.Since(start)) + continue + } + hashMethod, err := getHashMethod(len(hash)) + if err != nil { + log.TrErrorf("error.check.unknown.hash.method", hash) + bar.EwmaIncrement(time.Since(start)) + continue + } + _, buf, free := slots.Alloc(ctx) + if buf == nil { + return ctx.Err() + } + wg.Add(1) + go func(f FileInfo, buf []byte, free func()) { + defer log.RecoverPanic(nil) + defer wg.Done() + miss := true + r, err := sto.Open(hash) + if err != nil { + log.TrErrorf("error.check.open.failed", name, err) + } else { + hw := hashMethod.New() + _, err = io.CopyBuffer(hw, r, buf[:]) + r.Close() + if err != nil { + log.TrErrorf("error.check.hash.failed", name, err) + } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != hash { + log.TrWarnf("warn.check.modified.hash", name, hs, hash) + } else { + miss = false + } + } + bar.EwmaIncrement(time.Since(start)) + free() + if miss { + addMissing(f, sto) + } + }(f.FileInfo, buf, free) + } + } + wg.Wait() + + checkingHash.Store(&emptyStr) + + bar.SetTotal(-1, true) + log.TrInfof("info.check.done", missingCount.Load()) + return nil +} + +func getHashMethod(l int) (hashMethod crypto.Hash, err error) { + switch l { + case 32: + hashMethod = crypto.MD5 + case 40: + hashMethod = crypto.SHA1 + default: + err = fmt.Errorf("Unknown hash length %d", l) + } + return +} diff --git a/config/advanced.go b/config/advanced.go index 7b3fcfb..99fc132 100644 --- a/config/advanced.go +++ b/config/advanced.go @@ -30,5 +30,5 @@ type AdvancedConfig struct { WaitBeforeEnable int `yaml:"wait-before-enable"` // DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` - DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` + DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` } diff --git a/config/config.go b/config/config.go index 747b961..5c843c9 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,7 @@ import ( type Config struct { PublicHost string `yaml:"public-host"` PublicPort uint16 `yaml:"public-port"` + Host string `yaml:"host"` Port uint16 `yaml:"port"` Byoc bool `yaml:"byoc"` UseCert bool `yaml:"use-cert"` @@ -76,6 +77,7 @@ func NewDefaultConfig() *Config { return &Config{ PublicHost: "", PublicPort: 0, + Host: "0.0.0.0", Port: 4000, Byoc: false, TrustedXForwardedFor: false, diff --git a/config/server.go b/config/server.go index 6b16e00..32963a5 100644 --- a/config/server.go +++ b/config/server.go @@ -40,8 +40,8 @@ type ClusterOptions struct { } type ClusterGeneralConfig struct { - Host string `json:"host"` - Port uint16 `json:"port"` + PublicHost string `json:"public-host"` + PublicPort uint16 `json:"public-port"` Byoc bool `json:"byoc"` NoFastEnable bool `json:"no-fast-enable"` MaxReconnectCount int `json:"max-reconnect-count"` diff --git a/handler.go b/handler.go index 3c76264..bedc3fb 100644 --- a/handler.go +++ b/handler.go @@ -38,6 +38,8 @@ import ( "strings" "time" + "github.com/gorilla/websocket" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" @@ -126,7 +128,7 @@ func (r *Runner) GetHandler() http.Handler { return handler } -func (cr *Cluster) getRecordMiddleWare() utils.MiddleWareFunc { +func (r *Runner) getRecordMiddleWare() utils.MiddleWareFunc { type record struct { used float64 bytes float64 diff --git a/main.go b/main.go index 0bb6cb0..21fd45f 100644 --- a/main.go +++ b/main.go @@ -44,11 +44,14 @@ import ( doh "github.com/libp2p/go-doh-resolver" + "github.com/LiterMC/go-openbmclapi/config" + "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/lang" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" + subcmds "github.com/LiterMC/go-openbmclapi/sub_commands" _ "github.com/LiterMC/go-openbmclapi/lang/en" _ "github.com/LiterMC/go-openbmclapi/lang/zh" @@ -79,14 +82,14 @@ func parseArgs() { case "help", "--help": printHelp() os.Exit(0) - case "zip-cache": - cmdZipCache(os.Args[2:]) - os.Exit(0) - case "unzip-cache": - cmdUnzipCache(os.Args[2:]) - os.Exit(0) + // case "zip-cache": + // cmdZipCache(os.Args[2:]) + // os.Exit(0) + // case "unzip-cache": + // cmdUnzipCache(os.Args[2:]) + // os.Exit(0) case "upload-webdav": - cmdUploadWebdav(os.Args[2:]) + subcmds.CmdUploadWebdav(os.Args[2:]) os.Exit(0) default: fmt.Println("Unknown sub command:", subcmd) @@ -96,16 +99,6 @@ func parseArgs() { } } -var exitCh = make(chan int, 1) - -func osExit(n int) { - select { - case exitCh <- n: - default: - } - runtime.Goexit() -} - func main() { if runtime.GOOS == "windows" { lang.SetLang("zh-cn") @@ -118,28 +111,6 @@ func main() { printShortLicense() parseArgs() - exitCode := -1 - defer func() { - code := exitCode - if code == -1 { - select { - case code = <-exitCh: - default: - code = 0 - } - } - if code != 0 { - log.TrErrorf("program.exited", code) - log.TrErrorf("error.exit.please.read.faq") - if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { - // log.TrWarnf("warn.exit.detected.windows.open.browser") - // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") - // cmd.Start() - time.Sleep(time.Hour) - } - } - os.Exit(code) - }() defer log.RecordPanic() log.StartFlushLogFile() @@ -147,28 +118,32 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - config = readConfig() - if config.Advanced.DebugLog { + if config, err := readAndRewriteConfig(); err != nil { + log.Errorf("Config error: %s", err) + os.Exit(1) + } else { + r.Config = config + } + if r.Config.Advanced.DebugLog { log.SetLevel(log.LevelDebug) } else { log.SetLevel(log.LevelInfo) } - if config.NoAccessLog { + if r.Config.NoAccessLog { log.SetAccessLogSlots(-1) } else { - log.SetAccessLogSlots(config.AccessLogSlots) + log.SetAccessLogSlots(r.Config.AccessLogSlots) } - config.applyWebManifest(dsbManifest) + r.Config.applyWebManifest(dsbManifest) log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) - if config.ClusterId == defaultConfig.ClusterId || config.ClusterSecret == defaultConfig.ClusterSecret { - log.TrErrorf("error.set.cluster.id") - osExit(CodeClientError) + if r.Config.Tunneler.Enable { + r.StartTunneler() } - - r.InitCluster(ctx) + r.InitServer() + r.InitClusters(ctx) go func(ctx context.Context) { defer log.RecordPanic() @@ -189,7 +164,7 @@ func main() { defer listener.Close() if err := r.clusterSvr.Serve(listener); !errors.Is(err, http.ErrServerClosed) { log.Error("Error when serving:", err) - osExit(CodeClientError) + os.Exit(1) } }(listener) @@ -221,25 +196,39 @@ func main() { }(ctx) code := r.DoSignals(cancel) - exitCode = code + if code != 0 { + log.TrErrorf("program.exited", code) + log.TrErrorf("error.exit.please.read.faq") + if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { + // log.TrWarnf("warn.exit.detected.windows.open.browser") + // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") + // cmd.Start() + time.Sleep(time.Hour) + } + } + os.Exit(code) } type Runner struct { - cluster *Cluster - clusterSvr *http.Server + Config *config.Config + + clusters map[string]*Cluster + server *http.Server tlsConfig *tls.Config listener net.Listener publicHosts []string - updating atomic.Bool + reloading atomic.Bool + updating atomic.Bool + tunnelCancel context.CancelFunc } func (r *Runner) getPublicPort() uint16 { - if config.PublicPort > 0 { - return config.PublicPort + if r.Config.PublicPort > 0 { + return r.Config.PublicPort } - return config.Port + return r.Config.Port } func (r *Runner) getCertCount() int { @@ -255,10 +244,15 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) defer signal.Stop(signalCh) + var ( + forceStop context.CancelFunc + exited = make(chan struct{}, 0) + ) + for { select { - case code := <-exitCh: - return code + case <-exited: + return 0 case s := <-signalCh: switch s { case syscall.SIGQUIT: @@ -295,20 +289,46 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { log.Info("Dump file created") } } - continue case syscall.SIGHUP: - r.ReloadConfig() + go r.ReloadConfig() default: cancel() - r.StopServer(signalCh) + if forceStop == nil { + ctx, cancel := context.WithCancel(context.Background()) + forceStop = cancel + go func() { + defer close(exited) + r.StopServer(ctx) + }() + } else { + log.Warn("signal:", s) + log.Error("Second close signal received, forcely shutting down") + forceStop() + } } } - return 0 } + return 0 } -func (r *Runner) StopServer(sigCh <-chan os.Signal) { +func (r *Runner) ReloadConfig() { + if r.reloading.CompareAndSwap(false, true) { + log.Error("Config is already reloading!") + return + } + defer r.reloading.Store(false) + + config, err := readAndRewriteConfig() + if err != nil { + log.Errorf("Config error: %s", err) + } else { + r.Config = config + } +} + +func (r *Runner) StopServer(ctx context.Context) { + r.tunnelCancel() shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) defer cancelShut() log.TrWarnf("warn.server.closing") @@ -316,24 +336,39 @@ func (r *Runner) StopServer(sigCh <-chan os.Signal) { go func() { defer close(shutDone) defer cancelShut() - r.cluster.Disable(shutCtx) + var wg sync.WaitGroup + for _, cr := range r.clusters { + go func() { + defer wg.Done() + cr.Disable(shutCtx) + }() + } + wg.Wait() log.TrWarnf("warn.httpserver.closing") - r.clusterSvr.Shutdown(shutCtx) + r.server.Shutdown(shutCtx) }() select { case <-shutDone: - case s := <-sigCh: - log.Warn("signal:", s) - log.Error("Second close signal received, forcely exit") + case <-ctx.Done(): return } log.TrWarnf("warn.server.closed") } -func (r *Runner) InitCluster(ctx context.Context) { +func (r *Runner) InitServer() { + r.server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", d.Config.Host, d.Config.Port), + ReadTimeout: 10 * time.Second, + IdleTimeout: 5 * time.Second, + Handler: r, + ErrorLog: log.ProxiedStdLog, + } +} + +func (r *Runner) InitClusters(ctx context.Context) { var ( dialer *net.Dialer - cache = config.Cache.newCache() + cache = r.Config.Cache.newCache() ) _ = doh.NewResolver // TODO: use doh resolver @@ -349,20 +384,12 @@ func (r *Runner) InitCluster(ctx context.Context) { ) if err := r.cluster.Init(ctx); err != nil { log.Errorf(Tr("error.init.failed"), err) - osExit(CodeClientError) - } - - r.clusterSvr = &http.Server{ - Addr: fmt.Sprintf("%s:%d", "0.0.0.0", config.Port), - ReadTimeout: 10 * time.Second, - IdleTimeout: 5 * time.Second, - Handler: r.cluster.GetHandler(), - ErrorLog: log.ProxiedStdLog, + os.Exit(1) } } -func (r *Runner) UpdateFileRecords(files []FileInfo, oldfileset map[string]int64) { - if !config.Hijack.Enable { +func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { + if !r.hijacker.Enabled { return } if !r.updating.CompareAndSwap(false, true) { @@ -394,15 +421,14 @@ func (r *Runner) UpdateFileRecords(files []FileInfo, oldfileset map[string]int64 } func (r *Runner) InitSynchronizer(ctx context.Context) { - log.Info(Tr("info.filelist.fetching")) - fl, err := r.cluster.GetFileList(ctx, 0) - if err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), err) - if errors.Is(err, context.Canceled) { - return - } - if !config.Advanced.SkipFirstSync { - osExit(CodeClientOrServerError) + fileMap := make(map[string]*StorageFileInfo) + for _, cr := range r.clusters { + log.Info(Tr("info.filelist.fetching"), cr.ID()) + if err := cr.GetFileList(ctx, fileMap, true); err != nil { + log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + if errors.Is(err, context.Canceled) { + return + } } } @@ -414,10 +440,10 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } if !config.Advanced.SkipFirstSync { - if !r.cluster.SyncFiles(ctx, fl, false) { + if !r.cluster.SyncFiles(ctx, fileMap, false) { return } - go r.UpdateFileRecords(fl, nil) + go r.UpdateFileRecords(fileMap, nil) if !config.Advanced.NoGC { go r.cluster.Gc() @@ -428,29 +454,19 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } } - var lastMod int64 - for _, f := range fl { - if f.Mtime > lastMod { - lastMod = f.Mtime - } - } - createInterval(ctx, func() { - log.Info(Tr("info.filelist.fetching")) - fl, err := r.cluster.GetFileList(ctx, lastMod) - if err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), err) - return + fileMap := make(map[string]*StorageFileInfo) + for _, cr := range r.clusters { + log.Info(Tr("info.filelist.fetching"), cr.ID()) + if err := cr.GetFileList(ctx, fileMap, false); err != nil { + log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + return + } } - if len(fl) == 0 { - log.Infof("No file was updated since %s", time.UnixMilli(lastMod).Format(time.DateTime)) + if len(fileMap) == 0 { + log.Infof("No file was updated since last check") return } - for _, f := range fl { - if f.Mtime > lastMod { - lastMod = f.Mtime - } - } checkCount = (checkCount + 1) % heavyCheckInterval oldfileset := r.cluster.CloneFileset() @@ -464,12 +480,12 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Listener) { - listener, err := net.Listen("tcp", r.clusterSvr.Addr) + listener, err := net.Listen("tcp", r.Addr) if err != nil { - log.Errorf(Tr("error.address.listen.failed"), r.clusterSvr.Addr, err) + log.Errorf(Tr("error.address.listen.failed"), r.Addr, err) osExit(CodeEnvironmentError) } - if config.ServeLimit.Enable { + if r.Config.ServeLimit.Enable { limted := limited.NewLimitedListener(listener, config.ServeLimit.MaxConn, 0, config.ServeLimit.UploadRate*1024) limted.SetMinWriteRate(1024) listener = limted @@ -483,7 +499,7 @@ func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Lis r.publicHosts = append(r.publicHosts, strings.ToLower(h)) } } - listener = newHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) + listener = utils.NewHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) } r.listener = listener return @@ -493,7 +509,7 @@ func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) if config.UseCert { if len(config.Certificates) == 0 { log.Error(Tr("error.cert.not.set")) - osExit(CodeClientError) + os.Exit(1) } tlsConfig = new(tls.Config) tlsConfig.Certificates = make([]tls.Certificate, len(config.Certificates)) @@ -502,31 +518,33 @@ func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) if err != nil { log.Errorf(Tr("error.cert.parse.failed"), i, err) - osExit(CodeClientError) + os.Exit(1) } } } if !config.Byoc { - log.Info(Tr("info.cert.requesting")) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := r.cluster.RequestCert(tctx) - cancel() - if err != nil { - log.Errorf(Tr("error.cert.request.failed"), err) - osExit(CodeServerError) - } - if tlsConfig == nil { - tlsConfig = new(tls.Config) - } - var cert tls.Certificate - cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.Errorf(Tr("error.cert.requested.parse.failed"), err) - osExit(CodeServerUnexpectedError) + for _, cr := range r.clusters { + log.Info(Tr("info.cert.requesting"), cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.Errorf(Tr("error.cert.request.failed"), err) + os.Exit(2) + } + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + var cert tls.Certificate + cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.Errorf(Tr("error.cert.requested.parse.failed"), err) + os.Exit(2) + } + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.Infof(Tr("info.cert.requested"), certHost) } - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.Infof(Tr("info.cert.requested"), certHost) } r.tlsConfig = tlsConfig return @@ -554,7 +572,31 @@ func (r *Runner) EnableCluster(ctx context.Context) { } } -func (r *Runner) enableClusterByTunnel(ctx context.Context) { +func (r *Runner) StartTunneler() { + ctx, cancel := context.WithCancel(context.Background()) + r.tunnelCancel = cancel + go func() { + dur := time.Second + for { + start := time.Now() + r.RunTunneler(ctx) + used := time.Since(start) + // If the program runs no longer than 30s, then it fails too fast. + if used < time.Second*30 { + dur = min(dur*2, time.Minute*10) + } else { + dur = time.Second + } + select { + case <-time.After(dur): + case <-ctx.Done(): + return + } + } + }() +} + +func (r *Runner) RunTunneler(ctx context.Context) { cmd := exec.CommandContext(ctx, config.Tunneler.TunnelProg) log.Infof(Tr("info.tunnel.running"), cmd.String()) var ( @@ -562,18 +604,18 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { err error ) cmd.Env = append(os.Environ(), - "CLUSTER_PORT="+strconv.Itoa((int)(config.Port))) + "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) if cmdOut, err = cmd.StdoutPipe(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientUnexpectedError) + os.Exit(1) } if cmdErr, err = cmd.StderrPipe(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientUnexpectedError) + os.Exit(1) } if err = cmd.Start(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientError) + os.Exit(1) } type addrOut struct { host string @@ -651,7 +693,7 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { if ctx.Err() != nil { return } - osExit(CodeServerOrEnvionmentError) + os.Exit(2) } case <-ctx.Done(): return @@ -663,11 +705,5 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { return } log.Errorf("Tunnel program exited: %v", err) - osExit(CodeClientError) } - // TODO: maybe restart the tunnel program? -} - -func Tr(name string) string { - return lang.Tr(name) } diff --git a/storage/manager.go b/storage/manager.go index 36c4340..79353dd 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -50,7 +50,7 @@ func NewManager(storages []Storage) (m *Manager) { func (m *Manager) Get(id string) Storage { for _, s := range m.Storages { - if s.Options().Id == id { + if s.Id() == id { return s } } diff --git a/storage/storage.go b/storage/storage.go index 5c5c4a8..91c9952 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -35,6 +35,7 @@ import ( type Storage interface { fmt.Stringer + Id() string // Options should return the pointer of the StorageOption that should not be modified. Options() *StorageOption // Init will be called before start to use a storage diff --git a/storage/storage_local.go b/storage/storage_local.go index b912e1d..a0179b5 100644 --- a/storage/storage_local.go +++ b/storage/storage_local.go @@ -66,6 +66,10 @@ func (s *LocalStorage) String() string { return fmt.Sprintf("", s.opt.CachePath) } +func (s *LocalStorage) Id() string { + return s.basicOpt.Id +} + func (s *LocalStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/storage/storage_mount.go b/storage/storage_mount.go index 62a75e4..ea2481d 100644 --- a/storage/storage_mount.go +++ b/storage/storage_mount.go @@ -80,6 +80,10 @@ func (s *MountStorage) String() string { return fmt.Sprintf("", s.opt.Path, s.opt.RedirectBase) } +func (s *MountStorage) Id() string { + return s.basicOpt.Id +} + func (s *MountStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/storage/storage_webdav.go b/storage/storage_webdav.go index 5b472ce..940067f 100644 --- a/storage/storage_webdav.go +++ b/storage/storage_webdav.go @@ -156,6 +156,10 @@ func (s *WebDavStorage) String() string { return fmt.Sprintf("", s.opt.GetEndPoint(), s.opt.GetUsername()) } +func (s *WebDavStorage) Id() string { + return s.basicOpt.Id +} + func (s *WebDavStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/sub_commands/cmd_compress.go b/sub_commands/cmd_compress.go index 4f60637..680a352 100644 --- a/sub_commands/cmd_compress.go +++ b/sub_commands/cmd_compress.go @@ -19,7 +19,7 @@ * along with this program. If not, see . */ -package main +package sub_commands import ( "compress/gzip" diff --git a/sub_commands/cmd_webdav.go b/sub_commands/cmd_webdav.go index 5351df3..2826bd4 100644 --- a/sub_commands/cmd_webdav.go +++ b/sub_commands/cmd_webdav.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package sub_commands import ( "context" @@ -38,7 +38,7 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) -func cmdUploadWebdav(args []string) { +func CmdUploadWebdav(args []string) { cfg := readConfig() var ( diff --git a/sync.go b/sync.go index 0053653..bded4ac 100644 --- a/sync.go +++ b/sync.go @@ -88,34 +88,6 @@ type syncStats struct { lastInc atomic.Int64 } -func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) bool { - log.Infof(Tr("info.sync.prepare"), len(files)) - if !cr.issync.CompareAndSwap(false, true) { - log.Warn("Another sync task is running!") - return false - } - defer cr.issync.Store(false) - - sort.Slice(files, func(i, j int) bool { return files[i].Hash < files[j].Hash }) - if cr.syncFiles(ctx, files, heavyCheck) != nil { - return false - } - - cr.filesetMux.Lock() - for _, f := range files { - cr.fileset[f.Hash] = f.Size - } - cr.filesetMux.Unlock() - - return true -} - -type fileInfoWithTargets struct { - FileInfo - tgMux sync.Mutex - targets []storage.Storage -} - func (cr *Cluster) checkFileFor( ctx context.Context, sto storage.Storage, files []FileInfo, diff --git a/utils/http.go b/utils/http.go index 0843715..d83fbdb 100644 --- a/utils/http.go +++ b/utils/http.go @@ -290,10 +290,12 @@ func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) // Else it will just return the tls connection type HTTPTLSListener struct { net.Listener - TLSConfig *tls.Config - mux sync.RWMutex - hosts []string - port string + TLSConfig *tls.Config + AllowUnsecure bool + + mux sync.RWMutex + hosts []string + port string accepting atomic.Bool acceptedCh chan net.Conn @@ -397,6 +399,10 @@ func (s *HTTPTLSListener) accepter() { s.acceptedCh <- tls.Server(hr, s.TLSConfig) return } + if s.AllowUnsecure { + s.acceptedCh <- hr + return + } go s.serveHTTP(hr) } } @@ -418,6 +424,10 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { if host != "" { host = strings.ToLower(host) for _, h := range s.hosts { + if h == "*" { + inhosts = true + break + } if h, ok := strings.CutPrefix(h, "*."); ok { if strings.HasSuffix(host, h) { inhosts = true @@ -432,15 +442,15 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { u := *req.URL u.Scheme = "https" if !inhosts { + host = "" for _, h := range s.hosts { - if !strings.HasSuffix(h, "*.") { + if h != "*" && !strings.HasSuffix(h, "*.") { host = h break } } } if host == "" { - // we have nowhere to redirect body := strings.NewReader("Sent http request on https server") resp := &http.Response{ StatusCode: http.StatusBadRequest,