Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Let range search accept distance square as radius (#576)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <[email protected]>

Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored Dec 2, 2022
1 parent 998786c commit a8abbf7
Show file tree
Hide file tree
Showing 18 changed files with 57 additions and 80 deletions.
2 changes: 0 additions & 2 deletions knowhere/archive/BruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ BruteForce::RangeSearch(const DatasetPtr base_dataset,
auto faiss_metric_type = GetFaissMetricType(metric_type);
switch (faiss_metric_type) {
case faiss::METRIC_L2:
low_bound *= low_bound;
high_bound *= high_bound;
faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, high_bound, &res, bitset);
break;
case faiss::METRIC_INNER_PRODUCT:
Expand Down
3 changes: 1 addition & 2 deletions knowhere/index/vector_index/IndexBinaryIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ BinaryIDMAP::QueryByRangeImpl(int64_t n,

faiss::RangeSearchResult res(n);
index->range_search(n, data, high_bound, &res, bitset);
GetRangeSearchResult(res, (index->metric_type == faiss::METRIC_INNER_PRODUCT), n, low_bound, high_bound,
distances, labels, lims, bitset);
GetRangeSearchResult(res, false, n, low_bound, high_bound, distances, labels, lims, bitset);
}

} // namespace knowhere
3 changes: 1 addition & 2 deletions knowhere/index/vector_index/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ BinaryIVF::QueryByRangeImpl(int64_t n,

faiss::RangeSearchResult res(n);
index_->range_search(n, data, high_bound, &res, bitset);
GetRangeSearchResult(res, (ivf_index->metric_type == faiss::METRIC_INNER_PRODUCT), n, low_bound, high_bound,
distances, labels, lims, bitset);
GetRangeSearchResult(res, false, n, low_bound, high_bound, distances, labels, lims, bitset);
}

} // namespace knowhere
12 changes: 4 additions & 8 deletions knowhere/index/vector_index/IndexDiskANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,8 @@ IndexDiskANN<T>::QueryByRange(const DatasetPtr& dataset_ptr, const Config& confi
auto query_conf = DiskANNQueryByRangeConfig::Get(config);
auto low_bound = query_conf.radius_low_bound;
auto high_bound = query_conf.radius_high_bound;
bool is_L2 = (pq_flash_index_->get_metric() == diskann::Metric::L2);
if (is_L2) {
low_bound *= low_bound;
high_bound *= high_bound;
}
float radius = (is_L2 ? high_bound : low_bound);
bool is_ip = (pq_flash_index_->get_metric() == diskann::Metric::INNER_PRODUCT);
float radius = (is_ip ? low_bound : high_bound);

GET_TENSOR_DATA_DIM(dataset_ptr);
auto query = static_cast<const T*>(p_data);
Expand All @@ -457,7 +453,7 @@ IndexDiskANN<T>::QueryByRange(const DatasetPtr& dataset_ptr, const Config& confi
query_conf.search_list_and_k_ratio, bitset);

// filter range search result
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], !is_L2, low_bound,
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, low_bound,
high_bound);
}));
}
Expand All @@ -470,7 +466,7 @@ IndexDiskANN<T>::QueryByRange(const DatasetPtr& dataset_ptr, const Config& confi
int64_t* p_id = nullptr;
float* p_dist = nullptr;

GetRangeSearchResult(result_dist_array, result_id_array, !is_L2, rows, low_bound, high_bound, p_dist, p_id, p_lims);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, rows, low_bound, high_bound, p_dist, p_id, p_lims);

return GenResultDataset(p_id, p_dist, p_lims);
}
Expand Down
14 changes: 5 additions & 9 deletions knowhere/index/vector_index/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,8 @@ IndexHNSW::QueryByRangeImpl(int64_t n, const float* xq, float*& distances, int64

float low_bound = GetMetaRadiusLowBound(config);
float high_bound = GetMetaRadiusHighBound(config);
bool is_L2 = (index_->metric_type_ == 0); // L2: 0, InnerProduct: 1
if (is_L2) {
low_bound *= low_bound;
high_bound *= high_bound;
}
float radius = (is_L2 ? high_bound : 1.0f - low_bound);
bool is_ip = (index_->metric_type_ == 1); // L2: 0, InnerProduct: 1
float radius = (is_ip ? 1.0f - low_bound : high_bound);

std::vector<std::vector<int64_t>> result_id_array(n);
std::vector<std::vector<float>> result_dist_array(n);
Expand All @@ -309,13 +305,13 @@ IndexHNSW::QueryByRangeImpl(int64_t n, const float* xq, float*& distances, int64
result_id_array[index].resize(elem_cnt);
for (size_t j = 0; j < elem_cnt; j++) {
auto& p = rst[j];
result_dist_array[index][j] = (is_L2 ? p.first : (1 - p.first));
result_dist_array[index][j] = (is_ip ? (1 - p.first) : p.first);
result_id_array[index][j] = p.second;
}
result_size[index] = rst.size();

// filter range search result
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], !is_L2, low_bound,
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, low_bound,
high_bound);
}));
}
Expand All @@ -324,7 +320,7 @@ IndexHNSW::QueryByRangeImpl(int64_t n, const float* xq, float*& distances, int64
future.get();
}

GetRangeSearchResult(result_dist_array, result_id_array, !is_L2, n, low_bound, high_bound, distances, labels, lims);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, n, low_bound, high_bound, distances, labels, lims);
}

void
Expand Down
10 changes: 3 additions & 7 deletions knowhere/index/vector_index/IndexIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,12 @@ IDMAP::QueryByRangeImpl(int64_t n,
auto idmap_index = dynamic_cast<faiss::IndexFlat*>(index_.get());
float low_bound = GetMetaRadiusLowBound(config);
float high_bound = GetMetaRadiusHighBound(config);
bool is_L2 = (idmap_index->metric_type == faiss::METRIC_L2);
if (is_L2) {
low_bound *= low_bound;
high_bound *= high_bound;
}
float radius = (is_L2 ? high_bound : low_bound);
bool is_ip = (idmap_index->metric_type == faiss::METRIC_INNER_PRODUCT);
float radius = (is_ip ? low_bound : high_bound);

faiss::RangeSearchResult res(n);
idmap_index->range_search(n, reinterpret_cast<const float*>(data), radius, &res, bitset);
GetRangeSearchResult(res, !is_L2, n, low_bound, high_bound, distances, labels, lims, bitset);
GetRangeSearchResult(res, is_ip, n, low_bound, high_bound, distances, labels, lims, bitset);
}

} // namespace knowhere
10 changes: 3 additions & 7 deletions knowhere/index/vector_index/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,12 @@ IVF::QueryByRangeImpl(int64_t n,

float low_bound = GetMetaRadiusLowBound(config);
float high_bound = GetMetaRadiusHighBound(config);
bool is_L2 = (ivf_index->metric_type == faiss::METRIC_L2);
if (is_L2) {
low_bound *= low_bound;
high_bound *= high_bound;
}
float radius = (is_L2 ? high_bound : low_bound);
bool is_ip = (ivf_index->metric_type == faiss::METRIC_INNER_PRODUCT);
float radius = (is_ip ? low_bound : high_bound);

faiss::RangeSearchResult res(n);
ivf_index->range_search_thread_safe(n, xq, radius, &res, params->nprobe, parallel_mode, max_codes, bitset);
GetRangeSearchResult(res, !is_L2, n, low_bound, high_bound, distances, labels, lims, bitset);
GetRangeSearchResult(res, is_ip, n, low_bound, high_bound, distances, labels, lims, bitset);
}

void
Expand Down
10 changes: 3 additions & 7 deletions knowhere/index/vector_offset_index/IndexIVF_NM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,13 @@ IVF_NM::QueryByRangeImpl(int64_t n,

float low_bound = GetMetaRadiusLowBound(config);
float high_bound = GetMetaRadiusHighBound(config);
bool is_L2 = (ivf_index->metric_type == faiss::METRIC_L2);
if (is_L2) {
low_bound *= low_bound;
high_bound *= high_bound;
}
float radius = (is_L2 ? high_bound : low_bound);
bool is_ip = (ivf_index->metric_type == faiss::METRIC_INNER_PRODUCT);
float radius = (is_ip ? low_bound : high_bound);

faiss::RangeSearchResult res(n);
ivf_index->range_search_without_codes_thread_safe(n, xq, radius, &res, params->nprobe, parallel_mode, max_codes,
bitset);
GetRangeSearchResult(res, !is_L2, n, low_bound, high_bound, distances, labels, lims, bitset);
GetRangeSearchResult(res, is_ip, n, low_bound, high_bound, distances, labels, lims, bitset);
}

void
Expand Down
4 changes: 1 addition & 3 deletions unittest/benchmark/benchmark_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class Benchmark_base {
const size_t* lims,
int32_t nq) {
const float FLOAT_DIFF = 0.00001;
const bool is_L2 = (metric_type == "L2");
for (int32_t i = 0; i < nq; i++) {
std::unordered_set<int32_t> gt_ids_set(gt_ids_ + gt_lims_[i], gt_ids_ + gt_lims_[i + 1]);
std::unordered_map<int32_t, float> gt_map;
Expand All @@ -66,8 +65,7 @@ class Benchmark_base {
}
for (auto j = lims[i]; j < lims[i + 1]; j++) {
if (gt_ids_set.count(ids[j]) > 0) {
float dist = (is_L2) ? std::sqrt(distances[j]) : distances[j];
ASSERT_LT(std::abs(dist - gt_map[ids[j]]), FLOAT_DIFF);
ASSERT_LT(std::abs(distances[j] - gt_map[ids[j]]), FLOAT_DIFF);
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions unittest/benchmark/hdf5/benchmark_knowhere_float_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ class Benchmark_knowhere_float_range : public Benchmark_knowhere, public ::testi
#if 0
TEST_F(Benchmark_knowhere_float_range, TEST_CREATE_HDF5) {
// set this radius to get about 1M result dataset for 10k nq
const float radius = 186.0;
const float low_bound = 0.0;
const float high_bound = 186.0 * 186.0;

std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
RunFloatRangeSearchBF<CMin<float>>(golden_labels, golden_distances, golden_lims, metric_type_,
(const float*)xb_, nb_, (const float*)xq_, nq_, dim_, radius, nullptr);
RunFloatRangeSearchBF(golden_labels, golden_distances, golden_lims, metric_type_,
(const float*)xb_, nb_, (const float*)xq_, nq_, dim_, low_bound, high_bound, nullptr);

// convert golden_lims and golden_ids to int32
std::vector<int32_t> golden_lims_int(nq_ + 1);
Expand All @@ -179,7 +180,7 @@ TEST_F(Benchmark_knowhere_float_range, TEST_CREATE_HDF5) {

assert(dim_ == 128);
assert(nq_ == 10000);
hdf5_write_range<false>("sift-128-euclidean-range.hdf5", dim_, xb_, nb_, xq_, nq_, radius,
hdf5_write_range<false>("sift-128-euclidean-range.hdf5", dim_, xb_, nb_, xq_, nq_, high_bound,
golden_lims_int.data(), golden_ids_int.data(), golden_distances.data());
}
#endif
Expand Down
13 changes: 8 additions & 5 deletions unittest/benchmark/hdf5/benchmark_knowhere_float_range_multi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,25 @@ class Benchmark_knowhere_float_range_multi : public Benchmark_knowhere, public :
// Following these steps:
// 1. set_ann_test_name, eg. "sift-128-euclidean" or "glove-200-angular"
// 2. use parse_ann_test_name() and load_hdf5_data<false>()
// 3. comment SetMetaRadius()
// 4. use the last element in the gt_dist_ as its radius for each nq
// 5. specify the hdf5 file name to generate
// 6. run this testcase
// 3. use the last element in the gt_dist_ as its radius for each nq
// 4. specify the hdf5 file name to generate
// 5. run this testcase
#if 0
TEST_F(Benchmark_knowhere_float_range_multi, TEST_CREATE_HDF5_WITH_MULTI_RADIUS) {
std::vector<float> golden_radius(nq_);
for (int32_t i = 0; i < nq_; i++) {
golden_radius[i] = gt_dist_[(i + 1) * gt_k_ - 1] + 0.01;
golden_radius[i] = std::pow(gt_dist_[(i + 1) * gt_k_ - 1], 2.0) + 0.01;
}

std::vector<int32_t> golden_lims(nq_ + 1);
for (int32_t i = 0; i <= nq_; i++) {
golden_lims[i] = i * gt_k_;
}

for (int32_t i = 0; i < nq_ * gt_k_; i++) {
gt_dist_[i] = std::pow(gt_dist_[i], 2.0);
}

assert(dim_ == 128);
assert(nq_ == 10000);
hdf5_write_range<false>("sift-128-euclidean-range-multi.hdf5", dim_, xb_, nb_, xq_, nq_,
Expand Down
2 changes: 1 addition & 1 deletion unittest/range_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inline float float_vec_dist(
const size_t dim) {
assert(metric_type == knowhere::metric::L2 || metric_type == knowhere::metric::IP);
if (metric_type == knowhere::metric::L2) {
return std::sqrt(faiss::fvec_L2sqr_ref(pa, pb, dim));
return faiss::fvec_L2sqr_ref(pa, pb, dim);
} else {
return faiss::fvec_inner_product_ref(pa, pb, dim);
}
Expand Down
8 changes: 4 additions & 4 deletions unittest/test_bruteforce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ TEST_P(BruteForceTest, float_range_search_l2) {
xb.data(), nb, xq.data(), nq, dim, low_bound, high_bound, bitset);

auto result = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, config, bitset);
CheckRangeSearchResult(result, metric_type, nq, low_bound * low_bound, high_bound * high_bound,
CheckRangeSearchResult(result, metric_type, nq, low_bound, high_bound,
golden_labels.data(), golden_lims.data(), true, bitset);
};

auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold();
for (int64_t blas_threshold : {0, 20}) {
knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold);
for (std::pair<float, float> range: {
std::make_pair<float, float>(0, 4.1f),
std::make_pair<float, float>(4.1f, 4.2f),
std::make_pair<float, float>(4.2f, 4.3f)}) {
std::make_pair<float, float>(0, 16.81f),
std::make_pair<float, float>(16.81f, 17.64f),
std::make_pair<float, float>(17.64f, 18.49f)}) {
test_range_search_l2(range.first, range.second, nullptr);
test_range_search_l2(range.first, range.second, *bitset);
}
Expand Down
5 changes: 2 additions & 3 deletions unittest/test_diskann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ constexpr float kMax = 100;
constexpr uint32_t kK = 10;
constexpr uint32_t kBigK = kNumRows * 2;
constexpr float kL2RadiusLowBound = 0;
constexpr float kL2RadiusHighBound = 550;
constexpr float kL2RadiusHighBound = 300000;
constexpr float kIPRadiusLowBound = 50000;
constexpr float kIPRadiusHighBound = std::numeric_limits<float>::max();
constexpr float kDisLossTolerance = 0.5;
Expand All @@ -62,7 +62,7 @@ constexpr uint32_t kLargeDimNumQueries = 10;
constexpr uint32_t kLargeDim = 5600;
constexpr uint32_t kLargeDimBigK = kLargeDimNumRows * 2;
constexpr float kLargeDimL2RadiusLowBound = 0;
constexpr float kLargeDimL2RadiusHighBound = 6000;
constexpr float kLargeDimL2RadiusHighBound = 36000000;
constexpr float kLargeDimIPRadiusLowBound = 400000;
constexpr float kLargeDimIPRadiusHighBound = std::numeric_limits<float>::max();

Expand Down Expand Up @@ -174,7 +174,6 @@ GenRangeSearchGrounTruth(const float* data_p, const float* query_p, const std::s
for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim
dis += std::pow(xb[dim] - xq[dim], 2);
}
dis = std::sqrt(dis);
}
if (knowhere::distance_in_range(dis, radius_low_bound, radius_high_bound, is_ip)) {
ground_truth->at(query_index).emplace_back(row);
Expand Down
8 changes: 4 additions & 4 deletions unittest/test_hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ TEST_P(HNSWTest, hnsw_range_search_l2) {
ASSERT_TRUE(adapter->CheckRangeSearch(conf_, index_type_, index_mode_));

auto result = index_->QueryByRange(qd, conf_, bitset);
CheckRangeSearchResult(result, metric_type, nq, low_bound * low_bound, high_bound * high_bound,
CheckRangeSearchResult(result, metric_type, nq, low_bound, high_bound,
golden_labels.data(), golden_lims.data(), false, bitset);
};

for (std::pair<float, float> range: {
std::make_pair<float, float>(0, 4.1f),
std::make_pair<float, float>(4.1f, 4.2f),
std::make_pair<float, float>(4.2f, 4.3f)}) {
std::make_pair<float, float>(0, 16.81f),
std::make_pair<float, float>(16.81f, 17.64f),
std::make_pair<float, float>(17.64f, 18.49f)}) {
knowhere::SetMetaRadiusLowBound(conf_, range.first);
knowhere::SetMetaRadiusHighBound(conf_, range.second);
test_range_search_l2(range.first, range.second, nullptr);
Expand Down
8 changes: 4 additions & 4 deletions unittest/test_idmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,17 @@ TEST_P(IDMAPTest, idmap_range_search_l2) {
ASSERT_TRUE(adapter->CheckRangeSearch(conf_, index_type_, index_mode_));

auto result = index_->QueryByRange(qd, conf_, bitset);
CheckRangeSearchResult(result, metric_type, nq, low_bound * low_bound, high_bound * high_bound,
CheckRangeSearchResult(result, metric_type, nq, low_bound, high_bound,
golden_labels.data(), golden_lims.data(), true, bitset);
};

auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold();
for (int64_t blas_threshold : {0, 20}) {
knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold);
for (std::pair<float, float> range: {
std::make_pair<float, float>(0, 4.1f),
std::make_pair<float, float>(4.1f, 4.2f),
std::make_pair<float, float>(4.2f, 4.3f)}) {
std::make_pair<float, float>(0, 16.81f),
std::make_pair<float, float>(16.81f, 17.64f),
std::make_pair<float, float>(17.64f, 18.49f)}) {
knowhere::SetMetaRadiusLowBound(conf_, range.first);
knowhere::SetMetaRadiusHighBound(conf_, range.second);
test_range_search_l2(range.first, range.second, nullptr);
Expand Down
8 changes: 4 additions & 4 deletions unittest/test_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ TEST_P(IVFTest, ivf_range_search_l2) {
ASSERT_TRUE(adapter->CheckRangeSearch(conf_, index_type_, index_mode_));

auto result = index_->QueryByRange(qd, conf_, bitset);
CheckRangeSearchResult(result, metric_type, nq, low_bound * low_bound, high_bound * high_bound,
CheckRangeSearchResult(result, metric_type, nq, low_bound, high_bound,
golden_labels.data(), golden_lims.data(), false, bitset);
};

for (std::pair<float, float> range: {
std::make_pair<float, float>(0, 4.1f),
std::make_pair<float, float>(4.1f, 4.2f),
std::make_pair<float, float>(4.2f, 4.3f)}) {
std::make_pair<float, float>(0, 16.81f),
std::make_pair<float, float>(16.81f, 17.64f),
std::make_pair<float, float>(17.64f, 18.49f)}) {
knowhere::SetMetaRadiusLowBound(conf_, range.first);
knowhere::SetMetaRadiusHighBound(conf_, range.second);
test_range_search_l2(range.first, range.second, nullptr);
Expand Down
8 changes: 4 additions & 4 deletions unittest/test_ivf_nm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ TEST_P(IVFNMTest, ivfnm_range_search_l2) {
ASSERT_TRUE(adapter->CheckRangeSearch(conf_, index_type_, index_mode_));

auto result = index_->QueryByRange(qd, conf_, bitset);
CheckRangeSearchResult(result, metric_type, nq, low_bound * low_bound, high_bound * high_bound,
CheckRangeSearchResult(result, metric_type, nq, low_bound, high_bound,
golden_labels.data(), golden_lims.data(), false, bitset);
};

for (std::pair<float, float> range: {
std::make_pair<float, float>(0, 4.1f),
std::make_pair<float, float>(4.1f, 4.2f),
std::make_pair<float, float>(4.2f, 4.3f)}) {
std::make_pair<float, float>(0, 16.81f),
std::make_pair<float, float>(16.81f, 17.64f),
std::make_pair<float, float>(17.64f, 18.49f)}) {
knowhere::SetMetaRadiusLowBound(conf_, range.first);
knowhere::SetMetaRadiusHighBound(conf_, range.second);
test_range_search_l2(range.first, range.second, nullptr);
Expand Down

0 comments on commit a8abbf7

Please sign in to comment.