From 1dc4a8ad7feb06f5e6a3e7ec02dd3a6027079739 Mon Sep 17 00:00:00 2001 From: Yudong Cai Date: Tue, 19 Sep 2023 12:02:06 +0800 Subject: [PATCH] Support IVF_FLAT backward compatible when cosine Signed-off-by: Yudong Cai --- benchmark/hdf5/benchmark_knowhere.h | 8 ++++---- include/knowhere/config.h | 6 +++++- include/knowhere/utils.h | 5 ++++- src/common/utils.cc | 8 +++++++- src/index/ivf/ivf.cc | 3 ++- thirdparty/faiss/faiss/IndexIVFFlat.cpp | 2 +- .../faiss/faiss/invlists/InvertedLists.cpp | 18 +++++++++++++++++- .../faiss/faiss/invlists/InvertedLists.h | 4 +++- 8 files changed, 43 insertions(+), 11 deletions(-) diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index a6bc86128..cfc0cd4f9 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -45,7 +45,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { } void - read_index(knowhere::Index& index, const std::string& filename) { + read_index(knowhere::Index& index, const std::string& filename, const knowhere::Json& conf) { FileIOReader reader(filename); int64_t file_size = reader.size(); if (file_size < 0) { @@ -79,7 +79,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { bin->size = dim_ * nb_ * sizeof(float); binary_set.Append("RAW_DATA", bin); - index.Deserialize(binary_set); + index.Deserialize(binary_set, conf); } std::string @@ -98,7 +98,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { try { printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str()); - read_index(index_, index_file_name); + read_index(index_, index_file_name, conf); } catch (...) { printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_); @@ -120,7 +120,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { try { printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); - read_index(golden_index_, golden_index_file_name); + read_index(golden_index_, golden_index_file_name, conf); } catch (...) { printf("[%.3f s] Building golden index on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_); diff --git a/include/knowhere/config.h b/include/knowhere/config.h index f1510112d..5b6e6e1c2 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -507,7 +507,11 @@ class BaseConfig : public Config { CFG_BOOL enable_mmap; CFG_BOOL for_tuning; KNOHWERE_DECLARE_CONFIG(BaseConfig) { - KNOWHERE_CONFIG_DECLARE_FIELD(metric_type).set_default("L2").description("metric type").for_train_and_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) + .set_default("L2") + .description("metric type") + .for_train_and_search() + .for_deserialize(); KNOWHERE_CONFIG_DECLARE_FIELD(k) .set_default(10) .description("search for top k similar vector.") diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 43e5f6a0d..eba5437ad 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -67,6 +67,9 @@ round_down(const T value, const T align) { } extern void -ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size); +ConvertIVFFlatIfNeeded(const BinarySet& binset, + const MetricType metric_type, + const uint8_t* raw_data, + const size_t raw_size); } // namespace knowhere diff --git a/src/common/utils.cc b/src/common/utils.cc index df62aa4db..d9e180dde 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -73,7 +73,10 @@ CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) { } void -ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size) { +ConvertIVFFlatIfNeeded(const BinarySet& binset, + const MetricType metric_type, + const uint8_t* raw_data, + const size_t raw_size) { std::vector names = {"IVF", // compatible with knowhere-1.x knowhere::IndexEnum::INDEX_FAISS_IVFFLAT}; auto binary = binset.GetByNames(names); @@ -92,6 +95,9 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const s faiss::read_ivf_header(ivfl.get(), &reader); ivfl->code_size = ivfl->d * sizeof(float); + // is_cosine is not defined in IVF_FLAT_NM, so mark it from config + ivfl->is_cosine = IsMetricType(metric_type, knowhere::metric::COSINE); + auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) - sizeof(ivfl->invlists->code_size); auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t); diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 30b29ee71..cd626327d 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -692,7 +692,8 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { if constexpr (std::is_same::value) { auto raw_binary = binset.GetByName("RAW_DATA"); if (raw_binary != nullptr) { - ConvertIVFFlatIfNeeded(binset, raw_binary->data.get(), raw_binary->size); + const BaseConfig& base_cfg = static_cast(config); + ConvertIVFFlatIfNeeded(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size); // after conversion, binary size and data will be updated reader.data_ = binary->data.get(); reader.total_ = binary->size; diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.cpp b/thirdparty/faiss/faiss/IndexIVFFlat.cpp index 755db3247..a7d8552d8 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFlat.cpp @@ -47,7 +47,7 @@ void IndexIVFFlat::restore_codes( const uint8_t* raw_data, const size_t raw_size) { auto ails = dynamic_cast(invlists); - ails->restore_codes(raw_data, raw_size); + ails->restore_codes(raw_data, raw_size, is_cosine); } void IndexIVFFlat::train(idx_t n, const float* x) { diff --git a/thirdparty/faiss/faiss/invlists/InvertedLists.cpp b/thirdparty/faiss/faiss/invlists/InvertedLists.cpp index ba3c0ff28..e3cd101a0 100644 --- a/thirdparty/faiss/faiss/invlists/InvertedLists.cpp +++ b/thirdparty/faiss/faiss/invlists/InvertedLists.cpp @@ -14,6 +14,7 @@ #include #include +#include #include //TODO: refactor to decouple dependency between CPU and Cuda, or upgrade faiss @@ -273,21 +274,36 @@ void ArrayInvertedLists::resize(size_t list_no, size_t new_size) { codes[list_no].resize(new_size * code_size); } +// temp code for IVF_FLAT_NM backward compatibility void ArrayInvertedLists::restore_codes( const uint8_t* raw_data, - const size_t raw_size) { + const size_t raw_size, + const bool is_cosine) { size_t total = 0; + with_norm = is_cosine; codes.resize(nlist); + if (is_cosine) { + code_norms.resize(nlist); + } for (size_t i = 0; i < nlist; i++) { auto list_size = ids[i].size(); total += list_size; codes[i].resize(list_size * code_size); + if (is_cosine) { + code_norms[i].resize(list_size); + } uint8_t* dst = codes[i].data(); for (size_t j = 0; j < list_size; j++) { const uint8_t* src = raw_data + code_size * ids[i][j]; std::copy_n(src, code_size, dst); dst += code_size; } + if (is_cosine) { + fvec_norms_L2(code_norms[i].data(), + (const float*)codes[i].data(), + code_size / sizeof(float), + list_size); + } } assert(total * code_size == raw_size); } diff --git a/thirdparty/faiss/faiss/invlists/InvertedLists.h b/thirdparty/faiss/faiss/invlists/InvertedLists.h index c18d7e180..54b98d628 100644 --- a/thirdparty/faiss/faiss/invlists/InvertedLists.h +++ b/thirdparty/faiss/faiss/invlists/InvertedLists.h @@ -299,7 +299,9 @@ struct ArrayInvertedLists : InvertedLists { const uint8_t* get_codes(size_t list_no) const override; const idx_t* get_ids(size_t list_no) const override; - void restore_codes(const uint8_t* raw_data, const size_t raw_size); + void restore_codes(const uint8_t* raw_data, + const size_t raw_size, + const bool is_cosine); const float* get_code_norms(size_t list_no, size_t offset) const override; void release_code_norms(size_t list_no, const float* codes) const override;