Skip to content

Commit

Permalink
Storages: Optimize vector search in scenarios with updates (#9597)
Browse files Browse the repository at this point in the history
close #9599, ref #9600

Improve 75% the performance of vector search in scenarios with updates.

Signed-off-by: Lloyd-Pottiger <[email protected]>
  • Loading branch information
Lloyd-Pottiger authored Nov 13, 2024
1 parent 15ed3cf commit 9dc8374
Show file tree
Hide file tree
Showing 29 changed files with 457 additions and 236 deletions.
1 change: 0 additions & 1 deletion dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include <Common/Exception.h>
#include <DataStreams/PushingToViewsBlockOutputStream.h>
#include <DataStreams/SquashingBlockInputStream.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Storages/IStorage.h>
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/DataStreams/SquashingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace DB
{

SquashingBlockInputStream::SquashingBlockInputStream(
const BlockInputStreamPtr & src,
size_t min_block_size_rows,
Expand All @@ -28,7 +29,7 @@ SquashingBlockInputStream::SquashingBlockInputStream(
}


Block SquashingBlockInputStream::readImpl()
Block SquashingBlockInputStream::read()
{
if (all_read)
return {};
Expand Down
7 changes: 3 additions & 4 deletions dbms/src/DataStreams/SquashingBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

#pragma once

#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataStreams/SquashingTransform.h>


namespace DB
{
/** Merging consecutive blocks of stream to specified minimum size.
*/
class SquashingBlockInputStream : public IProfilingBlockInputStream
class SquashingBlockInputStream : public IBlockInputStream
{
static constexpr auto NAME = "Squashing";

Expand All @@ -37,8 +37,7 @@ class SquashingBlockInputStream : public IProfilingBlockInputStream

Block getHeader() const override { return children.at(0)->getHeader(); }

protected:
Block readImpl() override;
Block read() override;

private:
const LoggerPtr log;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
namespace DB::DM
{

ColumnFileSetInputStreamPtr ColumnFileSetWithVectorIndexInputStream::tryBuild(
SkippableBlockInputStreamPtr ColumnFileSetWithVectorIndexInputStream::tryBuild(
const DMContext & context,
const ColumnFileSetSnapshotPtr & delta_snap,
const ColumnDefinesPtr & col_defs,
const RowKeyRange & segment_range_,
const IColumnFileDataProviderPtr & data_provider,
const RSOperatorPtr & rs_operator,
const ANNQueryInfoPtr & ann_query_info,
const BitmapFilterPtr & bitmap_filter,
size_t offset,
ReadTag read_tag_)
Expand All @@ -36,28 +36,7 @@ ColumnFileSetInputStreamPtr ColumnFileSetWithVectorIndexInputStream::tryBuild(
return std::make_shared<ColumnFileSetInputStream>(context, delta_snap, col_defs, segment_range_, read_tag_);
};

if (rs_operator == nullptr || bitmap_filter == nullptr)
return fallback();

auto filter_with_ann = std::dynamic_pointer_cast<WithANNQueryInfo>(rs_operator);
if (filter_with_ann == nullptr)
return fallback();

auto ann_query_info = filter_with_ann->ann_query_info;
if (!ann_query_info)
return fallback();

// Fast check: ANNQueryInfo is available in the whole read path. However we may not reading vector column now.
bool is_matching_ann_query = false;
for (const auto & cd : *col_defs)
{
if (cd.id == ann_query_info->column_id())
{
is_matching_ann_query = true;
break;
}
}
if (!is_matching_ann_query)
if (!bitmap_filter || !ann_query_info)
return fallback();

std::optional<ColumnDefine> vec_cd;
Expand Down Expand Up @@ -140,16 +119,8 @@ Block ColumnFileSetWithVectorIndexInputStream::readImpl(FilterPtr & res_filter)
if (tiny_readers[current_file_index] != nullptr)
{
const auto file_rows = column_files[current_file_index]->getRows();
auto selected_row_begin = std::lower_bound(
selected_rows.cbegin(),
selected_rows.cend(),
read_rows,
[](const auto & row, UInt32 offset) { return row.key < offset; });
auto selected_row_end = std::lower_bound(
selected_row_begin,
selected_rows.cend(),
read_rows + file_rows,
[](const auto & row, UInt32 offset) { return row.key < offset; });
auto selected_row_begin = std::lower_bound(sorted_results.cbegin(), sorted_results.cend(), read_rows);
auto selected_row_end = std::lower_bound(selected_row_begin, sorted_results.cend(), read_rows + file_rows);
size_t selected_rows = std::distance(selected_row_begin, selected_row_end);
// If all rows are filtered out, skip this file.
if (selected_rows == 0)
Expand Down Expand Up @@ -184,7 +155,7 @@ Block ColumnFileSetWithVectorIndexInputStream::readImpl(FilterPtr & res_filter)
{
filter.clear();
filter.resize_fill(file_rows, 0);
for (const auto & [rowid, _] : file_selected_rows)
for (const auto rowid : file_selected_rows)
filter[rowid - read_rows] = 1;
res_filter = &filter;
}
Expand All @@ -211,13 +182,14 @@ Block ColumnFileSetWithVectorIndexInputStream::readImpl(FilterPtr & res_filter)
return {};
}

void ColumnFileSetWithVectorIndexInputStream::load()
std::vector<VectorIndexViewer::SearchResult> ColumnFileSetWithVectorIndexInputStream::load()
{
if (loaded)
return;
return {};

tiny_readers.reserve(column_files.size());
UInt32 precedes_rows = 0;
std::vector<VectorIndexViewer::SearchResult> search_results;
for (const auto & column_file : column_files)
{
if (auto * tiny_file = column_file->tryToTinyFile();
Expand All @@ -233,7 +205,7 @@ void ColumnFileSetWithVectorIndexInputStream::load()
auto sr = tiny_reader->load();
for (auto & row : sr)
row.key += precedes_rows;
selected_rows.insert(selected_rows.end(), sr.begin(), sr.end());
search_results.insert(search_results.end(), sr.begin(), sr.end());
tiny_readers.push_back(tiny_reader);
// avoid virutal function call
precedes_rows += tiny_file->getRows();
Expand All @@ -245,18 +217,26 @@ void ColumnFileSetWithVectorIndexInputStream::load()
}
}
// Keep the top k minimum distances rows.
auto select_size = selected_rows.size() > ann_query_info->top_k() ? ann_query_info->top_k() : selected_rows.size();
auto top_k_end = selected_rows.begin() + select_size;
std::nth_element(selected_rows.begin(), top_k_end, selected_rows.end(), [](const auto & lhs, const auto & rhs) {
auto select_size
= search_results.size() > ann_query_info->top_k() ? ann_query_info->top_k() : search_results.size();
auto top_k_end = search_results.begin() + select_size;
std::nth_element(search_results.begin(), top_k_end, search_results.end(), [](const auto & lhs, const auto & rhs) {
return lhs.distance < rhs.distance;
});
selected_rows.resize(select_size);
search_results.resize(select_size);
// Sort by key again.
std::sort(selected_rows.begin(), selected_rows.end(), [](const auto & lhs, const auto & rhs) {
std::sort(search_results.begin(), search_results.end(), [](const auto & lhs, const auto & rhs) {
return lhs.key < rhs.key;
});

loaded = true;
return search_results;
}

void ColumnFileSetWithVectorIndexInputStream::setSelectedRows(const std::span<const UInt32> & selected_rows)
{
sorted_results.reserve(selected_rows.size());
std::copy(selected_rows.begin(), selected_rows.end(), std::back_inserter(sorted_results));
}

} // namespace DB::DM
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
#include <Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.h>
#include <Storages/DeltaMerge/DMContext.h>
#include <Storages/DeltaMerge/Filter/RSOperator.h>
#include <Storages/DeltaMerge/VectorIndexBlockInputStream.h>


namespace DB::DM
{

class ColumnFileSetWithVectorIndexInputStream : public ColumnFileSetInputStream
class ColumnFileSetWithVectorIndexInputStream : public VectorIndexBlockInputStream
{
private:
ColumnFileSetReader reader;

std::vector<ColumnFileReaderPtr>::iterator cur_column_file_reader;
size_t read_rows = 0;

const IColumnFileDataProviderPtr data_provider;
const ANNQueryInfoPtr ann_query_info;
const BitmapFilterView valid_rows;
Expand All @@ -37,7 +43,7 @@ class ColumnFileSetWithVectorIndexInputStream : public ColumnFileSetInputStream
const ColumnDefinesPtr rest_col_defs;

// Set after load(). Top K search results in files with vector index.
std::vector<VectorIndexViewer::SearchResult> selected_rows;
std::vector<VectorIndexViewer::Key> sorted_results;
std::vector<ColumnFileTinyVectorIndexReaderPtr> tiny_readers;

ColumnFiles & column_files;
Expand All @@ -59,43 +65,54 @@ class ColumnFileSetWithVectorIndexInputStream : public ColumnFileSetInputStream
ColumnDefine && vec_cd_,
const ColumnDefinesPtr & rest_col_defs_,
ReadTag read_tag_)
: ColumnFileSetInputStream(context_, delta_snap_, col_defs_, segment_range_, read_tag_)
: reader(context_, delta_snap_, col_defs_, segment_range_, read_tag_)
, data_provider(data_provider_)
, ann_query_info(ann_query_info_)
, valid_rows(std::move(valid_rows_))
, vec_index_cache(context_.global_context.getVectorIndexCache())
, vec_cd(std::move(vec_cd_))
, rest_col_defs(rest_col_defs_)
, column_files(reader.snapshot->getColumnFiles())
, header(getHeader())
{}
, header(toEmptyBlock(*(reader.col_defs)))
{
cur_column_file_reader = reader.column_file_readers.begin();
}

static ColumnFileSetInputStreamPtr tryBuild(
static SkippableBlockInputStreamPtr tryBuild(
const DMContext & context,
const ColumnFileSetSnapshotPtr & delta_snap,
const ColumnDefinesPtr & col_defs,
const RowKeyRange & segment_range_,
const IColumnFileDataProviderPtr & data_provider,
const RSOperatorPtr & rs_operator,
const ANNQueryInfoPtr & ann_query_info,
const BitmapFilterPtr & bitmap_filter,
size_t offset,
ReadTag read_tag_);

String getName() const override { return "ColumnFileSetWithVectorIndex"; }
Block getHeader() const override { return header; }

Block read() override
{
FilterPtr filter = nullptr;
return read(filter, false);
}

// When all rows in block are not filtered out,
// `res_filter` will be set to null.
// The caller needs to do handle this situation.
Block read(FilterPtr & res_filter, bool return_filter) override;

std::vector<VectorIndexViewer::SearchResult> load() override;

void setSelectedRows(const std::span<const UInt32> & selected_rows) override;

private:
Block readImpl(FilterPtr & res_filter);

Block readOtherColumns();

void toNextFile(size_t current_file_index, size_t current_file_rows);

void load();
};

} // namespace DB::DM
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace DB::DM

void ColumnFileTinyVectorIndexReader::read(
MutableColumnPtr & vec_column,
const std::span<const VectorIndexViewer::SearchResult> & read_rowids,
const std::span<const VectorIndexViewer::Key> & read_rowids,
size_t rowid_start_offset,
size_t read_rows)
{
Expand All @@ -36,9 +36,10 @@ void ColumnFileTinyVectorIndexReader::read(
vec_column->reserve(read_rows);
std::vector<Float32> value;
size_t current_rowid = rowid_start_offset;
for (const auto & [rowid, _] : read_rowids)
for (const auto & rowid : read_rowids)
{
vec_index->get(rowid, value);
// Each ColomnFileTiny has its own vector index, rowid_start_offset is the offset of the ColmnFilePersistSet.
vec_index->get(rowid - rowid_start_offset, value);
if (rowid > current_rowid)
{
UInt32 nulls = rowid - current_rowid;
Expand Down Expand Up @@ -135,7 +136,7 @@ std::vector<VectorIndexViewer::SearchResult> ColumnFileTinyVectorIndexReader::lo
auto perf_begin = PerfContext::vector_search;
RUNTIME_CHECK(valid_rows.size() == tiny_file.getRows(), valid_rows.size(), tiny_file.getRows());

auto search_results = vec_index->searchWithDistance(ann_query_info, valid_rows);
auto search_results = vec_index->search(ann_query_info, valid_rows);
// Sort by key
std::sort(search_results.begin(), search_results.end(), [](const auto & lhs, const auto & rhs) {
return lhs.key < rhs.key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ColumnFileTinyVectorIndexReader
// others will be filled with default values.
void read(
MutableColumnPtr & vec_column,
const std::span<const VectorIndexViewer::SearchResult> & read_rowids,
const std::span<const VectorIndexViewer::Key> & read_rowids,
size_t rowid_start_offset,
size_t read_rows);

Expand Down
Loading

0 comments on commit 9dc8374

Please sign in to comment.