diff --git a/cmd/crproxy/main.go b/cmd/crproxy/main.go index b01f9c5..926edfd 100644 --- a/cmd/crproxy/main.go +++ b/cmd/crproxy/main.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "encoding/csv" "fmt" "io" "log" @@ -13,6 +14,7 @@ import ( "os" "slices" "strings" + "sync" "sync/atomic" "time" @@ -45,6 +47,7 @@ var ( allowImageListFromFile string blockImageList []string blockMessage string + blockIPListFromFile string privilegedIPList []string privilegedImageListFromFile string privilegedNoAuth bool @@ -92,6 +95,7 @@ func init() { pflag.StringVar(&allowImageListFromFile, "allow-image-list-from-file", "", "allow image list from file") pflag.StringSliceVar(&blockImageList, "block-image-list", nil, "block image list (deprecated)") pflag.StringVar(&blockMessage, "block-message", "", "block message") + pflag.StringVar(&blockIPListFromFile, "block-ip-list-from-file", "", "block ip list from file") pflag.StringSliceVar(&privilegedIPList, "privileged-ip-list", nil, "privileged IP list") pflag.BoolVar(&privilegedNoAuth, "privileged-no-auth", false, "privileged no auth (deprecated)") pflag.StringVar(&privilegedImageListFromFile, "privileged-image-list-from-file", "", "privileged image list from file") @@ -365,6 +369,41 @@ func main() { opts = append(opts, crproxy.WithPrivilegedNoAuth(true)) } + if blockIPListFromFile != "" { + f, err := os.ReadFile(blockIPListFromFile) + if err != nil { + logger.Println("can't read block ip list file", blockIPListFromFile, ":", err) + os.Exit(1) + } + bf, err := getIPReasonCSVListFrom(bytes.NewReader(f)) + if err != nil { + logger.Println("can't read block ip list file", blockIPListFromFile, ":", err) + os.Exit(1) + } + + var bfMutex sync.RWMutex + block := func(info *crproxy.BlockInfo) (string, bool) { + bfMutex.RLock() + defer bfMutex.RUnlock() + return bf(info) + } + opts = append(opts, crproxy.WithBlockFunc(block)) + if enableInternalAPI { + mux.HandleFunc("PUT /internal/api/block-ips", func(rw http.ResponseWriter, r *http.Request) { + blockFunc, err := getIPReasonCSVListFrom(r.Body) + if err != nil { + logger.Println("can't read block ip list file", blockIPListFromFile, ":", err) + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte(err.Error())) + return + } + bfMutex.Lock() + bf = blockFunc + bfMutex.Unlock() + }) + } + } + if len(userpass) != 0 { bc, err := toUserAndPass(userpass) if err != nil { @@ -542,3 +581,27 @@ func getListFrom(r io.Reader) (hostmatcher.Matcher, error) { } return hostmatcher.NewMatcher(hosts), nil } + +func getIPReasonCSVListFrom(r io.Reader) (func(*crproxy.BlockInfo) (string, bool), error) { + kv := map[string]string{} + + reader := csv.NewReader(r) + for { + record, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + kv[record[0]] = record[1] + } + + return func(info *crproxy.BlockInfo) (string, bool) { + reason, ok := kv[info.IP] + if !ok { + return "", false + } + return reason, true + }, nil +}