Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wand optimize cursor class #968

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 73 additions & 67 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
vals.push_back(fabs(data[i][j].val));
}
}
auto pos = vals.begin() + static_cast<size_t>(drop_ratio_build * vals.size());
// pos may be vals.end() if drop_ratio_build is 1.0, in that case we use
// the largest value as the threshold.
if (pos == vals.end()) {
pos--;
}
std::nth_element(vals.begin(), pos, vals.end());

value_threshold_ = *pos;
value_threshold_ = get_threshold(vals, drop_ratio_build);
drop_during_build_ = true;
return Status::success;
}
Expand Down Expand Up @@ -413,9 +405,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
for (size_t i = 0; i < query.size(); ++i) {
values[i] = std::abs(query[i].val);
}
auto pos = values.begin() + static_cast<size_t>(drop_ratio_search * values.size());
std::nth_element(values.begin(), pos, values.end());
auto q_threshold = *pos;
auto q_threshold = get_threshold(values, drop_ratio_search);

// if no data was dropped during both build and search, no refinement is
// needed.
Expand Down Expand Up @@ -447,9 +437,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
for (size_t i = 0; i < query.size(); ++i) {
values[i] = std::abs(query[i].val);
}
auto pos = values.begin() + static_cast<size_t>(drop_ratio_search * values.size());
std::nth_element(values.begin(), pos, values.end());
auto q_threshold = *pos;
auto q_threshold = get_threshold(values, drop_ratio_search);
auto distances = compute_all_distances(query, q_threshold, computer);
for (size_t i = 0; i < distances.size(); ++i) {
if (bitset.empty() || !bitset.test(i)) {
Expand Down Expand Up @@ -512,6 +500,22 @@ class InvertedIndex : public BaseInvertedIndex<T> {
}

private:
// Given a vector of values, returns the threshold value.
// All values strictly smaller than the threshold will be ignored.
// values will be modified in this function.
inline T
get_threshold(std::vector<T>& values, float drop_ratio) const {
// drop_ratio is in [0, 1) thus drop_count is guaranteed to be less
// than values.size().
auto drop_count = static_cast<size_t>(drop_ratio * values.size());
if (drop_count == 0) {
return 0;
}
auto pos = values.begin() + drop_count;
std::nth_element(values.begin(), pos, values.end());
return *pos;
}

size_t
n_rows_internal() const {
return raw_data_.size();
Expand Down Expand Up @@ -561,74 +565,77 @@ class InvertedIndex : public BaseInvertedIndex<T> {

// LUT supports size() and operator[] which returns an SparseIdVal.
template <typename LUT>
class Cursor {
struct Cursor {
public:
Cursor(const LUT& lut, size_t num_vec, float max_score, float q_value, const BitsetView bitset)
: lut_(lut), num_vec_(num_vec), max_score_(max_score), q_value_(q_value), bitset_(bitset) {
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
: lut_(lut),
lut_size_(lut.size()),
total_num_vec_(num_vec),
max_score_(max_score),
q_value_(q_value),
bitset_(bitset) {
while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) {
loc_++;
}
update_cur_vec_id();
}
Cursor(const Cursor& rhs) = delete;

void
next() {
loc_++;
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
loc_++;
}
next_internal();
update_cur_vec_id();
}
// advance loc until cur_vec_id() >= vec_id

// advance loc until cur_vec_id_ >= vec_id
void
seek(table_t vec_id) {
while (loc_ < lut_.size() && cur_vec_id() < vec_id) {
next();
while (loc_ < lut_size_ && lut_[loc_].id < vec_id) {
next_internal();
}
update_cur_vec_id();
}
[[nodiscard]] table_t
cur_vec_id() const {
if (is_end()) {
return num_vec_;
}
return lut_[loc_].id;
}

T
cur_vec_val() const {
return lut_[loc_].val;
}
[[nodiscard]] bool
is_end() const {
return loc_ >= size();
}
[[nodiscard]] float
q_value() const {
return q_value_;
}
[[nodiscard]] size_t
size() const {
return lut_.size();
}
[[nodiscard]] float
max_score() const {
return max_score_;
}

private:
const LUT& lut_;
const size_t lut_size_;
size_t loc_ = 0;
size_t num_vec_ = 0;
size_t total_num_vec_ = 0;
float max_score_ = 0.0f;
float q_value_ = 0.0f;
const BitsetView bitset_;
}; // class Cursor
table_t cur_vec_id_ = 0;

private:
inline void
update_cur_vec_id() {
if (loc_ >= lut_size_) {
cur_vec_id_ = total_num_vec_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just double-checking that this is not total_num_vec_ - 1. According to standard naming conventions, id_ is assumed to be something within [0, num) range.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it is not total_num_vec_ - 1. When id_ is in [0, num) range it is still in range and the cursor is pointing at some doc, loc_ >= lut_size_ means it is out of range and the cursor points at no doc.

} else {
cur_vec_id_ = lut_[loc_].id;
}
}

inline void
next_internal() {
loc_++;
while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) {
loc_++;
}
}
}; // struct Cursor

// any value in q_vec that is smaller than q_threshold will be ignored.
void
search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset,
const DocValueComputer<T>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<const typename decltype(inverted_lut_)::value_type&>>> cursors(q_dim);
auto valid_q_dim = 0;
size_t valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
auto dim_id = dim_map_.find(idx);
Expand All @@ -644,48 +651,47 @@ class InvertedIndex : public BaseInvertedIndex<T> {
}
cursors.resize(valid_q_dim);
auto sort_cursors = [&cursors] {
std::sort(cursors.begin(), cursors.end(),
[](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; });
};
sort_cursors();
auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; };
while (true) {
float threshold = heap.full() ? heap.top().val : 0;
float upper_bound = 0;
size_t pivot;
bool found_pivot = false;
for (pivot = 0; pivot < cursors.size(); ++pivot) {
if (cursors[pivot]->is_end()) {
for (pivot = 0; pivot < valid_q_dim; ++pivot) {
if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) {
break;
}
upper_bound += cursors[pivot]->max_score();
if (score_above_threshold(upper_bound)) {
upper_bound += cursors[pivot]->max_score_;
if (upper_bound > threshold) {
found_pivot = true;
break;
}
}
if (!found_pivot) {
break;
}
table_t pivot_id = cursors[pivot]->cur_vec_id();
if (pivot_id == cursors[0]->cur_vec_id()) {
table_t pivot_id = cursors[pivot]->cur_vec_id_;
if (pivot_id == cursors[0]->cur_vec_id_) {
float score = 0;
for (auto& cursor : cursors) {
if (cursor->cur_vec_id() != pivot_id) {
if (cursor->cur_vec_id_ != pivot_id) {
break;
}
T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id()) : 0;
score += cursor->q_value() * computer(cursor->cur_vec_val(), cur_vec_sum);
T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0;
score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum);
cursor->next();
}
heap.push(pivot_id, score);
sort_cursors();
} else {
size_t next_list = pivot;
for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) {
for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) {
}
cursors[next_list]->seek(pivot_id);
for (size_t i = next_list + 1; i < cursors.size(); ++i) {
if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) {
for (size_t i = next_list + 1; i < valid_q_dim; ++i) {
if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) {
break;
}
std::swap(cursors[i], cursors[i - 1]);
Expand Down