From f2b7d47f354f0a189c7b10c445a634a033ee5d36 Mon Sep 17 00:00:00 2001 From: crozzy Date: Mon, 9 Dec 2024 14:08:33 -0800 Subject: [PATCH] datastore: fix get query caching This change replaces caching the DB pgx.Rows, which is an iterator, with the actual potential vulnerabilities. Previously, subsequent records manifesting in the same query were given a pgx.Rows object that had been exhausted. Signed-off-by: crozzy --- datastore/postgres/get.go | 130 +++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 57 deletions(-) diff --git a/datastore/postgres/get.go b/datastore/postgres/get.go index 2a1c9b313..5f3b399af 100644 --- a/datastore/postgres/get.go +++ b/datastore/postgres/get.go @@ -51,7 +51,7 @@ func (s *MatcherStore) Get(ctx context.Context, records []*claircore.IndexRecord defer tx.Rollback(ctx) // start a batch batch := &pgx.Batch{} - resCache := map[string]pgx.Rows{} + resCache := map[string][]*claircore.Vulnerability{} rqs := []*recordQuery{} for _, record := range records { query, err := buildGetQuery(record, &opts) @@ -72,7 +72,6 @@ func (s *MatcherStore) Get(ctx context.Context, records []*claircore.IndexRecord resCache[query] = nil } // send the batch - start := time.Now() res := tx.SendBatch(ctx, batch) // Can't just defer the close, because the batch must be fully handled @@ -83,72 +82,89 @@ func (s *MatcherStore) Get(ctx context.Context, records []*claircore.IndexRecord results := make(map[string][]*claircore.Vulnerability) vulnSet := make(map[string]map[string]struct{}) for _, rq := range rqs { - rows, ok := resCache[rq.query] + rid := rq.record.Package.ID + vulns, ok := resCache[rq.query] if !ok { return nil, fmt.Errorf("unexpected vulnerability query: %s", rq.query) } - if rows == nil { - rows, err = res.Query() - if err != nil { - res.Close() - return nil, err + if vulns != nil { // We already have results we don't need to go back to the DB. + if _, ok := vulnSet[rid]; !ok { + vulnSet[rid] = make(map[string]struct{}) } - resCache[rq.query] = rows - } - - // unpack all returned rows into claircore.Vulnerability structs - for rows.Next() { - // fully allocate vuln struct - v := &claircore.Vulnerability{ - Package: &claircore.Package{}, - Dist: &claircore.Distribution{}, - Repo: &claircore.Repository{}, + for _, v := range vulns { + if _, ok := vulnSet[rid][v.ID]; !ok { + vulnSet[rid][v.ID] = struct{}{} + results[rid] = append(results[rid], v) + } } - - var id int64 - err := rows.Scan( - &id, - &v.Name, - &v.Description, - &v.Issued, - &v.Links, - &v.Severity, - &v.NormalizedSeverity, - &v.Package.Name, - &v.Package.Version, - &v.Package.Module, - &v.Package.Arch, - &v.Package.Kind, - &v.Dist.DID, - &v.Dist.Name, - &v.Dist.Version, - &v.Dist.VersionCodeName, - &v.Dist.VersionID, - &v.Dist.Arch, - &v.Dist.CPE, - &v.Dist.PrettyName, - &v.ArchOperation, - &v.Repo.Name, - &v.Repo.Key, - &v.Repo.URI, - &v.FixedInVersion, - &v.Updater, - ) - v.ID = strconv.FormatInt(id, 10) + continue + } + results[rid] = []*claircore.Vulnerability{} + err := func() error { + rows, err := res.Query() if err != nil { res.Close() - return nil, fmt.Errorf("failed to scan vulnerability: %v", err) + return fmt.Errorf("error getting rows: %w", err) } + defer rows.Close() + // unpack all returned rows into claircore.Vulnerability structs + for rows.Next() { + // fully allocate vuln struct + v := &claircore.Vulnerability{ + Package: &claircore.Package{}, + Dist: &claircore.Distribution{}, + Repo: &claircore.Repository{}, + } - rid := rq.record.Package.ID - if _, ok := vulnSet[rid]; !ok { - vulnSet[rid] = make(map[string]struct{}) - } - if _, ok := vulnSet[rid][v.ID]; !ok { - vulnSet[rid][v.ID] = struct{}{} - results[rid] = append(results[rid], v) + var id int64 + err := rows.Scan( + &id, + &v.Name, + &v.Description, + &v.Issued, + &v.Links, + &v.Severity, + &v.NormalizedSeverity, + &v.Package.Name, + &v.Package.Version, + &v.Package.Module, + &v.Package.Arch, + &v.Package.Kind, + &v.Dist.DID, + &v.Dist.Name, + &v.Dist.Version, + &v.Dist.VersionCodeName, + &v.Dist.VersionID, + &v.Dist.Arch, + &v.Dist.CPE, + &v.Dist.PrettyName, + &v.ArchOperation, + &v.Repo.Name, + &v.Repo.Key, + &v.Repo.URI, + &v.FixedInVersion, + &v.Updater, + ) + v.ID = strconv.FormatInt(id, 10) + if err != nil { + res.Close() + return fmt.Errorf("failed to scan vulnerability: %w", err) + } + + if _, ok := vulnSet[rid]; !ok { + vulnSet[rid] = make(map[string]struct{}) + } + if _, ok := vulnSet[rid][v.ID]; !ok { + vulnSet[rid][v.ID] = struct{}{} + results[rid] = append(results[rid], v) + } } + return nil + }() + if err != nil { + return nil, err } + resCache[rq.query] = results[rid] } if err := res.Close(); err != nil { return nil, fmt.Errorf("some weird batch error: %v", err)