diff --git a/backend/destination.go b/backend/destination.go index f18e6e3..44f4288 100644 --- a/backend/destination.go +++ b/backend/destination.go @@ -56,6 +56,8 @@ func CheckOfflineDestinations(nowTimeStamp int64) { if err != nil { utils.DebugPrintln("Unmarshal K8S API", err) } + dest.Mutex.Lock() + defer dest.Mutex.Unlock() dest.Pods = "" for _, podItem := range pods.Items { if podItem.Status.Phase == "Running" { diff --git a/backend/k8s.go b/backend/k8s.go index c07cc97..b52f51d 100644 --- a/backend/k8s.go +++ b/backend/k8s.go @@ -13,36 +13,64 @@ import ( "janusec/utils" "net/http" "strings" + "sync" "time" ) +func UpdatePods(dest *models.Destination, nowTimeStamp int64) { + dest.IsUpdating = true + dest.Mutex.Lock() // write lock + defer dest.Mutex.Unlock() + request, _ := http.NewRequest("GET", dest.PodsAPI, nil) + request.Header.Set("Content-Type", "application/json") + resp, err := utils.GetResponse(request) + if err != nil { + utils.DebugPrintln("Check K8S API GetResponse", err) + dest.CheckTime = nowTimeStamp + dest.Online = false + } + pods := models.PODS{} + err = json.Unmarshal(resp, &pods) + if err != nil { + utils.DebugPrintln("Unmarshal K8S API", err) + } + dest.Pods = "" + for _, podItem := range pods.Items { + if podItem.Status.Phase == "Running" { + if len(dest.Pods) > 0 { + dest.Pods += "|" + } + dest.Pods += podItem.Status.PodIP + ":" + dest.PodPort + } + } + dest.IsUpdating = false +} + func SelectPodFromDestination(dest *models.Destination, srcIP string, r *http.Request) string { nowTimeStamp := time.Now().Unix() - if len(dest.Pods) == 0 || (nowTimeStamp-dest.CheckTime) > 60 { - // check k8s api if exceed 60 seconds - request, _ := http.NewRequest("GET", dest.PodsAPI, nil) - request.Header.Set("Content-Type", "application/json") - resp, err := utils.GetResponse(request) - if err != nil { - utils.DebugPrintln("Check K8S API GetResponse", err) - dest.CheckTime = nowTimeStamp - dest.Online = false - } - pods := models.PODS{} - err = json.Unmarshal(resp, &pods) - if err != nil { - utils.DebugPrintln("Unmarshal K8S API", err) + var isEmptyPods bool + if len(dest.Pods) == 0 { + isEmptyPods = true + } else { + isEmptyPods = false + } + wg := new(sync.WaitGroup) + if !dest.IsUpdating && (isEmptyPods || (nowTimeStamp-dest.CheckTime) > 60) { + if isEmptyPods { + wg.Add(1) } - dest.Pods = "" - for _, podItem := range pods.Items { - if podItem.Status.Phase == "Running" { - if len(dest.Pods) > 0 { - dest.Pods += "|" - } - dest.Pods += podItem.Status.PodIP + ":" + dest.PodPort + // check k8s api if exceed 60 seconds + go func(dest *models.Destination, nowTimeStamp int64, wg *sync.WaitGroup) { + UpdatePods(dest, nowTimeStamp) + if isEmptyPods { + wg.Done() } - } + }(dest, nowTimeStamp, wg) + } + if isEmptyPods { + wg.Wait() } + dest.Mutex.RLock() // select target pod from dest.Pods directly dests := strings.Split(dest.Pods, "|") // According to Hash(IP+UA) @@ -53,6 +81,7 @@ func SelectPodFromDestination(dest *models.Destination, srcIP string, r *http.Re } hashUInt32 := h.Sum32() destIndex := hashUInt32 % uint32(len(dests)) + dest.Mutex.RUnlock() return dests[destIndex] } diff --git a/data/data.go b/data/data.go index 84b072e..6f094aa 100644 --- a/data/data.go +++ b/data/data.go @@ -33,7 +33,7 @@ var ( // IsPrimary i.e. Is Primary Node IsPrimary bool // Version of JANUSEC - Version = "1.3.0" + Version = "1.3.1" // NodeKey share with all nodes NodeKey []byte ) diff --git a/firewall/cc.go b/firewall/cc.go index 3b55b34..3408e94 100644 --- a/firewall/cc.go +++ b/firewall/cc.go @@ -53,6 +53,8 @@ func CCAttackTick(appID int64) { clientID := key.(string) stat := value.(*models.ClientStat) //fmt.Println("CCAttackTick:", appID, clientID, stat) + stat.Mutex.Lock() + defer stat.Mutex.Unlock() if stat.IsBadIP { stat.RemainSeconds -= ccPolicy.IntervalMilliSeconds / 1000.0 if stat.RemainSeconds <= 0 { @@ -135,6 +137,8 @@ func IsCCAttack(r *http.Request, app *models.Application, srcIP string) (bool, * clientID := data.SHA256Hash(preHashContent) clientIDStat, _ := appCCCount.LoadOrStore(clientID, &models.ClientStat{QuickCount: 0, SlowCount: 0, TimeFrameCount: 0, IsBadIP: false, RemainSeconds: 0}) clientStat := clientIDStat.(*models.ClientStat) + clientStat.Mutex.Lock() + defer clientStat.Mutex.Unlock() if clientStat.IsBadIP { needLog := false if clientStat.QuickCount == 0 { diff --git a/gateway/gateway.go b/gateway/gateway.go index 6ec9f66..fc8c3c0 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -388,6 +388,8 @@ func ReverseHandlerFunc(w http.ResponseWriter, r *http.Request) { conn, err := net.Dial("tcp", targetDest) dest.CheckTime = nowTimeStamp if err != nil { + dest.Mutex.Lock() + defer dest.Mutex.Unlock() dest.Online = false utils.DebugPrintln("DialContext error", err) if data.NodeSetting.SMTP.SMTPEnabled { diff --git a/models/backend.go b/models/backend.go index 0de181b..fc6d7df 100644 --- a/models/backend.go +++ b/models/backend.go @@ -143,6 +143,10 @@ type Destination struct { // Online status of Destination (IP:Port), added in V0.9.11 Online bool `json:"online"` CheckTime int64 `json:"check_time"` + + // added in 1.3.1, K8s routine updating and avoid race + Mutex sync.RWMutex `json:"-"` + IsUpdating bool `json:"-"` } // PODS for k8s /api/v1/namespaces/default/pods diff --git a/models/firewall.go b/models/firewall.go index 802f24b..2e8af71 100644 --- a/models/firewall.go +++ b/models/firewall.go @@ -9,6 +9,7 @@ package models import ( "database/sql" + "sync" ) type PolicyKey string @@ -137,6 +138,9 @@ type ClientStat struct { // RemainSeconds used for block time frame RemainSeconds float64 //time.Duration + + // added v1.3.1 + Mutex sync.Mutex } type VulnType struct { diff --git a/release_batch.sh b/release_batch.sh index e69dbeb..179b710 100755 --- a/release_batch.sh +++ b/release_batch.sh @@ -2,7 +2,7 @@ printf "Creating installation package\n" printf "Checklist:\n" printf "* Angular Admin Version Check. \n" printf "* Janusec Version Check. \n" -version="1.3.0" +version="1.3.1" printf "Version: ${version} \n" read -r -p "Are You Sure? [Y/n] " option