From 6dec4cfe98d2a7bbc84f0fa6f5d34b649464e6f2 Mon Sep 17 00:00:00 2001 From: chasingegg Date: Fri, 3 Jan 2025 12:04:00 +0800 Subject: [PATCH] Support MV only for HNSW Signed-off-by: chasingegg --- include/knowhere/bitsetview.h | 32 + include/knowhere/bitsetview_idselector.h | 17 + include/knowhere/comp/index_param.h | 1 + src/index/hnsw/faiss_hnsw.cc | 1279 ++++++++--- src/index/hnsw/impl/IndexBruteForceWrapper.cc | 37 +- .../hnsw/impl/IndexConditionalWrapper.cc | 8 +- src/index/hnsw/impl/IndexHNSWWrapper.cc | 81 +- tests/ut/test_faiss_hnsw.cc | 1901 ++++++++++------- tests/ut/test_index_check.cc | 14 +- tests/ut/utils.h | 46 + thirdparty/faiss/faiss/impl/index_read.cpp | 23 + thirdparty/faiss/faiss/impl/index_write.cpp | 14 + thirdparty/faiss/faiss/index_io.h | 9 + 13 files changed, 2306 insertions(+), 1156 deletions(-) diff --git a/include/knowhere/bitsetview.h b/include/knowhere/bitsetview.h index 464bf774b..6cf94b07e 100644 --- a/include/knowhere/bitsetview.h +++ b/include/knowhere/bitsetview.h @@ -95,6 +95,38 @@ class BitsetView { return ret; } + size_t + get_first_valid_index() const { + size_t ret = 0; + auto len_uint8 = byte_size(); + auto len_uint64 = len_uint8 >> 3; + + uint64_t* p_uint64 = (uint64_t*)bits_; + for (size_t i = 0; i < len_uint64; i++) { + uint64_t value = (~(*p_uint64)); + if (value == 0) { + p_uint64++; + continue; + } + ret = __builtin_ctzll(value); + return i * 64 + ret; + } + + // calculate remainder + uint8_t* p_uint8 = (uint8_t*)bits_ + (len_uint64 << 3); + for (size_t i = 0; i < len_uint8 - (len_uint64 << 3); i++) { + uint8_t value = (~(*p_uint8)); + if (value == 0) { + p_uint8++; + continue; + } + ret = __builtin_ctz(value); + return len_uint64 * 64 + i * 8 + ret; + } + + return num_bits_; + } + std::string to_string(size_t from, size_t to) const { if (empty()) { diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h index 39f6ff1a8..49531e0e4 100644 --- a/include/knowhere/bitsetview_idselector.h +++ b/include/knowhere/bitsetview_idselector.h @@ -32,4 +32,21 @@ struct BitsetViewIDSelector final : faiss::IDSelector { } }; +struct BitsetViewWithMappingIDSelector final : faiss::IDSelector { + const BitsetView bitset_view; + const uint32_t* out_id_mapping; + const size_t id_offset; + + inline BitsetViewWithMappingIDSelector(BitsetView bitset_view, const uint32_t* out_id_mapping, + const size_t offset = 0) + : bitset_view{bitset_view}, out_id_mapping(out_id_mapping), id_offset(offset) { + } + + inline bool + is_member(faiss::idx_t id) const override final { + // it is by design that out_id_mapping == nullptr is not tested here + return (!bitset_view.test(out_id_mapping[id + id_offset])); + } +}; + } // namespace knowhere diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 364b66c11..51fa452dd 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -95,6 +95,7 @@ constexpr const char* JSON_ID_SET = "json_id_set"; constexpr const char* TRACE_ID = "trace_id"; constexpr const char* SPAN_ID = "span_id"; constexpr const char* TRACE_FLAGS = "trace_flags"; +constexpr const char* SCALAR_INFO = "scalar_info"; constexpr const char* MATERIALIZED_VIEW_SEARCH_INFO = "materialized_view_search_info"; constexpr const char* MATERIALIZED_VIEW_OPT_FIELDS_PATH = "opt_fields_path"; constexpr const char* MAX_EMPTY_RESULT_BUCKETS = "max_empty_result_buckets"; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 2f933bfdd..ab8a2a8e1 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -30,6 +30,7 @@ #include "faiss/IndexHNSW.h" #include "faiss/IndexRefine.h" #include "faiss/impl/ScalarQuantizer.h" +#include "faiss/impl/mapped_io.h" #include "faiss/index_io.h" #include "index/hnsw/faiss_hnsw_config.h" #include "index/hnsw/hnsw.h" @@ -66,6 +67,11 @@ class BaseFaissIndexNode : public IndexNode { search_pool = ThreadPool::GetGlobalSearchThreadPool(); } + bool + IsAdditionalScalarSupported() const override { + return true; + } + // Status Train(const DataSetPtr dataset, std::shared_ptr cfg) override { @@ -159,21 +165,33 @@ is_faiss_fourcc_error(const char* what) { class BaseFaissRegularIndexNode : public BaseFaissIndexNode { public: BaseFaissRegularIndexNode(const int32_t& version, const Object& object) - : BaseFaissIndexNode(version, object), index{nullptr} { + : BaseFaissIndexNode(version, object), indexes(1, nullptr) { } Status Serialize(BinarySet& binset) const override { - if (index == nullptr) { + if (isIndexEmpty()) { return Status::empty_index; } try { MemoryIOWriter writer; - faiss::write_index(index.get(), &writer); + if (indexes.size() > 1) { + // this is a hack for compatibility, faiss index has 4-byte header to indicate index category + // create a new one to distinguish MV faiss hnsw from faiss hnsw + faiss::write_mv(&writer); + writeHeader(&writer); + for (const auto& index : indexes) { + faiss::write_index(index.get(), &writer); + } - std::shared_ptr data(writer.data()); - binset.Append(Type(), data, writer.tellg()); + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + } else { + faiss::write_index(indexes[0].get(), &writer); + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -192,8 +210,23 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { MemoryIOReader reader(binary->data.get(), binary->size); try { - auto read_index = std::unique_ptr(faiss::read_index(&reader)); - index.reset(read_index.release()); + // this is a hack for compatibility, faiss index has 4-byte header to indicate index category + // create a new one to distinguish MV faiss hnsw from faiss hnsw + bool is_mv = faiss::read_is_mv(&reader); + if (is_mv) { + LOG_KNOWHERE_INFO_ << "start to load index by mv"; + uint32_t v = readHeader(&reader); + indexes.resize(v); + LOG_KNOWHERE_INFO_ << "read " << v << " mvs"; + for (auto i = 0; i < v; ++i) { + auto read_index = std::unique_ptr(faiss::read_index(&reader)); + indexes[i].reset(read_index.release()); + } + } else { + reader.reset(); + auto read_index = std::unique_ptr(faiss::read_index(&reader)); + indexes[0].reset(read_index.release()); + } } catch (const std::exception& e) { if (is_faiss_fourcc_error(e.what())) { LOG_KNOWHERE_WARNING_ << "faiss does not recognize the input index: " << e.what(); @@ -217,8 +250,34 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { } try { - auto read_index = std::unique_ptr(faiss::read_index(filename.data(), io_flags)); - index.reset(read_index.release()); + // this is a hack for compatibility, faiss index has 4-byte header to indicate index category + // create a new one to distinguish MV faiss hnsw from faiss hnsw + bool is_mv = faiss::read_is_mv(filename.data()); + if (is_mv) { + auto read_index = [&](faiss::IOReader* r) { + LOG_KNOWHERE_INFO_ << "start to load index by mv"; + read_is_mv(r); + uint32_t v = readHeader(r); + LOG_KNOWHERE_INFO_ << "read " << v << " mvs"; + indexes.resize(v); + for (auto i = 0; i < v; ++i) { + auto read_index = std::unique_ptr(faiss::read_index(r, io_flags)); + indexes[i].reset(read_index.release()); + } + }; + if ((io_flags & faiss::IO_FLAG_MMAP_IFC) == faiss::IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(filename.data()); + faiss::MappedFileIOReader reader(owner); + read_index(&reader); + } else { + faiss::FileIOReader reader(filename.data()); + read_index(&reader); + } + } else { + auto read_index = std::unique_ptr(faiss::read_index(filename.data(), io_flags)); + indexes[0].reset(read_index.release()); + } } catch (const std::exception& e) { if (is_faiss_fourcc_error(e.what())) { LOG_KNOWHERE_WARNING_ << "faiss does not recognize the input index: " << e.what(); @@ -235,32 +294,38 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { // int64_t Dim() const override { - if (index == nullptr) { + if (isIndexEmpty()) { return -1; } - return index->d; + return indexes[0]->d; } int64_t Count() const override { - if (index == nullptr) { + if (isIndexEmpty()) { return -1; } + int64_t count = 0; + for (const auto& index : indexes) { + count += index->ntotal; + } // total number of indexed vectors - return index->ntotal; + return count; } int64_t Size() const override { - if (index == nullptr) { + if (isIndexEmpty()) { return 0; } // a temporary yet expensive workaround faiss::cppcontrib::knowhere::CountSizeIOWriter writer; - faiss::write_index(index.get(), &writer); + for (const auto& index : indexes) { + faiss::write_index(index.get(), &writer); + } // todo return writer.total_size; @@ -269,25 +334,80 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { protected: // it is std::shared_ptr, not std::unique_ptr, because it can be // shared with FaissHnswIterator - std::shared_ptr index; - - Status - AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { - LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; - return Status::empty_index; + std::vector> indexes; + // each index's out ids(label), can be shared with FaissHnswIterator + std::vector>> labels; + + // index rows, help to locate index id by offset + std::vector index_rows_sum; + // label to locate internal offset + std::vector label_to_internal_offset; + + int + getIndexToSearchByScalarInfo(const FaissHnswConfig& config, const BitsetView& bitset) const { + if (indexes.size() == 1) { + return 0; } - - auto data = dataset->GetTensor(); - auto rows = dataset->GetRows(); - try { - this->index->add(rows, reinterpret_cast(data)); - } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); - return Status::faiss_inner_error; + if (bitset.empty()) { + LOG_KNOWHERE_WARNING_ << "partition key value not correctly set"; + return -1; + } + // all data is filtered, just pick the first one + // this will not happen combined with milvus, which will not call knowhere and just return + if (bitset.count() == bitset.size()) { + return 0; + } + size_t first_valid_index = bitset.get_first_valid_index(); + auto it = std::lower_bound(index_rows_sum.begin(), index_rows_sum.end(), + label_to_internal_offset[first_valid_index] + 1); + if (it == index_rows_sum.end()) { + LOG_KNOWHERE_WARNING_ << "can not find vector of offset " << label_to_internal_offset[first_valid_index]; + return -1; } + return std::distance(index_rows_sum.begin(), it) - 1; + } - return Status::success; + void + writeHeader(faiss::IOWriter* f) const { + uint32_t version = 0; + faiss::write_value(version, f); + uint32_t size = indexes.size(); + faiss::write_value(size, f); + uint32_t cluster_size = labels.size(); + faiss::write_value(cluster_size, f); + for (const auto& label : labels) { + faiss::write_vector(*label, f); + } + faiss::write_vector(index_rows_sum, f); + faiss::write_vector(label_to_internal_offset, f); + } + + uint32_t + readHeader(faiss::IOReader* f) { + [[maybe_unused]] uint32_t version = faiss::read_value(f); + uint32_t size = faiss::read_value(f); + uint32_t cluster_size = faiss::read_value(f); + labels.resize(cluster_size); + for (auto j = 0; j < cluster_size; ++j) { + labels[j] = std::make_shared>(); + faiss::read_vector(*labels[j], f); + } + faiss::read_vector(index_rows_sum, f); + faiss::read_vector(label_to_internal_offset, f); + return size; + } + + bool + isIndexEmpty() const { + if (indexes.empty()) { + return true; + } + for (const auto& index : indexes) { + if (index == nullptr) { + return true; + } + } + return false; } }; @@ -319,6 +439,48 @@ static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; namespace { +bool +convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, + const DataFormatEnum src_data_format, const uint32_t* const __restrict offsets, const size_t nrows, + const size_t dim) { + if (src_data_format == DataFormatEnum::fp16) { + const knowhere::fp16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::bf16) { + const knowhere::bf16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::fp32) { + const knowhere::fp32* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::int8) { + const knowhere::int8* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else { + // unknown + return false; + } +} + bool convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows, @@ -437,6 +599,39 @@ add_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, co return Status::success; } +Status +add_partial_dataset_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, + const DataFormatEnum data_format, const std::vector& ids) { + const auto* data = dataset->GetTensor(); + + if (ids.size() > dataset->GetRows()) { + LOG_KNOWHERE_ERROR_ << "partial ids size larger than whole dataset size"; + return Status::invalid_args; + } + const int64_t rows = ids.size(); + const auto dim = dataset->GetDim(); + + // convert data into float in pieces and add to the index + constexpr int64_t n_tmp_rows = 4096; + std::unique_ptr tmp = std::make_unique(n_tmp_rows * dim); + + for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { + const int64_t start_row = irow; + const int64_t end_row = std::min(rows, start_row + n_tmp_rows); + const int64_t count_rows = end_row - start_row; + + if (!convert_rows_to_fp32(data, tmp.get(), data_format, ids.data() + start_row, count_rows, dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + // add + index->add(count_rows, tmp.get()); + } + + return Status::success; +} + // IndexFlat and IndexFlatCosine contain raw fp32 data // IndexScalarQuantizer and IndexScalarQuantizerCosine may contain rar bf16 and fp16 data // @@ -484,6 +679,46 @@ storage_distance_computer(const faiss::Index* storage) { } } +// there are chances that each partition split by scalar distribution is too small that we could not even train pq on it +// bcz 256 points are needed for a 8-bit pq training in faiss +// combine some small partitions to get a bigger one +// for example: scalar_info= {{1,2}, {3,4,5}, {1}}, base_rows = 3 +// we will get {{2, 0}, {1}}, which means we combine the scalar id 0 and 2 together +std::vector> +combine_partitions(const std::vector>& scalar_info, const int64_t base_rows) { + auto scalar_size = scalar_info.size(); + std::vector indices(scalar_size); + std::iota(indices.begin(), indices.end(), 0); + std::vector sizes; + sizes.reserve(scalar_size); + for (const auto& id_list : scalar_info) { + sizes.emplace_back(id_list.size()); + } + std::sort(indices.begin(), indices.end(), [&sizes](size_t i1, size_t i2) { return sizes[i1] < sizes[i2]; }); + std::vector> res; + std::vector cur; + int64_t cur_size = 0; + for (auto i : indices) { + cur_size += sizes[i]; + cur.push_back(i); + if (cur_size >= base_rows) { + res.push_back(cur); + cur.clear(); + cur_size = 0; + } + } + // tail + if (!cur.empty()) { + if (res.empty()) { + res.push_back(cur); + return res; + } else { + res[res.size() - 1].insert(res[res.size() - 1].end(), cur.begin(), cur.end()); + } + } + return res; +} + } // namespace // Contains an iterator state @@ -532,10 +767,11 @@ struct FaissHnswIteratorWorkspace { // Contains an iterator logic class FaissHnswIterator : public IndexIterator { public: - FaissHnswIterator(const std::shared_ptr& index_in, std::unique_ptr&& query_in, + FaissHnswIterator(const std::shared_ptr& index_in, + const std::shared_ptr>& labels_in, std::unique_ptr&& query_in, const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer, const float refine_ratio = 0.5f, bool use_knowhere_search_pool = true) - : IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio), index{index_in} { + : IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio), index{index_in}, labels{labels_in} { workspace.accumulated_alpha = (bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() @@ -740,6 +976,12 @@ class FaissHnswIterator : public IndexIterator { } } + if (labels != nullptr) { + for (auto& p : workspace.dists) { + p.id = p.id < 0 ? p.id : labels->operator[](p.id); + } + } + // pass back to the handler batch_handler(workspace.dists); @@ -754,10 +996,15 @@ class FaissHnswIterator : public IndexIterator { filter_type sel; next_batch(batch_handler, sel); - } else { + } else if (labels == nullptr) { using filter_type = knowhere::BitsetViewIDSelector; filter_type sel(workspace.bitset); + next_batch(batch_handler, sel); + } else { + using filter_type = knowhere::BitsetViewWithMappingIDSelector; + filter_type sel(workspace.bitset, labels->data()); + next_batch(batch_handler, sel); } } @@ -770,6 +1017,7 @@ class FaissHnswIterator : public IndexIterator { private: std::shared_ptr index; + std::shared_ptr> labels; FaissHnswIteratorWorkspace workspace; }; @@ -783,31 +1031,41 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { bool HasRawData(const std::string& metric_type) const override { - if (this->index == nullptr) { + if (indexes.empty()) { return false; } // check whether there is an index to reconstruct a raw data from - return (GetIndexToReconstructRawDataFrom() != nullptr); + // only check one is enough + return (GetIndexToReconstructRawDataFrom(0) != nullptr); } expected GetVectorByIds(const DataSetPtr dataset) const override { - if (index == nullptr) { + if (indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (const auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } // an index that is used for reconstruction - const faiss::Index* index_to_reconstruct_from = GetIndexToReconstructRawDataFrom(); - - // check whether raw data is available - if (index_to_reconstruct_from == nullptr) { - return expected::Err( - Status::invalid_index_error, - "The index does not contain a raw data, cannot proceed with GetVectorByIds"); + std::vector indexes_to_reconstruct_from(indexes.size()); + for (auto i = 0; i < indexes.size(); ++i) { + const faiss::Index* index_to_reconstruct_from = GetIndexToReconstructRawDataFrom(i); + + // check whether raw data is available + if (index_to_reconstruct_from == nullptr) { + return expected::Err( + Status::invalid_index_error, + "The index does not contain a raw data, cannot proceed with GetVectorByIds"); + } + indexes_to_reconstruct_from[i] = index_to_reconstruct_from; } // perform reconstruction @@ -815,6 +1073,22 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto rows = dataset->GetRows(); auto ids = dataset->GetIds(); + auto get_vector = [&](int64_t id, float* result) -> bool { + if (indexes.size() == 1) { + indexes_to_reconstruct_from[0]->reconstruct(id, result); + } else { + auto it = + std::lower_bound(index_rows_sum.begin(), index_rows_sum.end(), label_to_internal_offset[id] + 1); + if (it == index_rows_sum.end()) { + return false; + } + auto index_id = std::distance(index_rows_sum.begin(), it) - 1; + indexes_to_reconstruct_from[index_id]->reconstruct( + label_to_internal_offset[id] - index_rows_sum[index_id], result); + } + return true; + }; + try { if (data_format == DataFormatEnum::fp32) { // perform a direct reconstruction for fp32 data @@ -822,7 +1096,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, data.get() + i * dim); + if (!get_vector(id, data.get() + i * dim)) { + return expected::Err(Status::invalid_index_error, + "index inner error, cannot proceed with GetVectorByIds"); + } } return GenResultDataSet(rows, dim, std::move(data)); } else if (data_format == DataFormatEnum::fp16) { @@ -832,8 +1109,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + if (!get_vector(id, tmp.get())) { + return expected::Err(Status::invalid_index_error, + "index inner error, cannot proceed with GetVectorByIds"); + } if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -846,8 +1126,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + if (!get_vector(id, tmp.get())) { + return expected::Err(Status::invalid_index_error, + "index inner error, cannot proceed with GetVectorByIds"); + } if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -860,8 +1143,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + if (!get_vector(id, tmp.get())) { + return expected::Err(Status::invalid_index_error, + "index inner error, cannot proceed with GetVectorByIds"); + } if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -878,11 +1164,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { expected Search(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { - if (this->index == nullptr) { + if (this->indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!this->index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (const auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } const auto dim = dataset->GetDim(); @@ -891,7 +1182,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto hnsw_cfg = static_cast(*cfg); const auto k = hnsw_cfg.k.value(); - + auto index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected::Err(Status::invalid_args, "partition key value not correctly set"); + } feder::hnsw::FederResultUniq feder_result; if (hnsw_cfg.trace_visit.value()) { if (rows != 1) { @@ -901,7 +1195,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } // check for brute-force search - auto whether_bf_search = WhetherPerformBruteForceSearch(index.get(), hnsw_cfg, bitset); + auto whether_bf_search = WhetherPerformBruteForceSearch(indexes[index_id].get(), hnsw_cfg, bitset); if (!whether_bf_search.has_value()) { return expected::Err(Status::invalid_args, "k parameter is missing"); @@ -912,7 +1206,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // set up an index wrapper auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( - index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); + indexes[index_id].get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -934,9 +1228,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f; // set up a selector - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewWithMappingIDSelector bw_idselector(bitset, + labels.empty() ? nullptr : labels[index_id].get()->data()); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - hnsw_search_params.sel = id_selector; // run @@ -948,39 +1242,45 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { futs.reserve(rows); for (int64_t i = 0; i < rows; ++i) { - futs.emplace_back( - search_pool->push([&, idx = i, is_refined = is_refined, index_wrapper_ptr = index_wrapper_ptr] { - // 1 thread per element - ThreadPool::ScopedSearchOmpSetter setter(1); - - // set up a query - const float* cur_query = nullptr; - - std::vector cur_query_tmp(dim); - if (data_format == DataFormatEnum::fp32) { - cur_query = (const float*)data + idx * dim; - } else { - convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim); - cur_query = cur_query_tmp.data(); - } + futs.emplace_back(search_pool->push([&, idx = i, is_refined = is_refined, + index_wrapper_ptr = index_wrapper_ptr] { + // 1 thread per element + ThreadPool::ScopedSearchOmpSetter setter(1); + + // set up a query + const float* cur_query = nullptr; + + std::vector cur_query_tmp(dim); + if (data_format == DataFormatEnum::fp32) { + cur_query = (const float*)data + idx * dim; + } else { + convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim); + cur_query = cur_query_tmp.data(); + } - // set up local results - faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; - float* const __restrict local_distances = distances.get() + k * idx; - - // perform the search - if (is_refined) { - faiss::IndexRefineSearchParameters refine_params; - refine_params.k_factor = hnsw_cfg.refine_k.value_or(1); - // a refine procedure itself does not need to care about filtering - refine_params.sel = nullptr; - refine_params.base_index_params = &hnsw_search_params; - - index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &refine_params); - } else { - index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + // set up local results + faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; + float* const __restrict local_distances = distances.get() + k * idx; + + // perform the search + if (is_refined) { + faiss::IndexRefineSearchParameters refine_params; + refine_params.k_factor = hnsw_cfg.refine_k.value_or(1); + // a refine procedure itself does not need to care about filtering + refine_params.sel = nullptr; + refine_params.base_index_params = &hnsw_search_params; + + index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &refine_params); + } else { + index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + } + + if (!labels.empty()) { + for (auto j = 0; j < k; ++j) { + local_ids[j] = local_ids[j] < 0 ? local_ids[j] : labels[index_id]->operator[](local_ids[j]); } - })); + } + })); } // wait for the completion @@ -1006,11 +1306,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { expected RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { - if (this->index == nullptr) { + if (this->indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!this->index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (const auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } const auto dim = dataset->GetDim(); @@ -1018,8 +1323,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto* data = dataset->GetTensor(); const auto hnsw_cfg = static_cast(*cfg); + auto index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected::Err(Status::invalid_args, "partition key value not correctly set"); + } - const bool is_similarity_metric = faiss::is_similarity_metric(index->metric_type); + const bool is_similarity_metric = faiss::is_similarity_metric(indexes[index_id]->metric_type); const float radius = hnsw_cfg.radius.value(); const float range_filter = hnsw_cfg.range_filter.value(); @@ -1033,7 +1342,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } // check for brute-force search - auto whether_bf_search = WhetherPerformBruteForceRangeSearch(index.get(), hnsw_cfg, bitset); + auto whether_bf_search = WhetherPerformBruteForceRangeSearch(indexes[index_id].get(), hnsw_cfg, bitset); if (!whether_bf_search.has_value()) { return expected::Err(Status::invalid_args, "ef parameter is missing"); @@ -1044,7 +1353,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // set up an index wrapper auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( - index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); + indexes[index_id].get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -1067,9 +1376,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f; // set up a selector - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewWithMappingIDSelector bw_idselector(bitset, + labels.empty() ? nullptr : labels[index_id].get()->data()); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - hnsw_search_params.sel = id_selector; //////////////////////////////////////////////////////////////// @@ -1122,9 +1431,17 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { result_dist_array[idx].resize(elem_cnt); result_id_array[idx].resize(elem_cnt); - for (size_t j = 0; j < elem_cnt; j++) { - result_dist_array[idx][j] = res.distances[j]; - result_id_array[idx][j] = res.labels[j]; + if (labels.empty()) { + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[idx][j] = res.distances[j]; + result_id_array[idx][j] = res.labels[j]; + } + } else { + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[idx][j] = res.distances[j]; + result_id_array[idx][j] = + res.labels[j] < 0 ? res.labels[j] : labels[index_id]->operator[](res.labels[j]); + } } if (hnsw_cfg.range_filter.value() != defaultRangeFilter) { @@ -1147,22 +1464,53 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { protected: DataFormatEnum data_format; + std::vector> tmp_combined_scalar_ids; + Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (index == nullptr) { + if (isIndexEmpty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } auto rows = dataset->GetRows(); - try { - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - auto status = add_to_index(index.get(), dataset, data_format); - if (status != Status::success) { + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + if (scalar_info_map.empty()) { + try { + LOG_KNOWHERE_INFO_ << "Adding " << rows << " rows to HNSW Index"; + + auto status = add_to_index(indexes[0].get(), dataset, data_format); return status; + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; } + } + + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + LOG_KNOWHERE_INFO_ << "Add data to Index with Scalar Info"; + try { + for (const auto& [field_id, scalar_info] : scalar_info_map) { + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status = + add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status != Status::success) { + return status; + } + } + } + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -1172,8 +1520,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } const faiss::Index* - GetIndexToReconstructRawDataFrom() const { - if (index == nullptr) { + GetIndexToReconstructRawDataFrom(int i) const { + if (indexes.size() <= i) { + return nullptr; + } + if (indexes[i] == nullptr) { return nullptr; } @@ -1181,12 +1532,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const faiss::Index* index_to_reconstruct_from = nullptr; // check whether our index uses refine - auto index_refine = dynamic_cast(index.get()); + auto index_refine = dynamic_cast(indexes[i].get()); if (index_refine == nullptr) { // non-refined index // cast as IndexHNSW - auto index_hnsw = dynamic_cast(index.get()); + auto index_hnsw = dynamic_cast(indexes[i].get()); if (index_hnsw == nullptr) { // this is unexpected, we expect IndexHNSW return nullptr; @@ -1215,13 +1566,53 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return index_to_reconstruct_from; } + Status + TrainIndexByScalarInfo(std::function train_index, + const std::vector>& scalar_info, const void* data, const int64_t rows, + const int64_t dim) { + label_to_internal_offset.resize(rows); + index_rows_sum.resize(tmp_combined_scalar_ids.size() + 1); + labels.resize(tmp_combined_scalar_ids.size()); + indexes.resize(tmp_combined_scalar_ids.size()); + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + size_t partition_size = 0; + for (int j : tmp_combined_scalar_ids[i]) { + partition_size += scalar_info[j].size(); + } + std::unique_ptr tmp_data = std::make_unique(dim * partition_size); + labels[i] = std::make_shared>(partition_size); + index_rows_sum[i + 1] = index_rows_sum[i] + partition_size; + size_t cur_size = 0; + + for (size_t j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto scalar_id = tmp_combined_scalar_ids[i][j]; + if (!convert_rows_to_fp32(data, tmp_data.get() + dim * cur_size, data_format, + scalar_info[scalar_id].data(), scalar_info[scalar_id].size(), dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + for (size_t m = 0; m < scalar_info[scalar_id].size(); ++m) { + labels[i]->operator[](cur_size + m) = scalar_info[scalar_id][m]; + label_to_internal_offset[scalar_info[scalar_id][m]] = index_rows_sum[i] + cur_size + m; + } + cur_size += scalar_info[scalar_id].size(); + } + + Status s = train_index((const float*)(tmp_data.get()), i, partition_size); + if (s != Status::success) { + return s; + } + } + return Status::success; + } + public: // expected> AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, bool use_knowhere_search_pool) const override { - if (index == nullptr) { - LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; + if (isIndexEmpty()) { + LOG_KNOWHERE_ERROR_ << "creating iterator on empty index"; return expected>::Err(Status::empty_index, "index not loaded"); } @@ -1239,6 +1630,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto vec = std::vector(n_queries, nullptr); const FaissHnswConfig& hnsw_cfg = static_cast(*cfg); + int index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected>::Err(Status::invalid_args, + "partition key value not correctly set"); + } const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE); const bool larger_is_closer = (IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::IP) || is_cosine); @@ -1263,7 +1659,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { throw; } - const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); + const bool should_use_refine = + (dynamic_cast(indexes[index_id].get()) != nullptr); const float iterator_refine_ratio = should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; @@ -1271,8 +1668,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // create an iterator and initialize it // refine is not needed for flat // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) - auto it = std::make_shared(index, std::move(cur_query), bitset, ef, larger_is_closer, - iterator_refine_ratio, use_knowhere_search_pool); + auto it = std::make_shared( + indexes[index_id], labels.empty() ? nullptr : labels[index_id], std::move(cur_query), bitset, ef, + larger_is_closer, iterator_refine_ratio, use_knowhere_search_pool); // store vec[i] = it; } @@ -1329,54 +1727,75 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); std::unique_ptr hnsw_index; - if (is_cosine) { - if (data_format == DataFormatEnum::fp32) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, - hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, - hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::int8) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value()); - } else { - LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); - return Status::invalid_metric_type; - } - } else { - if (data_format == DataFormatEnum::fp32) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, - hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, - hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::int8) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, - hnsw_cfg.M.value(), metric.value()); + auto train_index = [&](const float* data, const int i, const int64_t rows) { + if (is_cosine) { + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::int8) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } else { - LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); - return Status::invalid_metric_type; + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::int8) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value(), metric.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // this function does nothing for the given parameters and indices. + // as a result, I'm just keeping it to have is_trained set to true. + // WARNING: this may cause problems if ->train() performs some action + // based on the data in the future. Otherwise, data needs to be + // converted into float*. + hnsw_index->train(rows, data); + + // done + indexes[i] = std::move(hnsw_index); + return Status::success; + }; + + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + for (const auto& [field_id, scalar_info] : scalar_info_map) { + tmp_combined_scalar_ids = + scalar_info.size() > 1 ? combine_partitions(scalar_info, 128) : std::vector>(); } - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - - // train - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - - // this function does nothing for the given parameters and indices. - // as a result, I'm just keeping it to have is_trained set to true. - // WARNING: this may cause problems if ->train() performs some action - // based on the data in the future. Otherwise, data needs to be - // converted into float*. - hnsw_index->train(rows, (const float*)data); + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + return train_index((const float*)(data), 0, rows); + } - // done - index = std::move(hnsw_index); + LOG_KNOWHERE_INFO_ << "Train HNSW index with Scalar Info"; + for (const auto& [field_id, scalar_info] : scalar_info_map) { + return TrainIndexByScalarInfo(train_index, scalar_info, data, rows, dim); + } return Status::success; } }; @@ -1409,6 +1828,15 @@ class HNSWIndexNodeWithFallback : public IndexNode { } } + bool + IsAdditionalScalarSupported() const override { + if (use_base_index) { + return base_index->IsAdditionalScalarSupported(); + } else { + return fallback_search_index->IsAdditionalScalarSupported(); + } + } + Status Train(const DataSetPtr dataset, std::shared_ptr cfg) override { if (use_base_index) { @@ -1805,6 +2233,8 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); + // data + const void* data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); @@ -1825,48 +2255,70 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { // create an index const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); - } - - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - // should refine be used? std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + + auto train_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value()); + } else { + hnsw_index = + std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - // assign - final_index = std::move(hnsw_index); - } + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // train - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // assign + final_index = std::move(hnsw_index); + } - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // done - index = std::move(final_index); + final_index->train(rows, data); + // done + indexes[i] = std::move(final_index); + return Status::success; + }; + + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + for (const auto& [field_id, scalar_info] : scalar_info_map) { + tmp_combined_scalar_ids = + scalar_info.size() > 1 ? combine_partitions(scalar_info, 128) : std::vector>(); + } + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return train_index(reinterpret_cast(float_ds_ptr->GetTensor()), 0, rows); + } + LOG_KNOWHERE_INFO_ << "Train HNSWSQ Index with Scalar Info"; + for (const auto& [field_id, scalar_info] : scalar_info_map) { + return TrainIndexByScalarInfo(train_index, scalar_info, data, rows, dim); + } return Status::success; } }; @@ -1914,7 +2366,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } protected: - std::unique_ptr tmp_index_pq; + std::vector> tmp_index_pq; Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { @@ -1922,10 +2374,18 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); + // data + const void* data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); + if (rows < (1 << hnsw_cfg.nbits.value())) { + LOG_KNOWHERE_ERROR_ << rows << " rows not enough, needs at least " << (1 << hnsw_cfg.nbits.value()) + << " rows"; + return Status::faiss_inner_error; + } + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); if (!metric.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); @@ -1937,105 +2397,114 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // HNSW + PQ index yields BAD recall somewhy. // Let's build HNSW+FLAT index, then replace FLAT with PQ + auto train_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + + // pq + std::unique_ptr pq_index; + if (is_cosine) { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + } else { + pq_index = + std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + } - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // pq - std::unique_ptr pq_index; - if (is_cosine) { - pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); - } else { - pq_index = - std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // should refine be used? - std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + // assign + final_index = std::move(hnsw_index); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine - - // assign - final_index = std::move(hnsw_index); - } + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } + final_index->train(rows, data); - // train hnswflat - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // train pq + LOG_KNOWHERE_INFO_ << "Training PQ Index"; - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + pq_index->train(rows, data); + pq_index->pq.compute_sdc_table(); - // train pq - LOG_KNOWHERE_INFO_ << "Training PQ Index"; + // done + indexes[i] = std::move(final_index); + tmp_index_pq[i] = std::move(pq_index); + return Status::success; + }; - pq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); - pq_index->pq.compute_sdc_table(); + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + for (const auto& [field_id, scalar_info] : scalar_info_map) { + tmp_combined_scalar_ids = scalar_info.size() > 1 + ? combine_partitions(scalar_info, (1 << hnsw_cfg.nbits.value())) + : std::vector>(); + } - // done - index = std::move(final_index); - tmp_index_pq = std::move(pq_index); + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + tmp_index_pq.resize(1); + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return train_index((const float*)(float_ds_ptr->GetTensor()), 0, rows); + } + LOG_KNOWHERE_INFO_ << "Train HNSWPQ Index with Scalar Info"; + tmp_index_pq.resize(tmp_combined_scalar_ids.size()); + for (const auto& [field_id, scalar_info] : scalar_info_map) { + return TrainIndexByScalarInfo(train_index, scalar_info, data, rows, dim); + } return Status::success; } Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (isIndexEmpty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } auto rows = dataset->GetRows(); - try { - // hnsw - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - - auto status_reg = add_to_index(index.get(), dataset, data_format); - if (status_reg != Status::success) { - return status_reg; - } - - // pq - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; - - auto status_pq = add_to_index(tmp_index_pq.get(), dataset, data_format); - if (status_pq != Status::success) { - return status_pq; - } + auto finalize_index = [&](int i) { // we're done. // throw away flat and replace it with pq // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(indexes[i].get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); } else { - index_hnsw = dynamic_cast(index.get()); + index_hnsw = dynamic_cast(indexes[i].get()); } // recreate hnswpq @@ -2057,14 +2526,70 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { index_hnsw_pq->storage = nullptr; // replace storage - index_hnsw_pq->storage = tmp_index_pq.release(); + index_hnsw_pq->storage = tmp_index_pq[i].release(); // replace if refine if (index_refine != nullptr) { delete index_refine->base_index; index_refine->base_index = index_hnsw_pq.release(); } else { - index = std::move(index_hnsw_pq); + indexes[i] = std::move(index_hnsw_pq); + } + return Status::success; + }; + try { + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + auto status_reg = add_to_index(indexes[0].get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; + + auto status_pq = add_to_index(tmp_index_pq[0].get(), dataset, data_format); + if (status_pq != Status::success) { + return status_pq; + } + return finalize_index(0); + } + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + LOG_KNOWHERE_INFO_ << "Add data to Index with Scalar Info"; + + for (const auto& [field_id, scalar_info] : scalar_info_map) { + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status_reg = + add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status_reg != Status::success) { + return status_reg; + } + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to PQ Index"; + + auto status_pq = + add_partial_dataset_to_index(tmp_index_pq[i].get(), dataset, data_format, scalar_info[id]); + + if (status_pq != Status::success) { + return status_pq; + } + } + finalize_index(i); + } } } catch (const std::exception& e) { @@ -2113,7 +2638,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } protected: - std::unique_ptr tmp_index_prq; + std::vector> tmp_index_prq; Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { @@ -2121,10 +2646,18 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { auto rows = dataset->GetRows(); // dimensionality of the data auto dim = dataset->GetDim(); + // data + const void* data = dataset->GetTensor(); // config auto hnsw_cfg = static_cast(cfg); + if (rows < (1 << hnsw_cfg.nbits.value())) { + LOG_KNOWHERE_ERROR_ << rows << " rows not enough, needs at least " << (1 << hnsw_cfg.nbits.value()) + << " rows"; + return Status::faiss_inner_error; + } + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); if (!metric.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); @@ -2136,110 +2669,121 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // HNSW + PRQ index yields BAD recall somewhy. // Let's build HNSW+FLAT index, then replace FLAT with PRQ + auto train_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // prq + faiss::AdditiveQuantizer::Search_type_t prq_search_type = + (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) + ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm + : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + + std::unique_ptr prq_index; + if (is_cosine) { + prq_index = std::make_unique( + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), prq_search_type); + } else { + prq_index = std::make_unique( + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), metric.value(), + prq_search_type); + } - // prq - faiss::AdditiveQuantizer::Search_type_t prq_search_type = - (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) - ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm - : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - std::unique_ptr prq_index; - if (is_cosine) { - prq_index = std::make_unique( - dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), prq_search_type); - } else { - prq_index = std::make_unique( - dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), metric.value(), prq_search_type); - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // should refine be used? - std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + // assign + final_index = std::move(hnsw_index); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // assign - final_index = std::move(hnsw_index); - } + final_index->train(rows, data); - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } + // train prq + LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; - // train hnswflat - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + prq_index->train(rows, data); - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + // done + indexes[i] = std::move(final_index); + tmp_index_prq[i] = std::move(prq_index); - // train prq - LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; + return Status::success; + }; + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + for (const auto& [field_id, scalar_info] : scalar_info_map) { + tmp_combined_scalar_ids = scalar_info.size() > 1 + ? combine_partitions(scalar_info, (1 << hnsw_cfg.nbits.value())) + : std::vector>(); + } - prq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + tmp_index_prq.resize(1); + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - // done - index = std::move(final_index); - tmp_index_prq = std::move(prq_index); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return train_index((const float*)(float_ds_ptr->GetTensor()), 0, rows); + } + LOG_KNOWHERE_INFO_ << "Train HNSWPRQ Index with Scalar Info"; + tmp_index_prq.resize(tmp_combined_scalar_ids.size()); + for (const auto& [field_id, scalar_info] : scalar_info_map) { + return TrainIndexByScalarInfo(train_index, scalar_info, data, rows, dim); + } return Status::success; } Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (isIndexEmpty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } auto rows = dataset->GetRows(); - try { - // hnsw - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - - auto status_reg = add_to_index(index.get(), dataset, data_format); - if (status_reg != Status::success) { - return status_reg; - } - - // prq - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; - - auto status_prq = add_to_index(tmp_index_prq.get(), dataset, data_format); - if (status_prq != Status::success) { - return status_prq; - } + auto finalize_index = [&](int i) { // we're done. // throw away flat and replace it with prq // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(indexes[i].get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); } else { - index_hnsw = dynamic_cast(index.get()); + index_hnsw = dynamic_cast(indexes[i].get()); } // recreate hnswprq @@ -2261,14 +2805,71 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { index_hnsw_prq->storage = nullptr; // replace storage - index_hnsw_prq->storage = tmp_index_prq.release(); + index_hnsw_prq->storage = tmp_index_prq[i].release(); // replace if refine if (index_refine != nullptr) { delete index_refine->base_index; index_refine->base_index = index_hnsw_prq.release(); } else { - index = std::move(index_hnsw_prq); + indexes[i] = std::move(index_hnsw_prq); + } + return Status::success; + }; + try { + const std::unordered_map>>& scalar_info_map = + dataset->Get>>>(meta::SCALAR_INFO); + + if (scalar_info_map.empty() || tmp_combined_scalar_ids.size() <= 1) { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + auto status_reg = add_to_index(indexes[0].get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } + + // prq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; + + auto status_prq = add_to_index(tmp_index_prq[0].get(), dataset, data_format); + if (status_prq != Status::success) { + return status_prq; + } + return finalize_index(0); + } + + if (scalar_info_map.size() > 1) { + LOG_KNOWHERE_WARNING_ << "vector index build with multiple scalar info is not supported"; + return Status::invalid_args; + } + LOG_KNOWHERE_INFO_ << "Add data to Index with Scalar Info"; + + for (const auto& [field_id, scalar_info] : scalar_info_map) { + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status_reg = + add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status_reg != Status::success) { + return status_reg; + } + + // prq + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to PQ Index"; + + auto status_prq = + add_partial_dataset_to_index(tmp_index_prq[i].get(), dataset, data_format, scalar_info[id]); + + if (status_prq != Status::success) { + return status_prq; + } + } + finalize_index(i); + } } } catch (const std::exception& e) { @@ -2294,7 +2895,6 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS } }; -// MV is only for compatibility #ifdef KNOWHERE_WITH_CARDINAL KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback, @@ -2307,13 +2907,16 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNo #endif KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) } // namespace knowhere diff --git a/src/index/hnsw/impl/IndexBruteForceWrapper.cc b/src/index/hnsw/impl/IndexBruteForceWrapper.cc index 748bc8263..0b12fe525 100644 --- a/src/index/hnsw/impl/IndexBruteForceWrapper.cc +++ b/src/index/hnsw/impl/IndexBruteForceWrapper.cc @@ -44,6 +44,21 @@ struct BitsetViewIDSelectorWrapper final { } }; +struct BitsetViewWithMappingIDSelectorWrapper final { + const BitsetView bitset_view; + const uint32_t* out_id_mapping; + + inline BitsetViewWithMappingIDSelectorWrapper(BitsetView bitset_view, const uint32_t* out_id_mapping) + : bitset_view{bitset_view}, out_id_mapping{out_id_mapping} { + } + + [[nodiscard]] inline bool + is_member(faiss::idx_t id) const { + // it is by design that bitset_view.empty() and out_id_mapping == nullptr is not tested here + return (!bitset_view.test(out_id_mapping[id])); + } +}; + // IndexBruteForceWrapper::IndexBruteForceWrapper(faiss::Index* underlying_index) : faiss::cppcontrib::knowhere::IndexWrapper{underlying_index} { @@ -70,8 +85,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* __restrict bw_idselector = - dynamic_cast(sel); + const knowhere::BitsetViewWithMappingIDSelector* __restrict bw_idselector = + dynamic_cast(sel); if (is_similarity_metric(index->metric_type)) { using C = faiss::CMin; @@ -80,12 +95,19 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: faiss::IDSelectorAll sel_all; faiss::cppcontrib::knowhere::brute_force_search_impl( index->ntotal, *dis, sel_all, k, local_distances, local_ids); - } else { + } else if (bw_idselector->out_id_mapping == nullptr) { BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view); faiss::cppcontrib::knowhere::brute_force_search_impl( index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } else { + BitsetViewWithMappingIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view, + bw_idselector->out_id_mapping); + + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); } } else { using C = faiss::CMax; @@ -94,12 +116,19 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss: faiss::IDSelectorAll sel_all; faiss::cppcontrib::knowhere::brute_force_search_impl( index->ntotal, *dis, sel_all, k, local_distances, local_ids); - } else { + } else if (bw_idselector->out_id_mapping == nullptr) { BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view); faiss::cppcontrib::knowhere::brute_force_search_impl( index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); + } else { + BitsetViewWithMappingIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view, + bw_idselector->out_id_mapping); + + faiss::cppcontrib::knowhere::brute_force_search_impl( + index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids); } } } diff --git a/src/index/hnsw/impl/IndexConditionalWrapper.cc b/src/index/hnsw/impl/IndexConditionalWrapper.cc index e3ab59e82..69410cbdf 100644 --- a/src/index/hnsw/impl/IndexConditionalWrapper.cc +++ b/src/index/hnsw/impl/IndexConditionalWrapper.cc @@ -51,8 +51,8 @@ WhetherPerformBruteForceSearch(const faiss::Index* index, const BaseConfig& cfg, double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold) || - k >= (index->ntotal - filtered_out_num) * HnswSearchThresholds::kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (bitset.size() * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold) || + k >= (bitset.size() - filtered_out_num) * HnswSearchThresholds::kHnswSearchBFTopkThreshold) { return true; } } @@ -84,8 +84,8 @@ WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswCo double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (index->ntotal * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) || - ef >= (index->ntotal - filtered_out_num) * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) { + if (filtered_out_num >= (bitset.size() * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) || + ef >= (bitset.size() - filtered_out_num) * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) { return true; } } diff --git a/src/index/hnsw/impl/IndexHNSWWrapper.cc b/src/index/hnsw/impl/IndexHNSWWrapper.cc index 6cd55e569..ceca0e37a 100644 --- a/src/index/hnsw/impl/IndexHNSWWrapper.cc +++ b/src/index/hnsw/impl/IndexHNSWWrapper.cc @@ -135,8 +135,8 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* __restrict bw_idselector = - dynamic_cast(sel); + const knowhere::BitsetViewWithMappingIDSelector* __restrict bw_idselector = + dynamic_cast(sel); if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { // no filter @@ -166,9 +166,10 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r local_stats = searcher.search(k, distances + i * k, labels + i * k); } - } else { - // with filter + } else if (bw_idselector->out_id_mapping == nullptr) { + // with filter, no mapping + knowhere::BitsetViewIDSelector bw_idselector_no_mapping(bw_idselector->bitset_view); // feder templating is important, bcz it removes an unneeded 'CALL' instruction. if (feder == nullptr) { // no feder @@ -179,6 +180,36 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r faiss::cppcontrib::knowhere::Bitset, knowhere::BitsetViewIDSelector>; + searcher_type searcher{ + hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, bw_idselector_no_mapping, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } else { + // use feder + FederVisitor graph_visitor(feder); + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + + searcher_type searcher{ + hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, bw_idselector_no_mapping, kAlpha, params}; + + local_stats = searcher.search(k, distances + i * k, labels + i * k); + } + } else { + // with filter + // feder templating is important, bcz it removes an unneeded 'CALL' instruction. + if (feder == nullptr) { + // no feder + DummyVisitor graph_visitor; + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, *bw_idselector, kAlpha, params}; @@ -190,7 +221,7 @@ IndexHNSWWrapper::search(idx_t n, const float* __restrict x, idx_t k, float* __r using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher; + knowhere::BitsetViewWithMappingIDSelector>; searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, *bw_idselector, kAlpha, params}; @@ -301,8 +332,8 @@ IndexHNSWWrapper::range_search(idx_t n, const float* __restrict x, float radius_ faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel; // try knowhere-specific filter - const knowhere::BitsetViewIDSelector* __restrict bw_idselector = - dynamic_cast(sel); + const knowhere::BitsetViewWithMappingIDSelector* __restrict bw_idselector = + dynamic_cast(sel); if (bw_idselector == nullptr || bw_idselector->bitset_view.empty()) { // no filter @@ -330,6 +361,38 @@ IndexHNSWWrapper::range_search(idx_t n, const float* __restrict x, float radius_ searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, sel_all, kAlpha, params}; + local_stats = searcher.range_search(radius, &res_min); + } + } else if (bw_idselector->out_id_mapping == nullptr) { + // with filter, no mapping + + knowhere::BitsetViewIDSelector bw_selector_no_mapping(bw_idselector->bitset_view); + // feder templating is important, bcz it removes an unneeded 'CALL' instruction. + if (feder == nullptr) { + // no feder + DummyVisitor graph_visitor; + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + + searcher_type searcher{ + hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, bw_selector_no_mapping, kAlpha, params}; + + local_stats = searcher.range_search(radius, &res_min); + } else { + // use feder + FederVisitor graph_visitor(feder); + + using searcher_type = + faiss::cppcontrib::knowhere::v2_hnsw_searcher; + + searcher_type searcher{ + hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, bw_selector_no_mapping, kAlpha, params}; + local_stats = searcher.range_search(radius, &res_min); } } else { @@ -343,7 +406,7 @@ IndexHNSWWrapper::range_search(idx_t n, const float* __restrict x, float radius_ using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher; + knowhere::BitsetViewWithMappingIDSelector>; searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, *bw_idselector, kAlpha, params}; @@ -356,7 +419,7 @@ IndexHNSWWrapper::range_search(idx_t n, const float* __restrict x, float radius_ using searcher_type = faiss::cppcontrib::knowhere::v2_hnsw_searcher; + knowhere::BitsetViewWithMappingIDSelector>; searcher_type searcher{hnsw, *(dis.get()), graph_visitor, bitset_visited_nodes, *bw_idselector, kAlpha, params}; diff --git a/tests/ut/test_faiss_hnsw.cc b/tests/ut/test_faiss_hnsw.cc index 33c2d9474..d207e4c25 100644 --- a/tests/ut/test_faiss_hnsw.cc +++ b/tests/ut/test_faiss_hnsw.cc @@ -190,7 +190,7 @@ read_index(knowhere::Index& index, const std::string& filen template knowhere::Index create_index(const std::string& index_type, const std::string& index_file_name, - const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Json& conf, + const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Json& conf, const bool mv_only_enable, const std::optional& additional_name = std::nullopt) { std::string additional_name_s = additional_name.value_or(""); @@ -208,6 +208,12 @@ create_index(const std::string& index_type, const std::string& index_file_name, auto base = knowhere::ConvertToDataTypeIfNeeded(default_ds_ptr); + if (mv_only_enable) { + base->Set(knowhere::meta::SCALAR_INFO, + default_ds_ptr->Get>>>( + knowhere::meta::SCALAR_INFO)); + } + StopWatch sw; index.value().Build(base, conf); double elapsed = sw.elapsed(); @@ -252,10 +258,10 @@ index_support_int8(const knowhere::Json& conf) { // template -void +std::string test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr& query_ds_ptr, const knowhere::DataSetPtr& golden_result, const std::vector& index_params, - const knowhere::Json& conf, const knowhere::BitsetView bitset_view) { + const knowhere::Json& conf, const bool mv_only_enable, const knowhere::BitsetView bitset_view) { const std::string index_type = conf[knowhere::meta::INDEX_TYPE].get(); // load indices @@ -266,10 +272,10 @@ test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr // our index // first, we create an index and save it - auto index = create_index(index_type, index_file_name, default_ds_ptr, conf); + auto index = create_index(index_type, index_file_name, default_ds_ptr, conf, mv_only_enable); // then, we force it to be loaded in order to test load & save - auto index_loaded = create_index(index_type, index_file_name, default_ds_ptr, conf); + auto index_loaded = create_index(index_type, index_file_name, default_ds_ptr, conf, mv_only_enable); // query auto query_t_ds_ptr = knowhere::ConvertToDataTypeIfNeeded(query_ds_ptr); @@ -313,14 +319,15 @@ test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr match_datasets(default_t_ds_ptr, vectors.value(), ids); } + return index_file_name; } // template -void +std::string test_hnsw_range(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr& query_ds_ptr, const knowhere::DataSetPtr& golden_result, const std::vector& index_params, - const knowhere::Json& conf, const knowhere::BitsetView bitset_view) { + const knowhere::Json& conf, const bool mv_only_enable, const knowhere::BitsetView bitset_view) { const std::string index_type = conf[knowhere::meta::INDEX_TYPE].get(); // load indices @@ -331,10 +338,10 @@ test_hnsw_range(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Data // our index // first, we create an index and save it - auto index = create_index(index_type, index_file_name, default_ds_ptr, conf); + auto index = create_index(index_type, index_file_name, default_ds_ptr, conf, mv_only_enable); // then, we force it to be loaded in order to test load & save - auto index_loaded = create_index(index_type, index_file_name, default_ds_ptr, conf); + auto index_loaded = create_index(index_type, index_file_name, default_ds_ptr, conf, mv_only_enable); // query auto query_t_ds_ptr = knowhere::ConvertToDataTypeIfNeeded(query_ds_ptr); @@ -381,6 +388,7 @@ test_hnsw_range(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Data match_datasets(default_t_ds_ptr, vectors.value(), ids); } + return index_file_name; } } // namespace @@ -403,6 +411,8 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { const int32_t NQ = 16; const int32_t TOPK = 16; + const std::vector MV_ONLYs = {false, true}; + const std::vector SQ_TYPES = {"SQ6", "SQ8", "BF16", "FP16"}; // random bitset rates @@ -465,7 +475,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { std::string golden_index_file_name = get_index_name(ann_test_name_, golden_index_type, golden_params); - create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, + create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, false, "golden "); } } @@ -504,55 +514,97 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf, "golden "); - - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + default_ds_ptr, conf, false, "golden "); + + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value + + for (const bool mv_only_enable : MV_ONLYs) { +#ifdef KNOWHERE_WITH_CARDINAL + if (mv_only_enable) { + continue; + } +#endif + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf, bitset_view); + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // fp32 candidate - printf( - "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf, bitset_view); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // fp32 candidate + printf( + "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // fp16 candidate - printf( - "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // fp16 candidate + printf( + "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // bf16 candidate - printf( - "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " "out\n", DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + if (index_support_int8(conf)) { + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + } + } + for (auto index : index_files) { + std::remove(index.c_str()); } } } @@ -588,155 +640,193 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); - - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); - } + default_ds_ptr, conf_golden, false, "golden "); - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // go SQ - for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); + } + std::vector index_files; + std::string index_file; - const std::string sq_type = SQ_TYPES[i_sq_type]; - conf[knowhere::indexparam::SQ_TYPE] = sq_type; + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // go SQ + for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + const std::string sq_type = SQ_TYPES[i_sq_type]; + conf[knowhere::indexparam::SQ_TYPE] = sq_type; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + + // fp16 candidate printf( - "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered " "out\n", sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test refines for FP32 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // fp32 candidate + if (index_support_int8(conf)) { + // int8 candidate printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); } - } + // test refines for FP32 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // test refines for FP16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + } - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + // test refines for FP16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); + + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } - } - // test refines for BF16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test refines for BF16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -771,146 +861,183 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); - - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); - } + default_ds_ptr, conf_golden, false, "golden "); - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // go PQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int pq_m = 8; - - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = pq_m; + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); + } + std::vector index_files; + std::string index_file; - std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // go PQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int pq_m = 8; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = pq_m; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // test int8 candidate + // test fp32 candidate printf( - "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered out\n", pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test refines for fp32 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // test fp32 candidate + // test bf16 candidate printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - // test refines for fp16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + // test refines for fp32 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - // test refines for bf16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test refines for fp16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + + // test refines for bf16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -945,152 +1072,190 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + default_ds_ptr, conf_golden, false, "golden "); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value + + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - // go PRQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int prq_m = 4; - const int prq_num = 2; + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = prq_m; - conf[knowhere::indexparam::PRQ_NUM] = prq_num; + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, - (int)nbits_type}; + // go PRQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int prq_m = 4; + const int prq_num = 2; - // test fp32 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = prq_m; + conf[knowhere::indexparam::PRQ_NUM] = prq_num; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, + (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // test bf16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test bf16 candidate printf( - "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered out\n", prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test fp32 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% " + "points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + // test fp32 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // test fp16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test a candidate + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + // test a candidate + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - // test bf16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test bf16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + // test a candidate + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1123,6 +1288,8 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra const std::vector NBS = {256}; const int32_t NQ = 16; + const std::vector MV_ONLYs = {false, true}; + const std::vector SQ_TYPES = {"SQ6", "SQ8", "BF16", "FP16"}; // random bitset rates @@ -1184,7 +1351,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra std::string golden_index_file_name = get_index_name(ann_test_name_, golden_index_type, golden_params); - create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, + create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, false, "golden "); } } @@ -1192,11 +1359,10 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra } // I'd like to have a sequential process here, because every item in the loop - // is parallelized on its own + // is parallelized on its own SECTION("FLAT") { const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW; - // const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { @@ -1237,62 +1403,95 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf, "golden "); + default_ds_ptr, conf, false, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { +#ifdef KNOWHERE_WITH_CARDINAL + if (mv_only_enable) { + continue; } +#endif + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); + } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf, bitset_view); - - // fp32 candidate - printf( - "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf, bitset_view); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); + // fp32 candidate + printf( + "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // fp16 candidate - printf( - "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + index_file = + test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); + // fp16 candidate + printf( + "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // bf16 candidate - printf( - "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + index_file = + test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = + test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + } + for (auto index : index_files) { + std::remove(index.c_str()); } } } @@ -1342,165 +1541,204 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); + default_ds_ptr, conf_golden, false, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go SQ - for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - - const std::string sq_type = SQ_TYPES[i_sq_type]; - conf[knowhere::indexparam::SQ_TYPE] = sq_type; + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; + // go SQ + for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + const std::string sq_type = SQ_TYPES[i_sq_type]; + conf[knowhere::indexparam::SQ_TYPE] = sq_type; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } - - // test refines for FP32 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); + + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; + // test refines for FP32 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + } - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); + // test refines for FP16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } + } - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); + // test refines for BF16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } + } - // test refines for FP16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; - - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; - - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; - - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); - - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); - } - } - - // test refines for BF16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; - - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; - - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; - - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); - - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); - } - } - } - } - } - } - } - } - } + for (auto index : index_files) { + std::remove(index.c_str()); + } + } + } + } + } + } + } SECTION("PQ") { const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PQ; @@ -1543,151 +1781,185 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); + default_ds_ptr, conf_golden, false, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go PQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int pq_m = 8; - - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = pq_m; + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; + // go PQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int pq_m = 8; - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = pq_m; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); - - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test fp16 candidate printf( - "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // test refines for fp32 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + radius, range_filter, int(bitset_rate * 100)); + + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + // test refines for fp32 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test refines for fp16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + // test refines for fp16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - // test refines for bf16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test refines for bf16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1736,158 +2008,199 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra get_index_name(ann_test_name_, golden_index_type, golden_params); auto golden_index = create_index(golden_index_type, golden_index_file_name, - default_ds_ptr, conf_golden, "golden "); + default_ds_ptr, conf_golden, false, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::unordered_map>> scalar_info = + GenerateScalarInfo(nb); + auto partition_size = scalar_info[0][0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go PRQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int prq_m = 4; - const int prq_num = 2; + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0][0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = prq_m; - conf[knowhere::indexparam::PRQ_NUM] = prq_num; + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, - (int)nbits_type}; + // go PRQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int prq_m = 4; + const int prq_num = 2; - // test fp32 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = prq_m; + conf[knowhere::indexparam::PRQ_NUM] = prq_num; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, + (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // test bf16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test bf16 candidate printf( - "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, " "radius=%f, " "range_filter=%f, %d%% points filtered out\n", prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, conf, + mv_only_enable, bitset_view); + index_files.emplace_back(index_file); - // test fp32 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), + dim, nb, radius, range_filter, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + // test fp32 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test fp16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // test bf16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + // test bf16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; + + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, mv_only_enable, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1946,8 +2259,8 @@ TEST_CASE("hnswlib to FAISS HNSW for HNSW_FLAT", "Check search fallback") { std::string hnswlib_index_file_name = get_index_name(ann_test_name_, hnswlib_index_type, hnswlib_params); - auto hnswlib_index = - create_index(hnswlib_index_type, hnswlib_index_file_name, default_ds_ptr, conf, "hnswlib "); + auto hnswlib_index = create_index(hnswlib_index_type, hnswlib_index_file_name, default_ds_ptr, + conf, false, "hnswlib"); // perform an hnswlib search auto hnswlib_result = hnswlib_index.Search(query_ds_ptr, conf, nullptr); @@ -1959,7 +2272,7 @@ TEST_CASE("hnswlib to FAISS HNSW for HNSW_FLAT", "Check search fallback") { // load index back, but as a faiss index auto faiss_index = create_index(faiss_index_type, hnswlib_index_file_name, default_ds_ptr, - faiss_conf, "hnswlib "); + faiss_conf, false, "hnswlib "); // perform a faiss search auto faiss_result = faiss_index.Search(query_ds_ptr, faiss_conf, nullptr); diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index 276051b88..5a7971676 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -472,8 +472,13 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { } SECTION("Check MV") { - // Only HNSW supports Materialized View + // Only HNSW family supports Materialized View REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::MV)); // All other indexes do not support MV REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_IDMAP, knowhere::feature::MV)); @@ -483,12 +488,7 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_SCANN, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IDMAP, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IVFFLAT, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::MV)); + REQUIRE_FALSE( IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_INVERTED_INDEX, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_WAND, knowhere::feature::MV)); diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 3b62ba7ef..b09f7d2f7 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -283,6 +283,52 @@ CheckDistanceInScope(const knowhere::DataSet& result, float low_bound, float hig return true; } +inline std::unordered_map>> +GenerateScalarInfo(size_t n) { + std::vector> scalar_info; + scalar_info.reserve(2); + std::vector scalar1; + scalar1.reserve(n / 2); + std::vector scalar2; + scalar2.reserve(n - n / 2); + for (size_t i = 0; i < n; ++i) { + if (i % 2 == 0) { + scalar2.emplace_back(i); + } else { + scalar1.emplace_back(i); + } + } + scalar_info.emplace_back(std::move(scalar1)); + scalar_info.emplace_back(std::move(scalar2)); + std::unordered_map>> scalar_map; + scalar_map[0] = std::move(scalar_info); + return scalar_map; +} + +inline std::vector +GenerateBitsetByScalarInfoAndFirstTBits(const std::vector& scalar, size_t n, size_t t) { + assert(scalar.size() <= n); + assert(t >= 0 && t <= n - scalar.size()); + std::vector data((n + 8 - 1) / 8, 0); + // set bits by scalar info + for (size_t i = 0; i < scalar.size(); ++i) { + data[scalar[i] >> 3] |= (0x1 << (scalar[i] & 0x7)); + } + size_t count = 0; + for (size_t i = 0; i < n; ++i) { + if (count == t) { + break; + } + // already set, skip + if (data[i >> 3] & (0x1 << (i & 0x7))) { + continue; + } + data[i >> 3] |= (0x1 << (i & 0x7)); + ++count; + } + return data; +} + // Return a n-bits bitset data with first t bits set to true inline std::vector GenerateBitsetWithFirstTbitsSet(size_t n, size_t t) { diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index c2bf84b0c..7504114e8 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -65,6 +65,29 @@ namespace faiss { +uint32_t read_value(IOReader* f) { + uint32_t h; + READ1(h) + return h; +} + +void read_vector(std::vector& v, IOReader* f) { + READVECTOR(v); +} + +// "IHMV" is a special header for faiss hnsw to indicate whether mv or not +bool read_is_mv(IOReader* f) { + uint32_t h; + READ1(h); + return h == fourcc("IHMV"); +} + +bool read_is_mv(const char* fname) { + FileIOReader f(fname); + return read_is_mv(&f); +} + + template void read_vector(VectorT& target, IOReader* f) { // is it a mmap-enabled reader? diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index 6a73729ed..19e2746c8 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -1102,6 +1102,20 @@ void write_index(const Index* idx, const char* fname, int io_flags) { write_index(idx, &writer, io_flags); } +void write_value(uint32_t v, IOWriter* f) { + WRITE1(v); +} + +void write_vector(const std::vector& v, IOWriter* f) { + WRITEVECTOR(v); +} + +// "IHMV" is a special header for faiss hnsw to indicate whether mv or not +void write_mv(IOWriter* f) { + uint32_t h = fourcc("IHMV"); + WRITE1(h); +} + // write index for offset-only index void write_index_nm(const Index *idx, IOWriter *f) { if(const IndexIVFFlat * ivfl = diff --git a/thirdparty/faiss/faiss/index_io.h b/thirdparty/faiss/faiss/index_io.h index 8e08746ca..2f68a0bbb 100644 --- a/thirdparty/faiss/faiss/index_io.h +++ b/thirdparty/faiss/faiss/index_io.h @@ -97,6 +97,15 @@ InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0); // for backward compatibility Index *read_index_nm(IOReader *f, int io_flags = 0); void write_index_nm(const Index* idx, IOWriter* writer); + +// additional helper function for knowhere +bool read_is_mv(IOReader* reader); +bool read_is_mv(const char* fname); +void write_vector(const std::vector& v, IOWriter* writer); +void read_vector(std::vector& v, IOReader* f); +uint32_t read_value(IOReader *f); +void write_value(uint32_t v, IOWriter* writer); +void write_mv(IOWriter* writer); } // namespace faiss #endif