From d60d706b638fd1ca2b3c50d8e4608ffc9ce8ec4b Mon Sep 17 00:00:00 2001 From: Paolo Fabio Zaino Date: Fri, 27 Sep 2024 17:42:18 +0100 Subject: [PATCH] Completed the work to improve API performance under heavy load --- services/api/console.go | 72 +++------------------- services/api/main.go | 105 +++++++++---------------------- services/api/search_engine.go | 112 ++++++---------------------------- 3 files changed, 56 insertions(+), 233 deletions(-) diff --git a/services/api/console.go b/services/api/console.go index cfbc473..4d94a7d 100644 --- a/services/api/console.go +++ b/services/api/console.go @@ -32,7 +32,7 @@ const ( errFailedToCommitTransaction = "Failed to commit transaction" ) -func performAddSource(query string, qType int) (ConsoleResponse, error) { +func performAddSource(query string, qType int, db *cdb.Handler) (ConsoleResponse, error) { var sqlQuery string var sqlParams addSourceRequest if qType == getQuery { @@ -53,7 +53,7 @@ func performAddSource(query string, qType int) (ConsoleResponse, error) { } // Perform the addSource operation - results, err := addSource(sqlQuery, sqlParams) + results, err := addSource(sqlQuery, sqlParams, db) if err != nil { cmn.DebugMsg(cmn.DbgLvlError, "adding the source: %v", err) return results, err @@ -105,7 +105,7 @@ func extractAddSourceParams(query string, params *addSourceRequest) { } -func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error) { +func addSource(sqlQuery string, params addSourceRequest, db *cdb.Handler) (ConsoleResponse, error) { var results ConsoleResponse results.Message = "Failed to add the source" @@ -119,20 +119,6 @@ func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error } } - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return results, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, "connecting to the database: %v", err) - return results, err - } - defer db.Close() - // Get the JSON string for the Config field configJSON, err := json.Marshal(params.Config) if err != nil { @@ -140,7 +126,7 @@ func addSource(sqlQuery string, params addSourceRequest) (ConsoleResponse, error } // Execute the SQL statement - _, err = db.Exec(sqlQuery, params.URL, params.Status, params.Restricted, params.Disabled, params.Flags, string(configJSON), params.CategoryID, params.UsrID) + _, err = (*db).Exec(sqlQuery, params.URL, params.Status, params.Restricted, params.Disabled, params.Flags, string(configJSON), params.CategoryID, params.UsrID) if err != nil { return results, err } @@ -180,7 +166,7 @@ func validateAndReformatConfig(config *cfg.SourceConfig) error { return nil } -func performRemoveSource(query string, qType int) (ConsoleResponse, error) { +func performRemoveSource(query string, qType int, db *cdb.Handler) (ConsoleResponse, error) { var results ConsoleResponse var sourceURL string // Assuming the source URL is passed. Adjust as necessary based on input. @@ -193,21 +179,8 @@ func performRemoveSource(query string, qType int) (ConsoleResponse, error) { return ConsoleResponse{Message: "Invalid request"}, nil } - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return ConsoleResponse{Message: errFailedToInitializeDBHandler}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - return ConsoleResponse{Message: errFailedToConnectToDB}, err - } - defer db.Close() - // Start a transaction - tx, err := db.Begin() + tx, err := (*db).Begin() if err != nil { return ConsoleResponse{Message: errFailedToStartTransaction}, err } @@ -269,7 +242,7 @@ func removeSource(tx *sql.Tx, sourceURL string) (ConsoleResponse, error) { return results, nil } -func performGetURLStatus(query string, qType int) (StatusResponse, error) { +func performGetURLStatus(query string, qType int, db *cdb.Handler) (StatusResponse, error) { var results StatusResponse var sourceURL string // Assuming the source URL is passed. Adjust as necessary based on input. @@ -282,21 +255,8 @@ func performGetURLStatus(query string, qType int) (StatusResponse, error) { return StatusResponse{Message: "Invalid request"}, nil } - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return StatusResponse{Message: errFailedToInitializeDBHandler}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - return StatusResponse{Message: errFailedToConnectToDB}, err - } - defer db.Close() - // Start a transaction - tx, err := db.Begin() + tx, err := (*db).Begin() if err != nil { return StatusResponse{Message: errFailedToStartTransaction}, err } @@ -362,23 +322,11 @@ func getURLStatus(tx *sql.Tx, sourceURL string) (StatusResponse, error) { return results, nil } -func performGetAllURLStatus(_ int) (StatusResponse, error) { +func performGetAllURLStatus(_ int, db *cdb.Handler) (StatusResponse, error) { // using _ instead of qType because for now we don't need it - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return StatusResponse{Message: errFailedToInitializeDBHandler}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - return StatusResponse{Message: errFailedToConnectToDB}, err - } - defer db.Close() // Start a transaction - tx, err := db.Begin() + tx, err := (*db).Begin() if err != nil { return StatusResponse{Message: errFailedToStartTransaction}, err } diff --git a/services/api/main.go b/services/api/main.go index 121a264..2fd38b0 100644 --- a/services/api/main.go +++ b/services/api/main.go @@ -31,6 +31,7 @@ import ( cmn "github.com/pzaino/thecrowler/pkg/common" cfg "github.com/pzaino/thecrowler/pkg/config" + cdb "github.com/pzaino/thecrowler/pkg/database" "golang.org/x/time/rate" ) @@ -46,6 +47,7 @@ var ( configMutex sync.Mutex configFile *string dbSemaphore chan struct{} // Semaphore for the database connection + dbHandler cdb.Handler ) func initAll(configFile *string, config *cfg.Config, lmt **rate.Limiter) error { @@ -91,6 +93,23 @@ func initAll(configFile *string, config *cfg.Config, lmt **rate.Limiter) error { // Set the database semaphore dbSemaphore = make(chan struct{}, config.Database.MaxConns-3) + // Initialize the database + var connected bool = false + dbHandler, err = cdb.NewHandler(*config) + if err != nil { + cmn.DebugMsg(cmn.DbgLvlError, "Error allocating database resources: %v", err) + } else { + for !connected { + err = dbHandler.Connect(*config) + if err != nil { + cmn.DebugMsg(cmn.DbgLvlError, "Error opening database connection: %v", err) + time.Sleep(5 * time.Second) + continue + } + connected = true + } + } + return nil } @@ -257,12 +276,6 @@ func SecurityHeadersMiddleware(next http.Handler) http.Handler { } func healthCheckHandler(w http.ResponseWriter, _ *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - // Create a JSON document with the health status healthStatus := HealthCheck{ Status: "OK", @@ -274,12 +287,6 @@ func healthCheckHandler(w http.ResponseWriter, _ *http.Request) { // searchHandler handles the traditional search requests func searchHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: // Try to acquire a DB connection defer func() { <-dbSemaphore }() // Release the connection after the work is done @@ -291,7 +298,7 @@ func searchHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performSearch(query) + results, err := performSearch(query, &dbHandler) results.SetHeaderFields( "customsearch#search", jsonResponse, @@ -321,12 +328,6 @@ func searchHandler(w http.ResponseWriter, r *http.Request) { } func webObjectHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -338,7 +339,7 @@ func webObjectHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performWebObjectSearch(query, getQTypeFromName(r.Method)) + results, err := performWebObjectSearch(query, getQTypeFromName(r.Method), &dbHandler) if results.IsEmpty() { var retCode int if config.API.Return404 { @@ -377,12 +378,6 @@ func webObjectHandler(w http.ResponseWriter, r *http.Request) { } func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -394,7 +389,7 @@ func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performCorrelatedSitesSearch(query, getQTypeFromName(r.Method)) + results, err := performCorrelatedSitesSearch(query, getQTypeFromName(r.Method), &dbHandler) if results.IsEmpty() { var retCode int if config.API.Return404 { @@ -434,12 +429,6 @@ func webCorrelatedSitesHandler(w http.ResponseWriter, r *http.Request) { // scrImgSrchHandler handles the search requests for screenshot images func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -451,7 +440,7 @@ func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performScreenshotSearch(query, getQTypeFromName(r.Method)) + results, err := performScreenshotSearch(query, getQTypeFromName(r.Method), &dbHandler) if results == (ScreenshotResponse{}) { var retCode int if config.API.Return404 { @@ -473,12 +462,6 @@ func scrImgSrchHandler(w http.ResponseWriter, r *http.Request) { // netInfoHandler handles the network information requests func netInfoHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -490,7 +473,7 @@ func netInfoHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performNetInfoSearch(query, getQTypeFromName(r.Method)) + results, err := performNetInfoSearch(query, getQTypeFromName(r.Method), &dbHandler) if results.isEmpty() { var retCode int if config.API.Return404 { @@ -530,12 +513,6 @@ func netInfoHandler(w http.ResponseWriter, r *http.Request) { // httpInfoHandler handles the http information requests func httpInfoHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -547,7 +524,7 @@ func httpInfoHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performHTTPInfoSearch(query, getQTypeFromName(r.Method)) + results, err := performHTTPInfoSearch(query, getQTypeFromName(r.Method), &dbHandler) if results.IsEmpty() { var retCode int if config.API.Return404 { @@ -587,12 +564,6 @@ func httpInfoHandler(w http.ResponseWriter, r *http.Request) { // addSourceHandler handles the addition of new sources func addSourceHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -604,7 +575,7 @@ func addSourceHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performAddSource(query, getQTypeFromName(r.Method)) + results, err := performAddSource(query, getQTypeFromName(r.Method), &dbHandler) handleErrorAndRespond(w, err, results, "Error performing addSource: %v", http.StatusInternalServerError, successCode) case <-time.After(5 * time.Second): // Wait for a connection with timeout healthStatus := HealthCheck{ @@ -616,12 +587,6 @@ func addSourceHandler(w http.ResponseWriter, r *http.Request) { // removeSourceHandler handles the removal of sources func removeSourceHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -633,7 +598,7 @@ func removeSourceHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performRemoveSource(query, getQTypeFromName(r.Method)) + results, err := performRemoveSource(query, getQTypeFromName(r.Method), &dbHandler) handleErrorAndRespond(w, err, results, "Error performing removeSource: %v", http.StatusInternalServerError, successCode) case <-time.After(5 * time.Second): // Wait for a connection with timeout healthStatus := HealthCheck{ @@ -645,12 +610,6 @@ func removeSourceHandler(w http.ResponseWriter, r *http.Request) { // singleURLstatusHandler handles the status requests func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() @@ -662,7 +621,7 @@ func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) { return } - results, err := performGetURLStatus(query, getQTypeFromName(r.Method)) + results, err := performGetURLStatus(query, getQTypeFromName(r.Method), &dbHandler) handleErrorAndRespond(w, err, results, "Error performing status: %v", http.StatusInternalServerError, successCode) case <-time.After(5 * time.Second): // Wait for a connection with timeout healthStatus := HealthCheck{ @@ -674,19 +633,13 @@ func singleURLstatusHandler(w http.ResponseWriter, r *http.Request) { // allURLstatusHandler handles the status requests for all sources func allURLstatusHandler(w http.ResponseWriter, r *http.Request) { - if !limiter.Allow() { - cmn.DebugMsg(cmn.DbgLvlDebug, errRateLimitExceed) - http.Error(w, errTooManyRequests, http.StatusTooManyRequests) - return - } - select { case dbSemaphore <- struct{}{}: defer func() { <-dbSemaphore }() successCode := http.StatusOK - results, err := performGetAllURLStatus(getQTypeFromName(r.Method)) + results, err := performGetAllURLStatus(getQTypeFromName(r.Method), &dbHandler) handleErrorAndRespond(w, err, results, "Error performing status: %v", http.StatusInternalServerError, successCode) case <-time.After(5 * time.Second): // Wait for a connection with timeout healthStatus := HealthCheck{ diff --git a/services/api/search_engine.go b/services/api/search_engine.go index 1cb49e5..7c09303 100644 --- a/services/api/search_engine.go +++ b/services/api/search_engine.go @@ -405,21 +405,7 @@ func isLogicalOperator(op string) bool { return op == "AND" || op == "OR" } -func performSearch(query string) (SearchResult, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return SearchResult{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return SearchResult{}, err - } - defer db.Close() - +func performSearch(query string, db *cdb.Handler) (SearchResult, error) { cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) // Prepare the query body @@ -471,7 +457,7 @@ func performSearch(query string) (SearchResult, error) { start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return SearchResult{}, err } @@ -513,20 +499,8 @@ func performSearch(query string) (SearchResult, error) { return results, nil } -func performScreenshotSearch(query string, qType int) (ScreenshotResponse, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return ScreenshotResponse{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return ScreenshotResponse{}, err - } - defer db.Close() +func performScreenshotSearch(query string, qType int, db *cdb.Handler) (ScreenshotResponse, error) { + var err error cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) @@ -558,7 +532,7 @@ func performScreenshotSearch(query string, qType int) (ScreenshotResponse, error start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return ScreenshotResponse{}, err } @@ -694,21 +668,8 @@ func parseScreenshotQuery(input string) (SearchQuery, error) { return SearchQuery{sqlQuery, sqlParams, 10, 0, Details{}}, nil } -func performWebObjectSearch(query string, qType int) (WebObjectResponse, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return WebObjectResponse{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return WebObjectResponse{}, err - } - defer db.Close() - +func performWebObjectSearch(query string, qType int, db *cdb.Handler) (WebObjectResponse, error) { + var err error cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) // Parse the user input @@ -739,7 +700,7 @@ func performWebObjectSearch(query string, qType int) (WebObjectResponse, error) start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return WebObjectResponse{}, err } @@ -879,21 +840,8 @@ func parseWebObjectQuery(input string) (SearchQuery, error) { return SearchQuery{sqlQuery, sqlParams, 10, 0, Details{}}, nil } -func performCorrelatedSitesSearch(query string, qType int) (CorrelatedSitesResponse, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return CorrelatedSitesResponse{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return CorrelatedSitesResponse{}, err - } - defer db.Close() - +func performCorrelatedSitesSearch(query string, qType int, db *cdb.Handler) (CorrelatedSitesResponse, error) { + var err error cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) // Parse the user input @@ -924,7 +872,7 @@ func performCorrelatedSitesSearch(query string, qType int) (CorrelatedSitesRespo start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return CorrelatedSitesResponse{}, err } @@ -1101,21 +1049,8 @@ func parseCorrelatedSitesQuery(input string) (SearchQuery, error) { } // performNetInfoSearch performs a search for network information. -func performNetInfoSearch(query string, qType int) (NetInfoResponse, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return NetInfoResponse{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return NetInfoResponse{}, err - } - defer db.Close() - +func performNetInfoSearch(query string, qType int, db *cdb.Handler) (NetInfoResponse, error) { + var err error cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) // Parse the user input @@ -1146,7 +1081,7 @@ func performNetInfoSearch(query string, qType int) (NetInfoResponse, error) { start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return NetInfoResponse{}, err } @@ -1280,21 +1215,8 @@ func parseNetInfoQuery(input string) (SearchQuery, error) { } // performNetInfoSearch performs a search for network information. -func performHTTPInfoSearch(query string, qType int) (HTTPInfoResponse, error) { - // Initialize the database handler - db, err := cdb.NewHandler(config) - if err != nil { - return HTTPInfoResponse{}, err - } - - // Connect to the database - err = db.Connect(config) - if err != nil { - cmn.DebugMsg(cmn.DbgLvlError, dbConnErrorLabel, err) - return HTTPInfoResponse{}, err - } - defer db.Close() - +func performHTTPInfoSearch(query string, qType int, db *cdb.Handler) (HTTPInfoResponse, error) { + var err error cmn.DebugMsg(cmn.DbgLvlDebug, searchLabel, query) // Parse the user input @@ -1325,7 +1247,7 @@ func performHTTPInfoSearch(query string, qType int) (HTTPInfoResponse, error) { start := time.Now() // Execute the query - rows, err := db.ExecuteQuery(sqlQuery, sqlParams...) + rows, err := (*db).ExecuteQuery(sqlQuery, sqlParams...) if err != nil { return HTTPInfoResponse{}, err }