From d60d706b638fd1ca2b3c50d8e4608ffc9ce8ec4b Mon Sep 17 00:00:00 2001
From: Paolo Fabio Zaino
Date: Fri, 27 Sep 2024 17:42:18 +0100
Subject: [PATCH] Completed the work to improve API performance under heavy
load
---
services/api/console.go | 72 +++-------------------
services/api/main.go | 105 +++++++++----------------------
services/api/search_engine.go | 112 ++++++----------------------------
3 files changed, 56 insertions(+), 233 deletions(-)
diff --git a/services/api/console.go b/services/api/console.go
index cfbc473..4d94a7d 100644
--- a/services/api/console.go
+++ b/services/api/console.go
@@ -32,7 +32,7 @@ const (
errFailedToCommitTransaction = "Failed to commit transaction"
)
-func performAddSource(query string, qType int) (ConsoleResponse, error) {
+func performAddSource(query string, qType int, db *cdb.Handler) (ConsoleResponse, error) {
var sqlQuery string
var sqlParams addSourceRequest
if qType == getQuery {
@@ -53,7 +53,7 @@ func performAddSource(query string, qType int) (ConsoleResponse, error) {
}
// Perform the addSource operation
- results, err := addSource(sqlQuery, sqlParams)
+ results, err := addSource(sqlQuery, sqlParams, db)
if err != nil {
cmn.DebugMsg(cmn.DbgLvlError, "adding the source: %v", err)
return results, err
@@ -105,7 +105,7 @@ func extractAddSourceParams(query string, params *addSourceRequest) {
}
-func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error) {
+func addSource(sqlQuery string, params addSourceRequest, db *cdb.Handler) (ConsoleResponse, error) {
var results ConsoleResponse
results.Message = "Failed to add the source"
@@ -119,20 +119,6 @@ func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error
}
}
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return results, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, "connecting to the database: %v", err)
- return results, err
- }
- defer db.Close()
-
// Get the JSON string for the Config field
configJSON, err := json.Marshal(params.Config)
if err != nil {
@@ -140,7 +126,7 @@ func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error
}
// Execute the SQL statement
- _, err = db.Exec(sqlQuery, params.URL, params.Status, params.Restricted, params.Disabled, params.Flags, string(configJSON), params.CategoryID, params.UsrID)
+ _, err = (*db).Exec(sqlQuery, params.URL, params.Status, params.Restricted, params.Disabled, params.Flags, string(configJSON), params.CategoryID, params.UsrID)
if err != nil {
return results, err
}
@@ -180,7 +166,7 @@ func validateAndReformatConfig(config *cfg.SourceConfig) error {
return nil
}
-func performRemoveSource(query string, qType int) (ConsoleResponse, error) {
+func performRemoveSource(query string, qType int, db *cdb.Handler) (ConsoleResponse, error) {
var results ConsoleResponse
var sourceURL string // Assuming the source URL is passed. Adjust as necessary based on input.
@@ -193,21 +179,8 @@ func performRemoveSource(query string, qType int) (ConsoleResponse, error) {
return ConsoleResponse{Message: "Invalid request"}, nil
}
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return ConsoleResponse{Message: errFailedToInitializeDBHandler}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- return ConsoleResponse{Message: errFailedToConnectToDB}, err
- }
- defer db.Close()
-
// Start a transaction
- tx, err := db.Begin()
+ tx, err := (*db).Begin()
if err != nil {
return ConsoleResponse{Message: errFailedToStartTransaction}, err
}
@@ -269,7 +242,7 @@ func removeSource(tx *sql.Tx, sourceURL string) (ConsoleResponse, error) {
return results, nil
}
-func performGetURLStatus(query string, qType int) (StatusResponse, error) {
+func performGetURLStatus(query string, qType int, db *cdb.Handler) (StatusResponse, error) {
var results StatusResponse
var sourceURL string // Assuming the source URL is passed. Adjust as necessary based on input.
@@ -282,21 +255,8 @@ func performGetURLStatus(query string, qType int) (StatusResponse, error) {
return StatusResponse{Message: "Invalid request"}, nil
}
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return StatusResponse{Message: errFailedToInitializeDBHandler}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- return StatusResponse{Message: errFailedToConnectToDB}, err
- }
- defer db.Close()
-
// Start a transaction
- tx, err := db.Begin()
+ tx, err := (*db).Begin()
if err != nil {
return StatusResponse{Message: errFailedToStartTransaction}, err
}
@@ -362,23 +322,11 @@ func getURLStatus(tx *sql.Tx, sourceURL string) (StatusResponse, error) {
return results, nil
}
-func performGetAllURLStatus(_ int) (StatusResponse, error) {
+func performGetAllURLStatus(_ int, db *cdb.Handler) (StatusResponse, error) {
// using _ instead of qType because for now we don't need it
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return StatusResponse{Message: errFailedToInitializeDBHandler}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- return StatusResponse{Message: errFailedToConnectToDB}, err
- }
- defer db.Close()
// Start a transaction
- tx, err := db.Begin()
+ tx, err := (*db).Begin()
if err != nil {
return StatusResponse{Message: errFailedToStartTransaction}, err
}
diff --git a/services/api/main.go b/services/api/main.go
index 121a264..2fd38b0 100644
--- a/services/api/main.go
+++ b/services/api/main.go
@@ -31,6 +31,7 @@ import (
cmn "github.com/pzaino/thecrowler/pkg/common"
cfg "github.com/pzaino/thecrowler/pkg/config"
+ cdb "github.com/pzaino/thecrowler/pkg/database"
"golang.org/x/time/rate"
)
@@ -46,6 +47,7 @@ var (
configMutex sync.Mutex
configFile *string
dbSemaphore chan struct{} // Semaphore for the database connection
+ dbHandler cdb.Handler
)
func initAll(configFile *string, config *cfg.Config, lmt **rate.Limiter) error {
@@ -91,6 +93,23 @@ func initAll(configFile *string, config *cfg.Config, lmt **rate.Limiter) error {
// Set the database semaphore
dbSemaphore = make(chan struct{}, config.Database.MaxConns-3)
+ // Initialize the database
+ var connected bool = false
+ dbHandler, err = cdb.NewHandler(*config)
+ if err != nil {
+ cmn.DebugMsg(cmn.DbgLvlError, "Error allocating database resources: %v", err)
+ } else {
+ for !connected {
+ err = dbHandler.Connect(*config)
+ if err != nil {
+ cmn.DebugMsg(cmn.DbgLvlError, "Error opening database connection: %v", err)
+ time.Sleep(5 * time.Second)
+ continue
+ }
+ connected = true
+ }
+ }
+
return nil
}
@@ -257,12 +276,6 @@ func SecurityHeadersMiddleware(next http.Handler) http.Handler {
}
func healthCheckHandler(w http.ResponseWriter, _ *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
// Create a JSON document with the health status
healthStatus := HealthCheck{
Status: "OK",
@@ -274,12 +287,6 @@ func healthCheckHandler(w http.ResponseWriter, _ *http.Request) {
// searchHandler handles the traditional search requests
func searchHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}: // Try to acquire a DB connection
defer func() { <-dbSemaphore }() // Release the connection after the work is done
@@ -291,7 +298,7 @@ func searchHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performSearch(query)
+ results, err := performSearch(query, &dbHandler)
results.SetHeaderFields(
"customsearch#search",
jsonResponse,
@@ -321,12 +328,6 @@ func searchHandler(w http.ResponseWriter, r *http.Request) {
}
func webObjectHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -338,7 +339,7 @@ func webObjectHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performWebObjectSearch(query, getQTypeFromName(r.Method))
+ results, err := performWebObjectSearch(query, getQTypeFromName(r.Method), &dbHandler)
if results.IsEmpty() {
var retCode int
if config.API.Return404 {
@@ -377,12 +378,6 @@ func webObjectHandler(w http.ResponseWriter, r *http.Request) {
}
func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -394,7 +389,7 @@ func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performCorrelatedSitesSearch(query, getQTypeFromName(r.Method))
+ results, err := performCorrelatedSitesSearch(query, getQTypeFromName(r.Method), &dbHandler)
if results.IsEmpty() {
var retCode int
if config.API.Return404 {
@@ -434,12 +429,6 @@ func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) {
// scrImgSrchHandler handles the search requests for screenshot images
func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -451,7 +440,7 @@ func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performScreenshotSearch(query, getQTypeFromName(r.Method))
+ results, err := performScreenshotSearch(query, getQTypeFromName(r.Method), &dbHandler)
if results == (ScreenshotResponse{}) {
var retCode int
if config.API.Return404 {
@@ -473,12 +462,6 @@ func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) {
// netInfoHandler handles the network information requests
func netInfoHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -490,7 +473,7 @@ func netInfoHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performNetInfoSearch(query, getQTypeFromName(r.Method))
+ results, err := performNetInfoSearch(query, getQTypeFromName(r.Method), &dbHandler)
if results.isEmpty() {
var retCode int
if config.API.Return404 {
@@ -530,12 +513,6 @@ func netInfoHandler(w http.ResponseWriter, r *http.Request) {
// httpInfoHandler handles the http information requests
func httpInfoHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -547,7 +524,7 @@ func httpInfoHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performHTTPInfoSearch(query, getQTypeFromName(r.Method))
+ results, err := performHTTPInfoSearch(query, getQTypeFromName(r.Method), &dbHandler)
if results.IsEmpty() {
var retCode int
if config.API.Return404 {
@@ -587,12 +564,6 @@ func httpInfoHandler(w http.ResponseWriter, r *http.Request) {
// addSourceHandler handles the addition of new sources
func addSourceHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -604,7 +575,7 @@ func addSourceHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performAddSource(query, getQTypeFromName(r.Method))
+ results, err := performAddSource(query, getQTypeFromName(r.Method), &dbHandler)
handleErrorAndRespond(w, err, results, "Error performing addSource: %v", http.StatusInternalServerError, successCode)
case <-time.After(5 * time.Second): // Wait for a connection with timeout
healthStatus := HealthCheck{
@@ -616,12 +587,6 @@ func addSourceHandler(w http.ResponseWriter, r *http.Request) {
// removeSourceHandler handles the removal of sources
func removeSourceHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -633,7 +598,7 @@ func removeSourceHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performRemoveSource(query, getQTypeFromName(r.Method))
+ results, err := performRemoveSource(query, getQTypeFromName(r.Method), &dbHandler)
handleErrorAndRespond(w, err, results, "Error performing removeSource: %v", http.StatusInternalServerError, successCode)
case <-time.After(5 * time.Second): // Wait for a connection with timeout
healthStatus := HealthCheck{
@@ -645,12 +610,6 @@ func removeSourceHandler(w http.ResponseWriter, r *http.Request) {
// singleURLstatusHandler handles the status requests
func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
@@ -662,7 +621,7 @@ func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) {
return
}
- results, err := performGetURLStatus(query, getQTypeFromName(r.Method))
+ results, err := performGetURLStatus(query, getQTypeFromName(r.Method), &dbHandler)
handleErrorAndRespond(w, err, results, "Error performing status: %v", http.StatusInternalServerError, successCode)
case <-time.After(5 * time.Second): // Wait for a connection with timeout
healthStatus := HealthCheck{
@@ -674,19 +633,13 @@ func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) {
// allURLstatusHandler handles the status requests for all sources
func allURLstatusHandler(w http.ResponseWriter, r *http.Request) {
- if !limiter.Allow() {
- cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed)
- http.Error(w, errTooManyRequests, http.StatusTooManyRequests)
- return
- }
-
select {
case dbSemaphore <- struct{}{}:
defer func() { <-dbSemaphore }()
successCode := http.StatusOK
- results, err := performGetAllURLStatus(getQTypeFromName(r.Method))
+ results, err := performGetAllURLStatus(getQTypeFromName(r.Method), &dbHandler)
handleErrorAndRespond(w, err, results, "Error performing status: %v", http.StatusInternalServerError, successCode)
case <-time.After(5 * time.Second): // Wait for a connection with timeout
healthStatus := HealthCheck{
diff --git a/services/api/search_engine.go b/services/api/search_engine.go
index 1cb49e5..7c09303 100644
--- a/services/api/search_engine.go
+++ b/services/api/search_engine.go
@@ -405,21 +405,7 @@ func isLogicalOperator(op string) bool {
return op == "AND" || op == "OR"
}
-func performSearch(query string) (SearchResult, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return SearchResult{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return SearchResult{}, err
- }
- defer db.Close()
-
+func performSearch(query string, db *cdb.Handler) (SearchResult, error) {
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
// Prepare the query body
@@ -471,7 +457,7 @@ func performSearch(query string) (SearchResult, error) {
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return SearchResult{}, err
}
@@ -513,20 +499,8 @@ func performSearch(query string) (SearchResult, error) {
return results, nil
}
-func performScreenshotSearch(query string, qType int) (ScreenshotResponse, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return ScreenshotResponse{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return ScreenshotResponse{}, err
- }
- defer db.Close()
+func performScreenshotSearch(query string, qType int, db *cdb.Handler) (ScreenshotResponse, error) {
+ var err error
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
@@ -558,7 +532,7 @@ func performScreenshotSearch(query string, qType int) (ScreenshotResponse, error
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return ScreenshotResponse{}, err
}
@@ -694,21 +668,8 @@ func parseScreenshotQuery(input string) (SearchQuery, error) {
return SearchQuery{sqlQuery, sqlParams, 10, 0, Details{}}, nil
}
-func performWebObjectSearch(query string, qType int) (WebObjectResponse, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return WebObjectResponse{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return WebObjectResponse{}, err
- }
- defer db.Close()
-
+func performWebObjectSearch(query string, qType int, db *cdb.Handler) (WebObjectResponse, error) {
+ var err error
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
// Parse the user input
@@ -739,7 +700,7 @@ func performWebObjectSearch(query string, qType int) (WebObjectResponse, error)
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return WebObjectResponse{}, err
}
@@ -879,21 +840,8 @@ func parseWebObjectQuery(input string) (SearchQuery, error) {
return SearchQuery{sqlQuery, sqlParams, 10, 0, Details{}}, nil
}
-func performCorrelatedSitesSearch(query string, qType int) (CorrelatedSitesResponse, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return CorrelatedSitesResponse{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return CorrelatedSitesResponse{}, err
- }
- defer db.Close()
-
+func performCorrelatedSitesSearch(query string, qType int, db *cdb.Handler) (CorrelatedSitesResponse, error) {
+ var err error
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
// Parse the user input
@@ -924,7 +872,7 @@ func performCorrelatedSitesSearch(query string, qType int) (CorrelatedSitesRespo
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return CorrelatedSitesResponse{}, err
}
@@ -1101,21 +1049,8 @@ func parseCorrelatedSitesQuery(input string) (SearchQuery, error) {
}
// performNetInfoSearch performs a search for network information.
-func performNetInfoSearch(query string, qType int) (NetInfoResponse, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return NetInfoResponse{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return NetInfoResponse{}, err
- }
- defer db.Close()
-
+func performNetInfoSearch(query string, qType int, db *cdb.Handler) (NetInfoResponse, error) {
+ var err error
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
// Parse the user input
@@ -1146,7 +1081,7 @@ func performNetInfoSearch(query string, qType int) (NetInfoResponse, error) {
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return NetInfoResponse{}, err
}
@@ -1280,21 +1215,8 @@ func parseNetInfoQuery(input string) (SearchQuery, error) {
}
// performNetInfoSearch performs a search for network information.
-func performHTTPInfoSearch(query string, qType int) (HTTPInfoResponse, error) {
- // Initialize the database handler
- db, err := cdb.NewHandler(config)
- if err != nil {
- return HTTPInfoResponse{}, err
- }
-
- // Connect to the database
- err = db.Connect(config)
- if err != nil {
- cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err)
- return HTTPInfoResponse{}, err
- }
- defer db.Close()
-
+func performHTTPInfoSearch(query string, qType int, db *cdb.Handler) (HTTPInfoResponse, error) {
+ var err error
cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query)
// Parse the user input
@@ -1325,7 +1247,7 @@ func performHTTPInfoSearch(query string, qType int) (HTTPInfoResponse, error) {
start := time.Now()
// Execute the query
- rows, err := db.ExecuteQuery(sqlQuery, sqlParams...)
+ rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...)
if err != nil {
return HTTPInfoResponse{}, err
}