Skip to content

Commit

Permalink
Add large topk tests (#160)
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Oct 23, 2023
1 parent f4c1757 commit 2f3b6e6
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {

const int64_t nb = 1000, nq = 10;
const int64_t dim = 128;
const int64_t topk = 5;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::L2, knowhere::metric::COSINE);
auto topk = GENERATE(as<int64_t>{}, 5, 120);
auto version = GenTestVersionList();

auto base_gen = [=]() {
Expand Down Expand Up @@ -89,7 +89,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
knowhere::Json json = base_gen();
json[knowhere::indexparam::HNSW_M] = 128;
json[knowhere::indexparam::EFCONSTRUCTION] = 200;
json[knowhere::indexparam::EF] = 64;
json[knowhere::indexparam::EF] = 200;
return json;
};

Expand Down Expand Up @@ -270,9 +270,9 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") {

const int64_t nb = 1000, nq = 10;
const int64_t dim = 1024;
const int64_t topk = 5;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD);
auto topk = GENERATE(as<int64_t>{}, 5, 120);
auto version = GenTestVersionList();
auto base_gen = [=]() {
knowhere::Json json;
Expand All @@ -288,15 +288,15 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") {
auto ivfflat_gen = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NPROBE] = 8;
json[knowhere::indexparam::NPROBE] = 14;
return json;
};

auto hnsw_gen = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::HNSW_M] = 128;
json[knowhere::indexparam::EFCONSTRUCTION] = 200;
json[knowhere::indexparam::EF] = 64;
json[knowhere::indexparam::EF] = 200;
return json;
};

Expand Down Expand Up @@ -377,11 +377,11 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") {
using Catch::Approx;

const int64_t nb = 1000, nq = 10;
const int64_t topk = 5;

auto dim = GENERATE(as<int64_t>{}, 8, 16, 32, 64, 128, 256, 512, 160);
auto version = GenTestVersionList();
auto metric = GENERATE(as<std::string>{}, knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE);
auto topk = GENERATE(as<int64_t>{}, 5, 100);

auto base_gen = [=]() {
knowhere::Json json;
Expand Down Expand Up @@ -441,11 +441,17 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") {
auto code_size = dim / 8;
for (int64_t i = 0; i < nq; i++) {
const uint8_t* query_vector = (const uint8_t*)query_ds->GetTensor() + i * code_size;
std::vector<int64_t> ids_v(ids + i * topk, ids + (i + 1) * topk);
auto ds = GenIdsDataSet(topk, ids_v);
// filter out -1 when the result num less than topk
int64_t real_topk = 0;
for (; real_topk < topk; real_topk++) {
if (ids[i * topk + real_topk] < 0)
break;
}
std::vector<int64_t> ids_v(ids + i * topk, ids + i * topk + real_topk);
auto ds = GenIdsDataSet(real_topk, ids_v);
auto gv_res = idx.GetVectorByIds(*ds);
REQUIRE(gv_res.has_value());
for (int64_t j = 0; j < topk; j++) {
for (int64_t j = 0; j < real_topk; j++) {
const uint8_t* res_vector = (const uint8_t*)gv_res.value()->GetTensor() + j * code_size;
if (metric == knowhere::metric::SUPERSTRUCTURE) {
REQUIRE(faiss::is_subset(res_vector, query_vector, code_size));
Expand Down

0 comments on commit 2f3b6e6

Please sign in to comment.