Skip to content

Commit

Permalink
fix cardinal ci (#999)
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy authored Dec 20, 2024
1 parent c24fd21 commit 9e514f3
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 79 deletions.
12 changes: 7 additions & 5 deletions include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ class IndexFactory {
// Please review carefully and select with caution

// register vector index supporting ALL_TYPE(binary, bf16, fp16, fp32, sparse_float32) data types
#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);
#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::SPARSE_FLOAT32), \
##__VA_ARGS__);

// register vector index supporting sparse_float32
#define KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(name, index_node, features, ...) \
Expand Down
2 changes: 2 additions & 0 deletions tests/ut/test_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ checkBuildConfig(knowhere::IndexType indexType, knowhere::Json& json) {
json, msg) == knowhere::Status::success);
CHECK(msg.empty());
}
#ifndef KNOWHERE_WITH_CARDINAL
if (knowhere::IndexFactory::Instance().FeatureCheck(indexType, knowhere::feature::INT8)) {
CHECK(knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
indexType, knowhere::Version::GetCurrentVersion().VersionNumber(), json, msg) ==
knowhere::Status::success);
CHECK(msg.empty());
}
#endif
}

TEST_CASE("Test config json parse", "[config]") {
Expand Down
161 changes: 90 additions & 71 deletions tests/ut/test_faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/knowhere_config.h"
#include "knowhere/dataset.h"
#include "knowhere/index/index_factory.h"
#include "utils.h"

namespace {
Expand Down Expand Up @@ -243,6 +244,12 @@ get_index_name(const std::string& ann_test_name, const std::string& index_type,
//
const std::string ann_test_name_ = "faiss_hnsw";

bool
index_support_int8(const knowhere::Json& conf) {
const std::string index_type = conf[knowhere::meta::INDEX_TYPE].get<std::string>();
return knowhere::IndexFactory::Instance().FeatureCheck(index_type, knowhere::feature::INT8);
}

//
template <typename T>
void
Expand Down Expand Up @@ -537,14 +544,16 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {

test_hnsw<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);
if (index_support_int8(conf)) {
// int8 candidate
printf(
"\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points filtered "
"out\n",
DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100));

// int8 candidate
printf(
"\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n",
DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);
test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);
}
}
}
}
Expand Down Expand Up @@ -635,17 +644,18 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {

test_hnsw<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);

// int8 candidate
printf(
"\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points filtered "
"out\n",
sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);

if (index_support_int8(conf)) {
// int8 candidate
printf(
"\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points "
"filtered "
"out\n",
sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params,
conf, bitset_view);
}
// test refines for FP32
{
const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type];
Expand Down Expand Up @@ -818,17 +828,17 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {

test_hnsw<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);
if (index_support_int8(conf)) {
// test int8 candidate
printf(
"\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points "
"filtered out\n",
pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

// test int8 candidate
printf(
"\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points "
"filtered out\n",
pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params,
conf, bitset_view);
}
// test refines for fp32
for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size();
allowed_ref_idx++) {
Expand Down Expand Up @@ -995,16 +1005,17 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {
test_hnsw<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);

// test int8 candidate
printf(
"\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points "
"filtered out\n",
prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf,
bitset_view);
if (index_support_int8(conf)) {
// test int8 candidate
printf(
"\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points "
"filtered out\n",
prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb,
int(bitset_rate * 100));

test_hnsw<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params,
conf, bitset_view);
}
// test fp32 refines
for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size();
allowed_ref_idx++) {
Expand Down Expand Up @@ -1272,16 +1283,17 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra

test_hnsw_range<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(), params,
conf, bitset_view);
if (index_support_int8(conf)) {
// int8 candidate
printf(
"\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter,
int(bitset_rate * 100));

// int8 candidate
printf(
"\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter,
int(bitset_rate * 100));

test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(), params,
conf, bitset_view);
test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
}
}
}
}
Expand Down Expand Up @@ -1387,15 +1399,17 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
test_hnsw_range<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);

// int8 candidate
printf(
"\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius,
range_filter, int(bitset_rate * 100));
if (index_support_int8(conf)) {
// int8 candidate
printf(
"\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius,
range_filter, int(bitset_rate * 100));

test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
}

// test refines for FP32
{
Expand Down Expand Up @@ -1587,15 +1601,17 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
test_hnsw_range<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);

// test int8 candidate
printf(
"\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius,
range_filter, int(bitset_rate * 100));
if (index_support_int8(conf)) {
// test int8 candidate
printf(
"\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius,
range_filter, int(bitset_rate * 100));

test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
}

// test refines for fp32
for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size();
Expand Down Expand Up @@ -1781,15 +1797,18 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
test_hnsw_range<knowhere::bf16>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);

// test int8 candidate
printf(
"\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb,
radius, range_filter, int(bitset_rate * 100));

test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
if (index_support_int8(conf)) {
// test int8 candidate
printf(
"\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, "
"radius=%f, "
"range_filter=%f, %d%% points filtered out\n",
prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim,
nb, radius, range_filter, int(bitset_rate * 100));

test_hnsw_range<knowhere::int8>(default_ds_ptr, query_ds_ptr, golden_result.value(),
params, conf, bitset_view);
}

// test fp32 refines
for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size();
Expand Down
11 changes: 8 additions & 3 deletions tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") {
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BF16));
#ifndef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::INT8));
#endif

#ifdef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BINARY));
Expand All @@ -352,18 +354,21 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") {
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::BF16));
#ifndef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::INT8));

#endif
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::BF16));
#ifndef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::INT8));

#endif
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::BF16));
#ifndef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::INT8));

#endif
// Sparse Indexes
REQUIRE_FALSE(
IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_INVERTED_INDEX, knowhere::feature::FLOAT32));
Expand Down

0 comments on commit 9e514f3

Please sign in to comment.