Skip to content

Commit

Permalink
add bm25 metric to inverted index/wand (#688)
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored Jul 10, 2024
1 parent 663a778 commit 756df75
Show file tree
Hide file tree
Showing 10 changed files with 409 additions and 108 deletions.
6 changes: 6 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ constexpr const char* TRACE_FLAGS = "trace_flags";
constexpr const char* MATERIALIZED_VIEW_SEARCH_INFO = "materialized_view_search_info";
constexpr const char* MATERIALIZED_VIEW_OPT_FIELDS_PATH = "opt_fields_path";
constexpr const char* MAX_EMPTY_RESULT_BUCKETS = "max_empty_result_buckets";
constexpr const char* BM25_K1 = "bm25_k1";
constexpr const char* BM25_B = "bm25_b";
// average document length
constexpr const char* BM25_AVGDL = "bm25_avgdl";
constexpr const char* WAND_BM25_MAX_SCORE_RATIO = "wand_bm25_max_score_ratio";
}; // namespace meta

namespace indexparam {
Expand Down Expand Up @@ -159,6 +164,7 @@ constexpr const char* HAMMING = "HAMMING";
constexpr const char* JACCARD = "JACCARD";
constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE";
constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE";
constexpr const char* BM25 = "BM25";
} // namespace metric

enum VecType {
Expand Down
37 changes: 36 additions & 1 deletion include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,24 @@ class BaseConfig : public Config {
CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE materialized_view_search_info;
CFG_STRING opt_fields_path;
CFG_FLOAT iterator_refine_ratio;
/**
* k1, b, avgdl are used by BM25 metric only.
* - k1, b, avgdl must be provided at load time.
* - k1 and b can be overridden at search time for SPARSE_INVERTED_INDEX
* but not for SPARSE_WAND.
* - avgdl must always be provided at search time.
*/
CFG_FLOAT bm25_k1;
CFG_FLOAT bm25_b;
CFG_FLOAT bm25_avgdl;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
.description("metric type")
.for_train_and_search()
.for_iterator()
.for_deserialize();
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(retrieve_friendly)
.description("whether the index holds raw data for fast retrieval")
.set_default(false)
Expand Down Expand Up @@ -739,6 +750,30 @@ class BaseConfig : public Config {
.description("refine ratio for iterator")
.for_iterator()
.for_range_search();
KNOWHERE_CONFIG_DECLARE_FIELD(bm25_k1)
.allow_empty_without_default()
.set_range(0.0, 3.0)
.description("BM25 k1 to tune the term frequency scaling factor")
.for_train_and_search()
.for_iterator()
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(bm25_b)
.allow_empty_without_default()
.set_range(0.0, 1.0)
.description("BM25 beta to tune the document length scaling factor")
.for_train_and_search()
.for_iterator()
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(bm25_avgdl)
.allow_empty_without_default()
.set_range(1, std::numeric_limits<CFG_FLOAT::value_type>::max())
.description("average document length")
.for_train_and_search()
.for_iterator()
.for_deserialize()
.for_deserialize_from_file();
}
};
} // namespace knowhere
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class IndexNode : public Object {
std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);
const bool similarity_metric = IsMetricType(base_cfg.metric_type.value(), metric::IP) ||
IsMetricType(base_cfg.metric_type.value(), metric::COSINE);
IsMetricType(base_cfg.metric_type.value(), metric::COSINE) ||
IsMetricType(base_cfg.metric_type.value(), metric::BM25);
const bool has_range_filter = range_filter != defaultRangeFilter;
constexpr size_t k_min_num_consecutive_over_radius = 16;
const auto range_search_level = base_cfg.range_search_level.value();
Expand Down
32 changes: 30 additions & 2 deletions include/knowhere/sparse_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <type_traits>
#include <vector>

#include "knowhere/expected.h"
#include "knowhere/object.h"
#include "knowhere/operands.h"

Expand All @@ -35,6 +37,28 @@ using label_t = int64_t;
template <typename T>
using SparseIdVal = IdVal<table_t, T>;

// DocValueComputer takes a value of a doc vector and returns the a computed
// value that can be used to multiply directly with the corresponding query
// value. The second parameter is the document length of the database vector,
// which is used in BM25.
template <typename T>
using DocValueComputer = std::function<float(const T&, const float)>;

template <typename T>
auto
GetDocValueOriginalComputer() {
static DocValueComputer<T> lambda = [](const T& right, const float) -> float { return right; };
return lambda;
}

template <typename T>
auto
GetDocValueBM25Computer(float k1, float b, float avgdl) {
return [k1, b, avgdl](const T& tf, const float doc_len) -> float {
return tf * (k1 + 1) / (tf + k1 * (1 - b + b * (doc_len / avgdl)));
};
}

template <typename T>
class SparseRow {
static_assert(std::is_same_v<T, fp32>, "SparseRow supports float only");
Expand Down Expand Up @@ -128,8 +152,12 @@ class SparseRow {
elem->value = value;
}

// In the case of asymetric distance functions, this should be the query
// and the other should be the database vector. For example using BM25, we
// should call query_vec.dot(doc_vec) instead of doc_vec.dot(query_vec).
template <typename Computer = DocValueComputer<T>>
float
dot(const SparseRow<T>& other) const {
dot(const SparseRow<T>& other, Computer computer = GetDocValueOriginalComputer<T>(), const T other_sum = 0) const {
float product_sum = 0.0f;
size_t i = 0;
size_t j = 0;
Expand All @@ -143,7 +171,7 @@ class SparseRow {
} else if (left->index > right->index) {
++j;
} else {
product_sum += left->value * right->value;
product_sum += left->value * computer(right->value, other_sum);
++i;
++j;
}
Expand Down
109 changes: 78 additions & 31 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,30 @@ namespace knowhere {

class BruteForceConfig : public BaseConfig {};

namespace {

template <typename T>
expected<sparse::DocValueComputer<T>>
GetDocValueComputer(const BruteForceConfig& cfg) {
if (IsMetricType(cfg.metric_type.value(), metric::IP)) {
return sparse::GetDocValueOriginalComputer<T>();
} else if (IsMetricType(cfg.metric_type.value(), metric::BM25)) {
if (!cfg.bm25_k1.has_value() || !cfg.bm25_b.has_value() || !cfg.bm25_avgdl.has_value()) {
return expected<sparse::DocValueComputer<T>>::Err(
Status::invalid_args, "bm25_k1, bm25_b, bm25_avgdl must be set when searching for bm25 metric");
}
auto k1 = cfg.bm25_k1.value();
auto b = cfg.bm25_b.value();
auto avgdl = cfg.bm25_avgdl.value();
return sparse::GetDocValueBM25Computer<T>(k1, b, avgdl);
} else {
return expected<sparse::DocValueComputer<T>>::Err(Status::invalid_metric_type,
"metric type not supported for sparse vector");
}
}

} // namespace

template <typename DataType>
expected<DataSetPtr>
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
Expand Down Expand Up @@ -321,15 +345,24 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
#endif

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
return expected<DataSetPtr>::Err(result.error(), result.what());
}
faiss::MetricType faiss_metric_type = result.value();
if (is_sparse && !IsMetricType(metric_str, metric::IP)) {
return expected<DataSetPtr>::Err(Status::invalid_metric_type,
"Invalid metric type for sparse float vector: " + metric_str);
const bool is_bm25 = IsMetricType(metric_str, metric::BM25);

faiss::MetricType faiss_metric_type;
sparse::DocValueComputer<float> sparse_computer;
if (!is_sparse) {
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
return expected<DataSetPtr>::Err(result.error(), result.what());
}
faiss_metric_type = result.value();
} else {
auto computer_or = GetDocValueComputer<float>(cfg);
if (!computer_or.has_value()) {
return expected<DataSetPtr>::Err(computer_or.error(), computer_or.what());
}
sparse_computer = computer_or.value();
}

bool is_cosine = IsMetricType(metric_str, metric::COSINE);

auto radius = cfg.radius.value();
Expand All @@ -352,7 +385,14 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
if (!bitset.empty() && bitset.test(j)) {
continue;
}
auto dist = cur_query->dot(xb_sparse[j]);
float row_sum = 0;
if (is_bm25) {
for (size_t k = 0; k < xb_sparse[j].size(); ++k) {
auto [d, v] = xb_sparse[j][k];
row_sum += v;
}
}
auto dist = cur_query->dot(xb_sparse[j], sparse_computer, row_sum);
if (dist > radius && dist <= range_filter) {
result_id_array[index].push_back(j);
result_dist_array[index].push_back(dist);
Expand Down Expand Up @@ -425,7 +465,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
}

auto range_search_result =
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip || is_bm25, nq, radius, range_filter);
auto res = GenResultDataSet(nq, std::move(range_search_result));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
Expand Down Expand Up @@ -469,15 +509,13 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
#endif

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return result.error();
}
if (!IsMetricType(metric_str, metric::IP)) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value();
return Status::invalid_metric_type;
const bool is_bm25 = IsMetricType(metric_str, metric::BM25);

auto computer_or = GetDocValueComputer<float>(cfg);
if (!computer_or.has_value()) {
return computer_or.error();
}
auto computer = computer_or.value();

int topk = cfg.k.value();
std::fill(distances, distances + nq * topk, std::numeric_limits<float>::quiet_NaN());
Expand All @@ -500,7 +538,14 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
if (!bitset.empty() && bitset.test(j)) {
continue;
}
float dist = row.dot(base[j]);
float row_sum = 0;
if (is_bm25) {
for (size_t k = 0; k < base[j].size(); ++k) {
auto [d, v] = base[j][k];
row_sum += v;
}
}
float dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
heap.push(j, dist);
}
Expand Down Expand Up @@ -673,18 +718,13 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
}
#endif

auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << metric_str;
return expected<std::vector<IndexNode::IteratorPtr>>::Err(
result.error(), "Failed to brute force search sparse for iterator: invalid metric type " + metric_str);
}
if (!IsMetricType(metric_str, metric::IP)) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << metric_str;
return expected<std::vector<IndexNode::IteratorPtr>>::Err(
Status::invalid_metric_type,
"Failed to brute force search sparse for iterator: invalid metric type " + metric_str);
const bool is_bm25 = IsMetricType(metric_str, metric::BM25);

auto computer_or = GetDocValueComputer<float>(cfg);
if (!computer_or.has_value()) {
return expected<std::vector<IndexNode::IteratorPtr>>::Err(computer_or.error(), computer_or.what());
}
auto computer = computer_or.value();

auto pool = ThreadPool::GetGlobalSearchThreadPool();
auto vec = std::vector<IndexNode::IteratorPtr>(nq, nullptr);
Expand All @@ -699,7 +739,14 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
if (!bitset.empty() && bitset.test(j)) {
continue;
}
auto dist = row.dot(base[j]);
float row_sum = 0;
if (is_bm25) {
for (size_t k = 0; k < base[j].size(); ++k) {
auto [d, v] = base[j][k];
row_sum += v;
}
}
auto dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
distances_ids.emplace_back(j, dist);
}
Expand Down
Loading

0 comments on commit 756df75

Please sign in to comment.