From 399ca142b06f38f183db85016341e6cfeffafe1d Mon Sep 17 00:00:00 2001 From: presburger Date: Tue, 19 Sep 2023 12:11:18 +0800 Subject: [PATCH 1/6] update cmake (#91) Signed-off-by: Yusheng.Ma --- ci/docker/builder/cpu/ubuntu20.04/Dockerfile | 4 ++-- ci/docker/builder/gpu/ubuntu20.04/Dockerfile | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/docker/builder/cpu/ubuntu20.04/Dockerfile b/ci/docker/builder/cpu/ubuntu20.04/Dockerfile index 46d70f06e..55552eb0a 100644 --- a/ci/docker/builder/cpu/ubuntu20.04/Dockerfile +++ b/ci/docker/builder/cpu/ubuntu20.04/Dockerfile @@ -1,7 +1,7 @@ FROM ubuntu:20.04 -ENV CMAKE_VERSION="v3.23" -ENV CMAKE_TAR="cmake-3.23.0-linux-x86_64.tar.gz" +ENV CMAKE_VERSION="v3.27" +ENV CMAKE_TAR="cmake-3.27.5-linux-x86_64.tar.gz" RUN apt-get update && apt-get install -y --no-install-recommends wget curl g++ gcc ca-certificates\ make ccache python3-dev gfortran python3-setuptools swig libopenblas-dev pip \ && apt-get remove --purge -y \ diff --git a/ci/docker/builder/gpu/ubuntu20.04/Dockerfile b/ci/docker/builder/gpu/ubuntu20.04/Dockerfile index 16dfee1d8..a7b35679c 100644 --- a/ci/docker/builder/gpu/ubuntu20.04/Dockerfile +++ b/ci/docker/builder/gpu/ubuntu20.04/Dockerfile @@ -1,7 +1,7 @@ FROM nvidia/cuda:11.6.0-devel-ubuntu20.04 -ENV CMAKE_VERSION="v3.23" -ENV CMAKE_TAR="cmake-3.23.1-linux-x86_64.tar.gz" +ENV CMAKE_VERSION="v3.27" +ENV CMAKE_TAR="cmake-3.27.5-linux-x86_64.tar.gz" RUN apt-get update && apt-get install -y --no-install-recommends wget curl g++ gcc ca-certificates\ make ccache python3-dev gfortran python3-setuptools swig libopenblas-dev pip \ && apt-get remove --purge -y \ From c02f58414ee29c30b187320ac8381840be7c7648 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Tue, 19 Sep 2023 04:17:17 +0000 Subject: [PATCH 2/6] changes needed for clang-16 to be working (#72) Signed-off-by: Alexandr Guzhva --- conanfile.py | 2 +- include/knowhere/bitsetview.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/conanfile.py b/conanfile.py index 029c372c4..cac125c9a 100644 --- a/conanfile.py +++ b/conanfile.py @@ -81,7 +81,7 @@ def configure(self): self.options.rm_safe("fPIC") def requirements(self): - self.requires("boost/1.78.0") + self.requires("boost/1.83.0") self.requires("glog/0.4.0") self.requires("nlohmann_json/3.11.2") self.requires("openssl/1.1.1t") diff --git a/include/knowhere/bitsetview.h b/include/knowhere/bitsetview.h index 1dfa35243..40d0e6119 100644 --- a/include/knowhere/bitsetview.h +++ b/include/knowhere/bitsetview.h @@ -13,6 +13,7 @@ #define BITSET_H #include +#include #include #include From b45581635bd7c2ddf7d4bcd80292891a3278849e Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Wed, 20 Sep 2023 11:15:17 +0800 Subject: [PATCH 3/6] Fix DiskANN LRU Set Invalid Medoid (#98) Signed-off-by: Patrick Weizhi Xu (cherry picked from commit 76723f98461265b281c69b6d010f6dc3d0d2f4b0) --- thirdparty/DiskANN/src/pq_flash_index.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index 016734ee9..d40e0008e 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -1436,7 +1436,7 @@ namespace diskann { } } } - if (k_search > 0) { + if (k_search > 0 && indices[0] != -1) { lru_cache.put(vec_hash, indices[0]); } From 1dcd84b4794812690322baf3082b3a5a3855d65a Mon Sep 17 00:00:00 2001 From: foxspy Date: Wed, 20 Sep 2023 18:07:25 +0800 Subject: [PATCH 4/6] add version support (#38) Signed-off-by: xianliang --- .gitignore | 2 + benchmark/hdf5/benchmark_float_bitset.cpp | 3 +- .../hdf5/benchmark_float_range_bitset.cpp | 3 +- benchmark/hdf5/benchmark_knowhere.h | 7 +- include/knowhere/comp/index_param.h | 4 + include/knowhere/comp/knowhere_config.h | 6 ++ include/knowhere/comp/local_file_manager.h | 10 ++ include/knowhere/config.h | 16 ++++ include/knowhere/expected.h | 46 ++++++++- include/knowhere/factory.h | 6 +- include/knowhere/index.h | 10 ++ include/knowhere/utils.h | 3 + include/knowhere/version.h | 96 +++++++++++++++++++ python/knowhere/__init__.py | 4 +- python/knowhere/knowhere.i | 10 +- src/common/comp/knowhere_config.cc | 11 +++ src/common/config.cc | 1 + src/common/factory.cc | 9 +- src/common/utils.cc | 5 + src/index/diskann/diskann.cc | 36 ++++--- src/index/diskann/diskann_config.h | 9 -- src/index/flat/flat.cc | 15 +-- src/index/gpu/flat_gpu/flat_gpu.cc | 6 +- src/index/gpu/ivf_gpu/ivf_gpu.cc | 14 +-- src/index/hnsw/hnsw.cc | 6 +- src/index/ivf/ivf.cc | 53 +++++----- tests/python/test_diskann.py | 3 +- tests/python/test_index_load_and_save.py | 5 +- tests/python/test_index_random.py | 3 +- tests/python/test_index_with_random.py | 3 +- tests/python/test_index_with_sift.py | 3 +- tests/ut/test_diskann.cc | 33 ++++--- tests/ut/test_feder.cc | 6 +- tests/ut/test_get_vector.cc | 22 +++-- tests/ut/test_iterator.cc | 11 ++- tests/ut/test_ivfflat_cc.cc | 10 +- tests/ut/test_mmap.cc | 15 +-- tests/ut/test_search.cc | 25 ++--- tests/ut/test_simd.cc | 3 +- tests/ut/test_utils.cc | 18 ++++ tests/ut/utils.h | 9 ++ 41 files changed, 422 insertions(+), 138 deletions(-) create mode 100644 include/knowhere/version.h diff --git a/.gitignore b/.gitignore index 5b5a5e347..adf2c2de5 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,8 @@ venv/ **/knowhere/swigknowhere.py wheelhouse/* +**/thirdparty/cardinal + *.bin diff --git a/benchmark/hdf5/benchmark_float_bitset.cpp b/benchmark/hdf5/benchmark_float_bitset.cpp index 66388aed8..9f322ed33 100644 --- a/benchmark/hdf5/benchmark_float_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_bitset.cpp @@ -235,7 +235,8 @@ TEST_F(Benchmark_float_bitset, TEST_DISKANN) { std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, diskann_index_pack); + auto version = knowhere::Version::GetCurrentVersion().VersionCode(); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_float_range_bitset.cpp b/benchmark/hdf5/benchmark_float_range_bitset.cpp index ccad0d183..498a95383 100644 --- a/benchmark/hdf5/benchmark_float_range_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_range_bitset.cpp @@ -236,7 +236,8 @@ TEST_F(Benchmark_float_range_bitset, TEST_DISKANN) { std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, diskann_index_pack); + auto version = knowhere::Version::GetCurrentVersion().VersionCode(); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index a6bc86128..1756147a7 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -20,6 +20,7 @@ #include "knowhere/config.h" #include "knowhere/factory.h" #include "knowhere/index.h" +#include "knowhere/version.h" class Benchmark_knowhere : public Benchmark_hdf5 { public: @@ -93,8 +94,9 @@ class Benchmark_knowhere : public Benchmark_hdf5 { knowhere::Index create_index(const std::string& index_file_name, const knowhere::Json& conf) { + auto version = knowhere::Version::GetCurrentVersion().VersionCode(); printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str()); - index_ = knowhere::IndexFactory::Instance().Create(index_type_); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); try { printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str()); @@ -112,11 +114,12 @@ class Benchmark_knowhere : public Benchmark_hdf5 { knowhere::Index create_golden_index(const knowhere::Json& conf) { + auto version = knowhere::Version::GetCurrentVersion().VersionCode(); golden_index_type_ = knowhere::IndexEnum::INDEX_FAISS_IDMAP; std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index"; printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str()); - golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_); + golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); try { printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index ce9b2ba7f..ca1275f01 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -48,6 +48,10 @@ constexpr const char* INDEX_DISKANN = "DISKANN"; namespace meta { constexpr const char* INDEX_TYPE = "index_type"; constexpr const char* METRIC_TYPE = "metric_type"; +constexpr const char* DATA_PATH = "data_path"; +constexpr const char* INDEX_PREFIX = "index_prefix"; +constexpr const char* INDEX_ENGINE_VERSION = "index_engine_version"; +constexpr const char* RETRIEVE_FRIENDLY = "retrieve_friendly"; constexpr const char* DIM = "dim"; constexpr const char* TENSOR = "tensor"; constexpr const char* ROWS = "rows"; diff --git a/include/knowhere/comp/knowhere_config.h b/include/knowhere/comp/knowhere_config.h index 6ea15844e..de845b6cf 100644 --- a/include/knowhere/comp/knowhere_config.h +++ b/include/knowhere/comp/knowhere_config.h @@ -84,6 +84,12 @@ class KnowhereConfig { static bool SetAioContextPool(size_t num_ctx); + static void + SetBuildThreadPoolSize(size_t num_threads); + + static void + SetSearchThreadPoolSize(size_t num_threads); + /** * init GPU Resource */ diff --git a/include/knowhere/comp/local_file_manager.h b/include/knowhere/comp/local_file_manager.h index 85bf35583..6c301799f 100644 --- a/include/knowhere/comp/local_file_manager.h +++ b/include/knowhere/comp/local_file_manager.h @@ -13,6 +13,16 @@ #include +#if __has_include() +#include +namespace fs = std::filesystem; +#elif __has_include() +#include +namespace fs = std::experimental::filesystem; +#else +error "Missing the header." +#endif + #include "knowhere/file_manager.h" namespace knowhere { /** diff --git a/include/knowhere/config.h b/include/knowhere/config.h index f1510112d..ca22fb021 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -501,6 +501,9 @@ class BaseConfig : public Config { CFG_STRING metric_type; CFG_INT k; CFG_INT num_build_thread; + CFG_BOOL retrieve_friendly; + CFG_STRING data_path; + CFG_STRING index_prefix; CFG_FLOAT radius; CFG_FLOAT range_filter; CFG_BOOL trace_visit; @@ -508,6 +511,19 @@ class BaseConfig : public Config { CFG_BOOL for_tuning; KNOHWERE_DECLARE_CONFIG(BaseConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(metric_type).set_default("L2").description("metric type").for_train_and_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(retrieve_friendly) + .description("whether the index holds raw data for fast retrieval") + .set_default(false) + .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(data_path) + .description("raw data path.") + .allow_empty_without_default() + .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(index_prefix) + .description("path prefix to load or save index.") + .allow_empty_without_default() + .for_train() + .for_deserialize(); KNOWHERE_CONFIG_DECLARE_FIELD(k) .set_default(10) .description("search for top k similar vector.") diff --git a/include/knowhere/expected.h b/include/knowhere/expected.h index 430a69dc6..6109818f9 100644 --- a/include/knowhere/expected.h +++ b/include/knowhere/expected.h @@ -34,13 +34,57 @@ enum class Status { hnsw_inner_error = 12, malloc_error = 13, diskann_inner_error = 14, - diskann_file_error = 15, + disk_file_error = 15, invalid_value_in_json = 16, arithmetic_overflow = 17, raft_inner_error = 18, invalid_binary_set = 19, }; +inline std::string +Status2String(knowhere::Status status) { + switch (status) { + case knowhere::Status::invalid_args: + return "invalid args"; + case knowhere::Status::invalid_param_in_json: + return "invalid param in json"; + case knowhere::Status::out_of_range_in_json: + return "out of range in json"; + case knowhere::Status::type_conflict_in_json: + return "type conflict in json"; + case knowhere::Status::invalid_metric_type: + return "invalid metric type"; + case knowhere::Status::empty_index: + return "empty index"; + case knowhere::Status::not_implemented: + return "not implemented"; + case knowhere::Status::index_not_trained: + return "index not trained"; + case knowhere::Status::index_already_trained: + return "index already trained"; + case knowhere::Status::faiss_inner_error: + return "faiss inner error"; + case knowhere::Status::hnsw_inner_error: + return "hnsw inner error"; + case knowhere::Status::malloc_error: + return "malloc error"; + case knowhere::Status::diskann_inner_error: + return "diskann inner error"; + case knowhere::Status::disk_file_error: + return "disk file error"; + case knowhere::Status::invalid_value_in_json: + return "invalid value in json"; + case knowhere::Status::arithmetic_overflow: + return "arithmetic overflow"; + case knowhere::Status::raft_inner_error: + return "raft inner error"; + case knowhere::Status::invalid_binary_set: + return "invalid binary set"; + default: + return "unexpected status"; + } +} + template class expected { public: diff --git a/include/knowhere/factory.h b/include/knowhere/factory.h index f53c7ffc4..aed461f78 100644 --- a/include/knowhere/factory.h +++ b/include/knowhere/factory.h @@ -22,14 +22,14 @@ namespace knowhere { class IndexFactory { public: Index - Create(const std::string& name, const Object& object = nullptr); + Create(const std::string& name, const std::string& version, const Object& object = nullptr); const IndexFactory& - Register(const std::string& name, std::function(const Object&)> func); + Register(const std::string& name, std::function(const std::string& version, const Object&)> func); static IndexFactory& Instance(); private: - typedef std::map(const Object&)>> FuncMap; + typedef std::map(const std::string&, const Object&)>> FuncMap; IndexFactory(); static FuncMap& MapInstance(); diff --git a/include/knowhere/index.h b/include/knowhere/index.h index 066e6e845..164eea88c 100644 --- a/include/knowhere/index.h +++ b/include/knowhere/index.h @@ -107,6 +107,16 @@ class Index { return *this; } + T1* + Node() { + return node; + } + + const T1* + Node() const { + return node; + } + template Index Cast() { diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 43e5f6a0d..78ff18a1a 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -69,4 +69,7 @@ round_down(const T value, const T align) { extern void ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size); +bool +UseDiskLoad(const std::string& index_type, const std::string& /*version*/); + } // namespace knowhere diff --git a/include/knowhere/version.h b/include/knowhere/version.h new file mode 100644 index 000000000..006cbde9d --- /dev/null +++ b/include/knowhere/version.h @@ -0,0 +1,96 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "log.h" + +namespace knowhere { +namespace { +static const std::regex version_regex(R"(^knowhere-v(\d+)$)"); +static constexpr const char* default_version = "knowhere-v0"; +static constexpr const char* minimal_vesion = "knowhere-v0"; +static constexpr const char* current_version = "knowhere-v0"; +} // namespace + +class Version { + public: + explicit Version(const std::string& version_code_) : version_code(version_code_) { + try { + std::smatch matches; + if (std::regex_match(version_code_, matches, version_regex)) { + version_ = std::stoi(matches[1]); + } else { + LOG_KNOWHERE_ERROR_ << "unexpected version code : " << version_code_; + } + } catch (std::exception& e) { + LOG_KNOWHERE_ERROR_ << "version code " << version_code_ << " parse failed : " << e.what(); + } + } + + bool + Valid() { + return version_ != unexpected_version_num; + }; + + const std::string& + VersionCode() const { + return version_code; + } + + static bool + VersionCheck(const std::string& version) { + try { + return std::regex_match(version.c_str(), version_regex); + } catch (std::regex_error& e) { + LOG_KNOWHERE_ERROR_ << "unexpected index version : " << version; + } + return false; + } + + // used when version is not set + static inline Version + GetDefaultVersion() { + return Version(default_version); + } + + // the current version (newest version support) + static inline Version + GetCurrentVersion() { + return Version(current_version); + } + + // the minimal version (oldest version support) + static inline Version + GetMinimalSupport() { + return Version(minimal_vesion); + } + + static inline bool + VersionSupport(const Version& version) { + return VersionCheck(version.version_code) && GetMinimalSupport() <= version && version <= GetCurrentVersion(); + } + + friend bool + operator<=(const Version& lhs, const Version& rhs) { + return lhs.version_ <= rhs.version_; + } + + private: + static constexpr int32_t unexpected_version_num = -1; + const std::string version_code; + int32_t version_ = unexpected_version_num; +}; + +} // namespace knowhere diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 3dc156ba0..c27418b5b 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -3,8 +3,8 @@ from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView import numpy as np -def CreateIndex(name): - return swigknowhere.IndexWrap(name) +def CreateIndex(name, version): + return swigknowhere.IndexWrap(name, version) def CreateBitSet(bits_num): diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index f80592c7d..b20b03193 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -28,6 +28,8 @@ typedef uint64_t size_t; #endif #include #include +#include +#include #include using namespace knowhere; %} @@ -108,14 +110,14 @@ public: class IndexWrap { public: - IndexWrap(const std::string& name) { + IndexWrap(const std::string& name, const std::string& version) { GILReleaser rel; - if (name == std::string("DISKANN")) { + if (knowhere::UseDiskLoad(name, version)) { std::shared_ptr file_manager = std::make_shared(); auto diskann_pack = knowhere::Pack(file_manager); - idx = IndexFactory::Instance().Create(name, diskann_pack); + idx = IndexFactory::Instance().Create(name, version, diskann_pack); } else { - idx = IndexFactory::Instance().Create(name); + idx = IndexFactory::Instance().Create(name, version); } } diff --git a/src/common/comp/knowhere_config.cc b/src/common/comp/knowhere_config.cc index b04917a8b..126529d21 100644 --- a/src/common/comp/knowhere_config.cc +++ b/src/common/comp/knowhere_config.cc @@ -18,6 +18,7 @@ #endif #include "faiss/Clustering.h" #include "faiss/utils/distances.h" +#include "knowhere/comp/thread_pool.h" #include "knowhere/log.h" #ifdef KNOWHERE_WITH_GPU #include "index/gpu/gpu_res_mgr.h" @@ -132,6 +133,16 @@ KnowhereConfig::SetAioContextPool(size_t num_ctx) { return true; } +void +KnowhereConfig::SetBuildThreadPoolSize(size_t num_threads) { + knowhere::ThreadPool::InitGlobalBuildThreadPool(num_threads); +} + +void +KnowhereConfig::SetSearchThreadPoolSize(size_t num_threads) { + knowhere::ThreadPool::InitGlobalSearchThreadPool(num_threads); +} + void KnowhereConfig::InitGPUResource(int64_t gpu_id, int64_t res_num) { #ifdef KNOWHERE_WITH_GPU diff --git a/src/common/config.cc b/src/common/config.cc index a25664ba7..f54c705a3 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -48,6 +48,7 @@ static const std::unordered_set ext_legal_json_keys = {"metric_type "round_decimal", "offset", "for_tuning", + "index_engine_version", "reorder_k"}; Status diff --git a/src/common/factory.cc b/src/common/factory.cc index 36cf8efba..caad57498 100644 --- a/src/common/factory.cc +++ b/src/common/factory.cc @@ -14,15 +14,16 @@ namespace knowhere { Index -IndexFactory::Create(const std::string& name, const Object& object) { +IndexFactory::Create(const std::string& name, const std::string& version, const Object& object) { auto& func_mapping_ = MapInstance(); assert(func_mapping_.find(name) != func_mapping_.end()); - LOG_KNOWHERE_INFO_ << "create knowhere index " << name; - return func_mapping_[name](object); + LOG_KNOWHERE_INFO_ << "create knowhere index " << name << " with version " << version; + return func_mapping_[name](version, object); } const IndexFactory& -IndexFactory::Register(const std::string& name, std::function(const Object&)> func) { +IndexFactory::Register(const std::string& name, + std::function(const std::string& version, const Object&)> func) { auto& func_mapping_ = MapInstance(); func_mapping_[name] = func; return *this; diff --git a/src/common/utils.cc b/src/common/utils.cc index df62aa4db..2c5ba256f 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -119,4 +119,9 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const s } } +bool +UseDiskLoad(const std::string& index_type, const std::string& /*version*/) { + return !index_type.compare(IndexEnum::INDEX_DISKANN); +} + } // namespace knowhere diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index 73b60f8c8..d35db96f8 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -40,7 +40,7 @@ class DiskANNIndexNode : public IndexNode { static_assert(std::is_same_v, "DiskANN only support float"); public: - DiskANNIndexNode(const Object& object) : is_prepared_(false), dim_(-1), count_(-1) { + DiskANNIndexNode(const std::string& version, const Object& object) : is_prepared_(false), dim_(-1), count_(-1) { assert(typeid(object) == typeid(Pack>)); auto diskann_index_pack = dynamic_cast>*>(&object); assert(diskann_index_pack != nullptr); @@ -79,8 +79,8 @@ class DiskANNIndexNode : public IndexNode { Status Serialize(BinarySet& binset) const override { - LOG_KNOWHERE_ERROR_ << "DiskANN doesn't support Serialize."; - return Status::not_implemented; + LOG_KNOWHERE_INFO_ << "DiskANN does nothing for serialize"; + return Status::success; } Status @@ -182,7 +182,7 @@ TryDiskANNCall(std::function&& diskann_call) { return Status::success; } catch (const diskann::FileException& e) { LOG_KNOWHERE_ERROR_ << "DiskANN File Exception: " << e.what(); - return Status::diskann_file_error; + return Status::disk_file_error; } catch (const diskann::ANNException& e) { LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what(); return Status::diskann_inner_error; @@ -263,13 +263,17 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << build_conf.metric_type.value(); return Status::invalid_metric_type; } + if (!(build_conf.index_prefix.has_value() && build_conf.data_path.has_value())) { + LOG_KNOWHERE_ERROR_ << "DiskANN file path for build is empty." << std::endl; + return Status::invalid_param_in_json; + } if (AnyIndexFileExist(build_conf.index_prefix.value())) { LOG_KNOWHERE_ERROR_ << "This index prefix already has index files." << std::endl; - return Status::diskann_file_error; + return Status::disk_file_error; } if (!LoadFile(build_conf.data_path.value())) { LOG_KNOWHERE_ERROR_ << "Failed load the raw data before building." << std::endl; - return Status::diskann_file_error; + return Status::disk_file_error; } auto& data_path = build_conf.data_path.value(); index_prefix_ = build_conf.index_prefix.value(); @@ -315,13 +319,13 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { for (auto& filename : GetNecessaryFilenames(index_prefix_, need_norm, true, true)) { if (!AddFile(filename)) { LOG_KNOWHERE_ERROR_ << "Failed to add file " << filename << "."; - return Status::diskann_file_error; + return Status::disk_file_error; } } for (auto& filename : GetOptionalFilenames(index_prefix_)) { if (file_exists(filename) && !AddFile(filename)) { LOG_KNOWHERE_ERROR_ << "Failed to add file " << filename << "."; - return Status::diskann_file_error; + return Status::disk_file_error; } } @@ -340,6 +344,10 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { if (is_prepared_.load()) { return Status::success; } + if (!(prep_conf.index_prefix.has_value())) { + LOG_KNOWHERE_ERROR_ << "DiskANN file path for deserialize is empty." << std::endl; + return Status::invalid_param_in_json; + } index_prefix_ = prep_conf.index_prefix.value(); bool is_ip = IsMetricType(prep_conf.metric_type.value(), knowhere::metric::IP); bool need_norm = IsMetricType(prep_conf.metric_type.value(), knowhere::metric::IP) || @@ -359,17 +367,17 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { index_prefix_, need_norm, prep_conf.search_cache_budget_gb.value() > 0 && !prep_conf.use_bfs_cache.value(), prep_conf.warm_up.value())) { if (!LoadFile(filename)) { - return Status::diskann_file_error; + return Status::disk_file_error; } } for (auto& filename : GetOptionalFilenames(index_prefix_)) { auto is_exist_op = file_manager_->IsExisted(filename); if (!is_exist_op.has_value()) { LOG_KNOWHERE_ERROR_ << "Failed to check existence of file " << filename << "."; - return Status::diskann_file_error; + return Status::disk_file_error; } if (is_exist_op.value() && !LoadFile(filename)) { - return Status::diskann_file_error; + return Status::disk_file_error; } } @@ -464,7 +472,7 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); }) != Status::success) { LOG_KNOWHERE_ERROR_ << "Failed to load warmup file for DiskANN."; - return Status::diskann_file_error; + return Status::disk_file_error; } std::vector warmup_result_ids_64(warmup_num, 0); std::vector warmup_result_dists(warmup_num, 0); @@ -704,5 +712,7 @@ DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, const uint6 return num_nodes_to_cache; } -KNOWHERE_REGISTER_GLOBAL(DISKANN, [](const Object& object) { return Index>::Create(object); }); +KNOWHERE_REGISTER_GLOBAL(DISKANN, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); } // namespace knowhere diff --git a/src/index/diskann/diskann_config.h b/src/index/diskann/diskann_config.h index e25dedb68..94638ce1b 100644 --- a/src/index/diskann/diskann_config.h +++ b/src/index/diskann/diskann_config.h @@ -25,10 +25,6 @@ constexpr const CFG_INT::value_type kDefaultSearchListSizeForBuild = 128; class DiskANNConfig : public BaseConfig { public: - // Path prefix to load or save DiskANN - CFG_STRING index_prefix; - // The path to the raw data file. Raw data's format should be [row_num(4 bytes) | dim_num(4 bytes) | vectors]. - CFG_STRING data_path; // This is the degree of the graph index, typically between 60 and 150. Larger R will result in larger indices and // longer indexing times, but better search quality. CFG_INT max_degree; @@ -84,11 +80,6 @@ class DiskANNConfig : public BaseConfig { .description("metric type") .for_train_and_search() .for_deserialize(); - KNOWHERE_CONFIG_DECLARE_FIELD(index_prefix) - .description("path to load or save Diskann.") - .for_train() - .for_deserialize(); - KNOWHERE_CONFIG_DECLARE_FIELD(data_path).description("raw data path.").for_train(); KNOWHERE_CONFIG_DECLARE_FIELD(max_degree) .description("the degree of the graph index.") .set_default(48) diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index 4dc62f117..5aaa2aeaf 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -26,7 +26,7 @@ namespace knowhere { template class FlatIndexNode : public IndexNode { public: - FlatIndexNode(const Object&) : index_(nullptr) { + FlatIndexNode(const std::string& version, const Object& object) : index_(nullptr) { static_assert(std::is_same::value || std::is_same::value, "not support"); search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); @@ -349,13 +349,14 @@ class FlatIndexNode : public IndexNode { std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(FLAT, - [](const Object& object) { return Index>::Create(object); }); -KNOWHERE_REGISTER_GLOBAL(BINFLAT, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(FLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(BIN_FLAT, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(BINFLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(BIN_FLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); } // namespace knowhere diff --git a/src/index/gpu/flat_gpu/flat_gpu.cc b/src/index/gpu/flat_gpu/flat_gpu.cc index 5f4ae930f..f0f4d0bd4 100644 --- a/src/index/gpu/flat_gpu/flat_gpu.cc +++ b/src/index/gpu/flat_gpu/flat_gpu.cc @@ -23,7 +23,7 @@ namespace knowhere { class GpuFlatIndexNode : public IndexNode { public: - GpuFlatIndexNode(const Object& object) : index_(nullptr) { + GpuFlatIndexNode(const std::string& version, const Object& object) : index_(nullptr) { } Status @@ -189,6 +189,8 @@ class GpuFlatIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_FLAT, [](const Object& object) { return Index::Create(object); }); +KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_FLAT, [](const std::string& version, const Object& object) { + return Index::Create(version, object); +}); } // namespace knowhere diff --git a/src/index/gpu/ivf_gpu/ivf_gpu.cc b/src/index/gpu/ivf_gpu/ivf_gpu.cc index c9e51bffc..53c6422bb 100644 --- a/src/index/gpu/ivf_gpu/ivf_gpu.cc +++ b/src/index/gpu/ivf_gpu/ivf_gpu.cc @@ -49,7 +49,7 @@ struct KnowhereConfigType { template class GpuIvfIndexNode : public IndexNode { public: - GpuIvfIndexNode(const Object& object) : index_(nullptr) { + GpuIvfIndexNode(const std::string& version, const Object& object) : index_(nullptr) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value); } @@ -273,14 +273,14 @@ class GpuIvfIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_FLAT, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_FLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_PQ, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_PQ, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_SQ8, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_SQ8, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); } // namespace knowhere diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index a41da1535..d03494d1d 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -32,7 +32,7 @@ namespace knowhere { class HnswIndexNode : public IndexNode { public: - HnswIndexNode(const Object& object) : index_(nullptr) { + HnswIndexNode(const std::string& /*version*/, const Object& object) : index_(nullptr) { search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -517,6 +517,8 @@ class HnswIndexNode : public IndexNode { std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(HNSW, [](const Object& object) { return Index::Create(object); }); +KNOWHERE_REGISTER_GLOBAL(HNSW, [](const std::string& version, const Object& object) { + return Index::Create(version, object); +}); } // namespace knowhere diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 30b29ee71..f176b0dae 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -45,7 +45,7 @@ struct QuantizerT { template class IvfIndexNode : public IndexNode { public: - IvfIndexNode(const Object& object) : index_(nullptr) { + IvfIndexNode(const std::string& /*version*/, const Object& object) : index_(nullptr) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || @@ -732,36 +732,41 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& return Status::success; } -KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(BIN_IVF_FLAT, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(BIN_IVF_FLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(IVFFLAT, - [](const Object& object) { return Index>::Create(object); }); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT, - [](const Object& object) { return Index>::Create(object); }); -KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(IVFFLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(IVF_FLAT, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(SCANN, - [](const Object& object) { return Index>::Create(object); }); -KNOWHERE_REGISTER_GLOBAL(IVFPQ, - [](const Object& object) { return Index>::Create(object); }); -KNOWHERE_REGISTER_GLOBAL(IVF_PQ, - [](const Object& object) { return Index>::Create(object); }); - -KNOWHERE_REGISTER_GLOBAL(IVFSQ, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); -KNOWHERE_REGISTER_GLOBAL(IVF_SQ8, [](const Object& object) { - return Index>::Create(object); +KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(SCANN, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(IVFPQ, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(IVF_PQ, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); + +KNOWHERE_REGISTER_GLOBAL(IVFSQ, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(IVF_SQ8, [](const std::string& version, const Object& object) { + return Index>::Create(version, object); }); } // namespace knowhere diff --git a/tests/python/test_diskann.py b/tests/python/test_diskann.py index 3db4d3a57..1f09aa8fc 100644 --- a/tests/python/test_diskann.py +++ b/tests/python/test_diskann.py @@ -15,6 +15,7 @@ def fbin_write(x, fname): x.tofile(f) def test_index(gen_data, faiss_ans, recall, error): + version = "knowhere-v2.2.0" index_name = "DISKANN" diskann_dir = "diskann_test" data_path = os.path.join(diskann_dir, "diskann_data") @@ -62,7 +63,7 @@ def test_index(gen_data, faiss_ans, recall, error): } print(index_name, diskann_config["build_config"]) - diskann = knowhere.CreateIndex(index_name) + diskann = knowhere.CreateIndex(index_name, version) build_status = diskann.Build( knowhere.GetNullDataSet(), json.dumps(diskann_config["build_config"]), diff --git a/tests/python/test_index_load_and_save.py b/tests/python/test_index_load_and_save.py index fdf991c94..7b76ab513 100644 --- a/tests/python/test_index_load_and_save.py +++ b/tests/python/test_index_load_and_save.py @@ -16,7 +16,8 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): # simple load and save not work for ivf nm print(name, config) - build_idx = knowhere.CreateIndex(name) + version = "knowhere-v2.2.0" + build_idx = knowhere.CreateIndex(name, version) xb, xq = gen_data(10000, 100, 256) build_idx.Build( @@ -25,7 +26,7 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): ) binset = knowhere.GetBinarySet() build_idx.Serialize(binset) - search_idx = knowhere.CreateIndex(name) + search_idx = knowhere.CreateIndex(name, version) search_idx.Deserialize(binset) ans, _ = search_idx.Search( knowhere.ArrayToDataSet(xq), diff --git a/tests/python/test_index_random.py b/tests/python/test_index_random.py index f833a2fb3..077fe1854 100644 --- a/tests/python/test_index_random.py +++ b/tests/python/test_index_random.py @@ -62,7 +62,8 @@ @pytest.mark.parametrize("name,config", test_data) def test_index(gen_data, faiss_ans, recall, error, name, config): print(name, config) - idx = knowhere.CreateIndex(name) + version = "knowhere-v2.2.0" + idx = knowhere.CreateIndex(name, version) xb, xq = gen_data(10000, 100, 256) idx.Build( diff --git a/tests/python/test_index_with_random.py b/tests/python/test_index_with_random.py index b0569dc13..eeaa7f0e3 100644 --- a/tests/python/test_index_with_random.py +++ b/tests/python/test_index_with_random.py @@ -84,8 +84,9 @@ @pytest.mark.parametrize("name,config", test_data) def test_index(gen_data, faiss_ans, recall, error, name, config): + version = "knowhere-v2.2.0" print(name, config) - idx = knowhere.CreateIndex(name) + idx = knowhere.CreateIndex(name, version) xb, xq = gen_data(10000, 100, 256) idx.Build( diff --git a/tests/python/test_index_with_sift.py b/tests/python/test_index_with_sift.py index d47e3183f..9018c4e62 100644 --- a/tests/python/test_index_with_sift.py +++ b/tests/python/test_index_with_sift.py @@ -110,12 +110,13 @@ def download_sift(): @pytest.mark.parametrize("name,config", test_data) def test_index_with_sift(recall, name, config): + version = "knowhere-v2.2.0" download_sift() xb = fvecs_read("/tmp/sift/sift_base.fvecs") xq = fvecs_read("/tmp/sift/sift_query.fvecs") ids_true = ivecs_read("/tmp/sift/sift_groundtruth.ivecs") - idx = knowhere.CreateIndex(name) + idx = knowhere.CreateIndex(name, version) idx.Build( knowhere.ArrayToDataSet(xb), json.dumps(config), diff --git a/tests/ut/test_diskann.cc b/tests/ut/test_diskann.cc index 636fca22f..891299697 100644 --- a/tests/ut/test_diskann.cc +++ b/tests/ut/test_diskann.cc @@ -70,6 +70,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { REQUIRE_NOTHROW(fs::create_directory(kL2IndexDir)); REQUIRE_NOTHROW(fs::create_directory(kIPIndexDir)); int rows_num = 10; + auto version = GenTestVersionList(); auto test_gen = [rows_num]() { knowhere::Json json; json["dim"] = kDim; @@ -96,7 +97,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { // build process SECTION("Invalid build params test") { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); knowhere::Json test_json; knowhere::Status test_stat; // invalid metric type @@ -108,14 +109,16 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { test_json = test_gen(); test_json["data_path"] = kL2IndexPrefix + ".temp"; test_stat = diskann.Build(*ds_ptr, test_json); - REQUIRE(test_stat == knowhere::Status::diskann_file_error); + REQUIRE(test_stat == knowhere::Status::disk_file_error); } SECTION("Invalid search params test") { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto binarySet = knowhere::BinarySet(); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Build(*ds_ptr, test_gen()); - diskann.Deserialize(knowhere::BinarySet(), test_gen()); + diskann.Serialize(binarySet); + diskann.Deserialize(binarySet, test_gen()); knowhere::Json test_json; auto query_ds = GenDataSet(kNumQueries, kDim, 42); @@ -151,6 +154,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { REQUIRE_NOTHROW(fs::create_directory(kCOSINEIndexDir)); auto metric_str = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::IP, knowhere::metric::COSINE); + auto version = GenTestVersionList(); std::unordered_map metric_dir_map = { {knowhere::metric::L2, kL2IndexPrefix}, @@ -243,14 +247,15 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { // build process { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); + diskann.Serialize(binset); } { // knn search - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); @@ -264,15 +269,16 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { { std::string cached_nodes_file_path = std::string(build_gen()["index_prefix"]) + std::string("_cached_nodes.bin"); - REQUIRE(fs::exists(cached_nodes_file_path)); - fs::remove(cached_nodes_file_path); - auto diskann_tmp = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + if (fs::exists(cached_nodes_file_path)) { + fs::remove(cached_nodes_file_path); + } + auto diskann_tmp = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann_tmp.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); knowhere::Json knn_json = knowhere::Json::parse(knn_search_json); auto res = diskann_tmp.Search(*query_ds, knn_json, nullptr); REQUIRE(res.has_value()); - REQUIRE(GetKNNRecall(*knn_gt_ptr, *res.value()) == knn_recall); + REQUIRE(GetKNNRecall(*knn_gt_ptr, *res.value()) >= kKnnRecall); } // knn search with bitset @@ -314,6 +320,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { // This test case only check L2 TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { + auto version = GenTestVersionList(); for (const uint32_t dim : {kDim, kLargeDim}) { fs::remove_all(kDir); fs::remove(kDir); @@ -321,6 +328,7 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto base_gen = [&] { knowhere::Json json; + json[knowhere::meta::RETRIEVE_FRIENDLY] = true; json["dim"] = dim; json["metric_type"] = knowhere::metric::L2; json["k"] = kK; @@ -347,11 +355,12 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto diskann_index_pack = knowhere::Pack(file_manager); knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); knowhere::BinarySet binset; + diskann.Serialize(binset); { std::vector cache_sizes = {0, 1.0f * sizeof(float) * dim * kNumRows * 0.125 / (1024 * 1024 * 1024)}; for (const auto cache_size : cache_sizes) { @@ -362,7 +371,7 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { return json; }; knowhere::Json deserialize_json = knowhere::Json::parse(deserialize_gen().dump()); - auto index = knowhere::IndexFactory::Instance().Create("DISKANN", diskann_index_pack); + auto index = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto ret = index.Deserialize(binset, deserialize_json); REQUIRE(ret == knowhere::Status::success); std::vector ids_sizes = {1, kNumRows * 0.2, kNumRows * 0.7, kNumRows}; diff --git a/tests/ut/test_feder.cc b/tests/ut/test_feder.cc index 87bb4e0b4..905597edc 100644 --- a/tests/ut/test_feder.cc +++ b/tests/ut/test_feder.cc @@ -142,6 +142,8 @@ TEST_CASE("Test Feder", "[feder]") { int64_t dim = 128; int64_t seed = 42; + auto version = GenTestVersionList(); + auto base_gen = [&]() { knowhere::Json json; json[knowhere::meta::DIM] = dim; @@ -175,7 +177,7 @@ TEST_CASE("Test Feder", "[feder]") { SECTION("Test HNSW Feder") { auto name = knowhere::IndexEnum::INDEX_HNSW; - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = hnsw_gen(); @@ -199,7 +201,7 @@ TEST_CASE("Test Feder", "[feder]") { SECTION("Test IVF_FLAT Feder") { auto name = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = ivfflat_gen(); diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index aca22189b..dc974ab44 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -25,6 +25,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { const int64_t dim = 128; const auto metric_type = knowhere::metric::HAMMING; + auto version = GenTestVersionList(); auto base_bin_gen = [&]() { knowhere::Json json; @@ -58,10 +59,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, bin_hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); - if (!idx.HasRawData(metric_type)) { - return; - } + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -69,11 +67,14 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { auto ids_ds = GenIdsDataSet(nb, nq); REQUIRE(idx.Type() == name); auto res = idx.Build(*train_ds, json); + if (!idx.HasRawData(metric_type)) { + return; + } REQUIRE(res == knowhere::Status::success); knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto results = idx_new.GetVectorByIds(*ids_ds); REQUIRE(results.has_value()); @@ -101,6 +102,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { const int64_t dim = 128; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; @@ -159,10 +161,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); - if (!idx.HasRawData(metric)) { - return; - } + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -171,11 +170,14 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { auto ids_ds = GenIdsDataSet(nb, nq); REQUIRE(idx.Type() == name); auto res = idx.Build(*train_ds, json); + if (!idx.HasRawData(metric)) { + return; + } REQUIRE(res == knowhere::Status::success); knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto results = idx_new.GetVectorByIds(*ids_ds); REQUIRE(results.has_value()); diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index cd0537c9d..1f440f886 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -56,6 +56,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto topk = GENERATE(5, 10, 20); auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; @@ -89,7 +90,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -113,7 +114,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -143,7 +144,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -179,6 +180,8 @@ TEST_CASE("Test Iterator Mem Index With Binary Vector", "[float metrics]") { const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD); + auto version = GenTestVersionList(); + auto base_gen = [&]() { knowhere::Json json; json[knowhere::meta::DIM] = dim; @@ -208,7 +211,7 @@ TEST_CASE("Test Iterator Mem Index With Binary Vector", "[float metrics]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_ivfflat_cc.cc b/tests/ut/test_ivfflat_cc.cc index 94cc056d4..b67bb8163 100644 --- a/tests/ut/test_ivfflat_cc.cc +++ b/tests/ut/test_ivfflat_cc.cc @@ -24,6 +24,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { using Catch::Approx; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); int64_t nb = 10000, nq = 1000; int64_t dim = 128; @@ -140,7 +141,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -182,8 +183,9 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { SECTION("Test Build & Search Correctness") { using std::make_tuple; - auto ivf_flat = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); - auto ivf_flat_cc = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC); + auto ivf_flat = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version); + auto ivf_flat_cc = + knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version); knowhere::Json ivf_flat_json = knowhere::Json::parse(ivfflat_gen().dump()); knowhere::Json ivf_flat_cc_json = knowhere::Json::parse(ivfflatcc_gen().dump()); @@ -240,7 +242,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_mmap.cc b/tests/ut/test_mmap.cc index a37f76dc0..84734e575 100644 --- a/tests/ut/test_mmap.cc +++ b/tests/ut/test_mmap.cc @@ -50,6 +50,7 @@ TEST_CASE("Search mmap", "[float metrics]") { const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; @@ -126,7 +127,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -153,7 +154,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -177,7 +178,7 @@ TEST_CASE("Search mmap", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -213,6 +214,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; json[knowhere::meta::DIM] = dim; @@ -268,7 +270,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -290,7 +292,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -316,6 +318,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { auto dim = GENERATE(as{}, 8, 16, 32, 64, 128, 256, 512, 160); auto metric = GENERATE(as{}, knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; @@ -371,7 +374,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 0d0519256..d93f33ef8 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -35,6 +35,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; @@ -113,7 +114,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -155,7 +156,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -191,7 +192,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -230,7 +231,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -239,14 +240,14 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); } SECTION("Test IVFPQ with invalid params") { - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); uint32_t nb = 1000; uint32_t dim = 128; auto ivf_pq_gen = [&]() { @@ -272,6 +273,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD); + auto version = GenTestVersionList(); auto base_gen = [&]() { knowhere::Json json; json[knowhere::meta::DIM] = dim; @@ -313,7 +315,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -333,7 +335,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -355,7 +357,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -364,7 +366,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -378,6 +380,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { const int64_t topk = 5; auto dim = GENERATE(as{}, 8, 16, 32, 64, 128, 256, 512, 160); + auto version = GenTestVersionList(); auto metric = GENERATE(as{}, knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); auto base_gen = [&]() { @@ -419,7 +422,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index 38d31e501..fba56c32c 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -69,6 +69,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { const int64_t k = 5; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto version = GenTestVersionList(); const auto train_ds = GenDataSet(nb, dim); const auto query_ds = CopyDataSet(train_ds, nq); @@ -96,7 +97,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { } } - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); REQUIRE(idx.Build(*train_ds, conf) == knowhere::Status::success); auto res = idx.Search(*query_ds, conf, nullptr); REQUIRE(res.has_value()); diff --git a/tests/ut/test_utils.cc b/tests/ut/test_utils.cc index 329cce234..c7da4e60b 100644 --- a/tests/ut/test_utils.cc +++ b/tests/ut/test_utils.cc @@ -16,6 +16,7 @@ #include "knowhere/comp/time_recorder.h" #include "knowhere/heap.h" #include "knowhere/utils.h" +#include "knowhere/version.h" #include "utils.h" namespace { @@ -108,3 +109,20 @@ TEST_CASE("Test Time Recorder") { auto span = tr.ElapseFromBegin("done"); REQUIRE(span > 0); } + +TEST_CASE("Test Version") { + REQUIRE(knowhere::Version("knowhere-v1").Valid()); + REQUIRE(!knowhere::Version("knowhere-V1").Valid()); + REQUIRE(!knowhere::Version("knowhere-1.2.2").Valid()); + REQUIRE(!knowhere::Version("knowhere-vx.2.2-hotfix").Valid()); + REQUIRE(!knowhere::Version("knowhere-v222.22").Valid()); + REQUIRE(knowhere::Version::VersionSupport(knowhere::Version::GetMinimalSupport())); + REQUIRE(knowhere::Version::VersionSupport(knowhere::Version::GetCurrentVersion())); +} + +TEST_CASE("Test DiskLoad") { + REQUIRE(knowhere::UseDiskLoad(knowhere::IndexEnum::INDEX_DISKANN, + knowhere::Version::GetCurrentVersion().VersionCode())); + REQUIRE( + !knowhere::UseDiskLoad(knowhere::IndexEnum::INDEX_HNSW, knowhere::Version::GetCurrentVersion().VersionCode())); +} diff --git a/tests/ut/utils.h b/tests/ut/utils.h index c6409ecbb..f04596c50 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -10,7 +10,9 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include +#include #include +#include #include #include #include @@ -18,9 +20,11 @@ #include #include +#include "catch2/generators/catch_generators.hpp" #include "common/range_util.h" #include "knowhere/binaryset.h" #include "knowhere/dataset.h" +#include "knowhere/version.h" constexpr int64_t kSeed = 42; using IdDisPair = std::pair; @@ -223,3 +227,8 @@ GenerateRandomDistanceIdPair(size_t n) { } return res; } + +inline auto +GenTestVersionList() { + return GENERATE(as{}, knowhere::Version::GetCurrentVersion().VersionCode()); +} From ea12387701df9662494fbd2f75c7ce526a428376 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Wed, 20 Sep 2023 18:15:18 +0800 Subject: [PATCH 5/6] Support faiss FLAT index get vector with metric COSINE (#80) Signed-off-by: Yudong Cai --- src/common/comp/brute_force.cc | 8 ++--- src/index/flat/flat.cc | 15 +++++---- src/index/ivf/ivf.cc | 6 ++-- thirdparty/faiss/faiss/IndexFlat.cpp | 37 ++++++++++++++++++--- thirdparty/faiss/faiss/IndexFlat.h | 12 ++++++- thirdparty/faiss/faiss/IndexFlatCodes.h | 2 ++ thirdparty/faiss/faiss/IndexIVFFlat.cpp | 10 +++--- thirdparty/faiss/faiss/IndexIVFFlat.h | 8 ++--- thirdparty/faiss/faiss/impl/index_read.cpp | 3 ++ thirdparty/faiss/faiss/impl/index_write.cpp | 3 ++ thirdparty/faiss/faiss/utils/distances.cpp | 36 +++++++++++--------- thirdparty/faiss/faiss/utils/distances.h | 2 ++ 12 files changed, 97 insertions(+), 45 deletions(-) diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 930b04152..ebd452020 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -79,7 +79,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, dim, 1, nb, &buf, bitset); + faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset); } else { faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } @@ -173,7 +173,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, dim, 1, nb, &buf, bitset); + faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset); } else { faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } @@ -274,8 +274,8 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da auto cur_query = (const float*)xq + dim * index; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::range_search_cosine(copied_query.get(), (const float*)xb, dim, 1, nb, radius, &res, - bitset); + faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius, + &res, bitset); } else { faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset); diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index 5aaa2aeaf..0a4f44705 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -36,17 +36,18 @@ class FlatIndexNode : public IndexNode { Train(const DataSet& dataset, const Config& cfg) override { const FlatConfig& f_cfg = static_cast(cfg); - // do normalize for COSINE metric type - if (IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE)) { - Normalize(dataset); - } - auto metric = Str2FaissMetricType(f_cfg.metric_type.value()); if (!metric.has_value()) { LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type.value(); return metric.error(); } - index_ = std::make_unique(dataset.GetDim(), metric.value()); + if constexpr (std::is_same::value) { + index_ = std::make_unique(dataset.GetDim(), metric.value()); + } + if constexpr (std::is_same::value) { + bool is_cosine = IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE); + index_ = std::make_unique(dataset.GetDim(), metric.value(), is_cosine); + } return Status::success; } @@ -236,7 +237,7 @@ class FlatIndexNode : public IndexNode { bool HasRawData(const std::string& metric_type) const override { if constexpr (std::is_same::value) { - return !IsMetricType(metric_type, metric::COSINE); + return true; } if constexpr (std::is_same::value) { return true; diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index f176b0dae..f0070b0bb 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -276,15 +276,15 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { const IvfFlatConfig& ivf_flat_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cfg.nlist.value()); qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, is_cosine, metric.value()); + index = std::make_unique(qzr, dim, nlist, metric.value(), is_cosine); index->train(rows, (const float*)data); } if constexpr (std::is_same::value) { const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist.value()); qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, ivf_flat_cc_cfg.ssize.value(), is_cosine, - metric.value()); + index = std::make_unique(qzr, dim, nlist, ivf_flat_cc_cfg.ssize.value(), + metric.value(), is_cosine); index->train(rows, (const float*)data); } if constexpr (std::is_same::value) { diff --git a/thirdparty/faiss/faiss/IndexFlat.cpp b/thirdparty/faiss/faiss/IndexFlat.cpp index 1c987ee18..fae29e9f7 100644 --- a/thirdparty/faiss/faiss/IndexFlat.cpp +++ b/thirdparty/faiss/faiss/IndexFlat.cpp @@ -18,10 +18,28 @@ #include #include +#include "knowhere/utils.h" + namespace faiss { -IndexFlat::IndexFlat(idx_t d, MetricType metric) - : IndexFlatCodes(sizeof(float) * d, d, metric) {} +IndexFlat::IndexFlat(idx_t d, MetricType metric, bool is_cosine) + : IndexFlatCodes(sizeof(float) * d, d, metric) { + this->is_cosine = is_cosine; +} + +void IndexFlat::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + codes.resize((ntotal + n) * code_size); + sa_encode(n, x, &codes[ntotal * code_size]); + if (is_cosine) { + auto x_normalized = std::make_unique(n * d); + std::memcpy(x_normalized.get(), x, n * d * sizeof(float)); + auto norms = knowhere::NormalizeVecs(x_normalized.get(), n, d); + code_norms.resize(ntotal + n); + std::memcpy(&code_norms[ntotal], norms.data(), sizeof(float) * n); + } + ntotal += n; +} void IndexFlat::search( idx_t n, @@ -36,7 +54,11 @@ void IndexFlat::search( if (metric_type == METRIC_INNER_PRODUCT) { float_minheap_array_t res = {size_t(n), size_t(k), labels, distances}; - knn_inner_product(x, get_xb(), d, n, ntotal, &res, bitset); + if (is_cosine) { + knn_cosine(x, get_xb(), get_norms(), d, n, ntotal, &res, bitset); + } else { + knn_inner_product(x, get_xb(), d, n, ntotal, &res, bitset); + } } else if (metric_type == METRIC_L2) { float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, bitset); @@ -90,8 +112,13 @@ void IndexFlat::range_search( const BitsetView bitset) const { switch (metric_type) { case METRIC_INNER_PRODUCT: - range_search_inner_product( - x, get_xb(), d, n, ntotal, radius, result, bitset); + if (is_cosine) { + range_search_cosine(x, get_xb(), get_norms(), d, n, ntotal, + radius, result, bitset); + } else { + range_search_inner_product( + x, get_xb(), d, n, ntotal, radius, result, bitset); + } break; case METRIC_L2: range_search_L2sqr( diff --git a/thirdparty/faiss/faiss/IndexFlat.h b/thirdparty/faiss/faiss/IndexFlat.h index 155cb0b27..d220db2e6 100644 --- a/thirdparty/faiss/faiss/IndexFlat.h +++ b/thirdparty/faiss/faiss/IndexFlat.h @@ -22,7 +22,10 @@ struct IndexFlat : IndexFlatCodes { /// database vectors, size ntotal * d std::vector xb; - explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2); + explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2, + bool is_cosine = false); + + void add(idx_t n, const float* x) override; void search( idx_t n, @@ -67,6 +70,13 @@ struct IndexFlat : IndexFlatCodes { return (const float*)codes.data(); } + float* get_norms() { + return (float*)code_norms.data(); + } + const float* get_norms() const { + return (const float*)code_norms.data(); + } + IndexFlat() {} DistanceComputer* get_distance_computer() const override; diff --git a/thirdparty/faiss/faiss/IndexFlatCodes.h b/thirdparty/faiss/faiss/IndexFlatCodes.h index 0ea79e296..174ba46a3 100644 --- a/thirdparty/faiss/faiss/IndexFlatCodes.h +++ b/thirdparty/faiss/faiss/IndexFlatCodes.h @@ -22,6 +22,8 @@ struct IndexFlatCodes : Index { /// encoded dataset, size ntotal * code_size std::vector codes; + std::vector code_norms; + IndexFlatCodes(); IndexFlatCodes(size_t code_size, idx_t d, MetricType metric = METRIC_L2); diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.cpp b/thirdparty/faiss/faiss/IndexIVFFlat.cpp index 755db3247..d8bff5cc3 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFlat.cpp @@ -35,8 +35,8 @@ IndexIVFFlat::IndexIVFFlat( Index* quantizer, size_t d, size_t nlist, - bool is_cosine, - MetricType metric) + MetricType metric, + bool is_cosine) : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) { this->is_cosine = is_cosine; code_size = sizeof(float) * d; @@ -268,9 +268,9 @@ IndexIVFFlatCC::IndexIVFFlatCC( size_t d, size_t nlist, size_t ssize, - bool is_cosine, - MetricType metric) - : IndexIVFFlat(quantizer, d, nlist, is_cosine, metric) { + MetricType metric, + bool is_cosine) + : IndexIVFFlat(quantizer, d, nlist, metric, is_cosine) { replace_invlists(new ConcurrentArrayInvertedLists(nlist, code_size, ssize, is_cosine), true); } diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.h b/thirdparty/faiss/faiss/IndexIVFFlat.h index 7e81f4e45..7106a2ff6 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.h +++ b/thirdparty/faiss/faiss/IndexIVFFlat.h @@ -26,8 +26,8 @@ struct IndexIVFFlat : IndexIVF { Index* quantizer, size_t d, size_t nlist_, - bool is_cosine, - MetricType = METRIC_L2); + MetricType = METRIC_L2, + bool is_cosine = false); void restore_codes(const uint8_t* raw_data, const size_t raw_size); @@ -66,8 +66,8 @@ struct IndexIVFFlatCC : IndexIVFFlat { size_t d, size_t nlist, size_t ssize, - bool iscosine, - MetricType = METRIC_L2); + MetricType = METRIC_L2, + bool is_cosine = false); IndexIVFFlatCC() {} }; diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index e2291509a..6ff36885c 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -634,6 +634,9 @@ Index* read_index(IOReader* f, int io_flags) { read_index_header(idxf, f); idxf->code_size = idxf->d * sizeof(float); READXBVECTOR(idxf->codes); + if (idxf->is_cosine) { + READVECTOR(idxf->code_norms); + } FAISS_THROW_IF_NOT( idxf->codes.size() == idxf->ntotal * idxf->code_size); // leak! diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index e0c886af3..efe674dd2 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -461,6 +461,9 @@ void write_index(const Index* idx, IOWriter* f) { WRITE1(h); write_index_header(idx, f); WRITEXBVECTOR(idxf->codes); + if (idx->is_cosine) { + WRITEVECTOR(idxf->code_norms); + } } else if (const IndexLSH* idxl = dynamic_cast(idx)) { uint32_t h = fourcc("IxHe"); WRITE1(h); diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp index 6f15d51d1..789f07db4 100644 --- a/thirdparty/faiss/faiss/utils/distances.cpp +++ b/thirdparty/faiss/faiss/utils/distances.cpp @@ -163,16 +163,11 @@ void exhaustive_L2sqr_seq( } } -namespace { -float fvec_cosine(const float* x, const float* y, size_t d) { - return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d)); -} -} // namespace - template void exhaustive_cosine_seq( const float* x, const float* y, + const float* y_norms, size_t d, size_t nx, size_t ny, @@ -191,7 +186,10 @@ void exhaustive_cosine_seq( resi.begin(i); for (size_t j = 0; j < ny; j++) { if (bitset.empty() || !bitset.test(j)) { - float disij = fvec_cosine(x_i, y_j, d); + float norm = + (y_norms != nullptr) ? y_norms[j] + : sqrtf(fvec_norm_L2sqr(y_j, d)); + float disij = fvec_inner_product(x_i, y_j, d) / norm; resi.add_result(disij, j); } y_j += d; @@ -347,6 +345,7 @@ template void exhaustive_cosine_blas( const float* x, const float* y, + const float* y_norms_in, size_t d, size_t nx, size_t ny, @@ -361,10 +360,12 @@ void exhaustive_cosine_blas( const size_t bs_y = distance_compute_blas_database_bs; // const size_t bs_x = 16, bs_y = 16; std::unique_ptr ip_block(new float[bs_x * bs_y]); - std::unique_ptr y_norms(new float[nx]); + std::unique_ptr y_norms(new float[ny]); std::unique_ptr del2; - fvec_norms_L2(y_norms.get(), x, d, nx); + if (y_norms_in == nullptr) { + fvec_norms_L2(y_norms.get(), y, d, ny); + } for (size_t i0 = 0; i0 < nx; i0 += bs_x) { size_t i1 = i0 + bs_x; @@ -401,7 +402,8 @@ void exhaustive_cosine_blas( for (size_t j = j0; j < j1; j++) { float ip = *ip_line; - float dis = ip / y_norms[j]; + float dis = (y_norms_in != nullptr) ? ip / y_norms_in[j] + : ip / y_norms[j]; *ip_line = dis; ip_line++; } @@ -565,6 +567,7 @@ void knn_L2sqr( void knn_cosine( const float* x, const float* y, + const float* y_norms, size_t d, size_t nx, size_t ny, @@ -574,17 +577,17 @@ void knn_cosine( HeapResultHandler> res( ha->nh, ha->val, ha->ids, ha->k); if (nx < distance_compute_blas_threshold) { - exhaustive_cosine_seq(x, y, d, nx, ny, res, bitset); + exhaustive_cosine_seq(x, y, y_norms, d, nx, ny, res, bitset); } else { - exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset); + exhaustive_cosine_blas(x, y, y_norms, d, nx, ny, res, bitset); } } else { ReservoirResultHandler> res( ha->nh, ha->val, ha->ids, ha->k); if (nx < distance_compute_blas_threshold) { - exhaustive_cosine_seq(x, y, d, nx, ny, res, bitset); + exhaustive_cosine_seq(x, y, y_norms, d, nx, ny, res, bitset); } else { - exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset); + exhaustive_cosine_blas(x, y, y_norms, d, nx, ny, res, bitset); } } } @@ -655,6 +658,7 @@ void range_search_inner_product( void range_search_cosine( const float* x, const float* y, + const float* y_norms, size_t d, size_t nx, size_t ny, @@ -663,9 +667,9 @@ void range_search_cosine( const BitsetView bitset) { RangeSearchResultHandler> resh(res, radius); if (nx < distance_compute_blas_threshold) { - exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset); + exhaustive_cosine_seq(x, y, y_norms, d, nx, ny, resh, bitset); } else { - exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset); + exhaustive_cosine_blas(x, y, y_norms, d, nx, ny, resh, bitset); } } diff --git a/thirdparty/faiss/faiss/utils/distances.h b/thirdparty/faiss/faiss/utils/distances.h index cd590270b..1e06d51ac 100644 --- a/thirdparty/faiss/faiss/utils/distances.h +++ b/thirdparty/faiss/faiss/utils/distances.h @@ -199,6 +199,7 @@ void knn_L2sqr( void knn_cosine( const float* x, const float* y, + const float* y_norms, size_t d, size_t nx, size_t ny, @@ -274,6 +275,7 @@ void range_search_inner_product( void range_search_cosine( const float* x, const float* y, + const float* y_norms, size_t d, size_t nx, size_t ny, From e0a44d85084e5be06bb63f0a8f4dd93ed218fb4f Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Wed, 20 Sep 2023 14:15:43 -0400 Subject: [PATCH 6/6] merge with baseline --- thirdparty/faiss/faiss/IndexIVFFlat.cpp | 3 +-- thirdparty/faiss/faiss/IndexIVFFlat.h | 1 - thirdparty/faiss/faiss/index_factory.cpp | 6 ++---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.cpp b/thirdparty/faiss/faiss/IndexIVFFlat.cpp index 7292881b3..38cbf72a3 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFlat.cpp @@ -314,9 +314,8 @@ IndexIVFFlatDedup::IndexIVFFlatDedup( Index* quantizer, size_t d, size_t nlist_, - bool is_cosine, MetricType metric_type) - : IndexIVFFlat(quantizer, d, nlist_, is_cosine, metric_type) {} + : IndexIVFFlat(quantizer, d, nlist_, metric_type) {} void IndexIVFFlatDedup::train(idx_t n, const float* x) { std::unordered_map map; diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.h b/thirdparty/faiss/faiss/IndexIVFFlat.h index cd5654825..88b19681b 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.h +++ b/thirdparty/faiss/faiss/IndexIVFFlat.h @@ -83,7 +83,6 @@ struct IndexIVFFlatDedup : IndexIVFFlat { Index* quantizer, size_t d, size_t nlist_, - bool is_cosine, MetricType = METRIC_L2); /// also dedups the training set diff --git a/thirdparty/faiss/faiss/index_factory.cpp b/thirdparty/faiss/faiss/index_factory.cpp index 1b819ee7e..9f24217a4 100644 --- a/thirdparty/faiss/faiss/index_factory.cpp +++ b/thirdparty/faiss/faiss/index_factory.cpp @@ -303,12 +303,10 @@ IndexIVF* parse_IndexIVF( int d = quantizer->d; if (match("Flat")) { - // todo aguzhva: added 'false' as 'is_cosine' - return new IndexIVFFlat(get_q(), d, nlist, false, mt); + return new IndexIVFFlat(get_q(), d, nlist, mt); } if (match("FlatDedup")) { - // todo aguzhva: added 'false' as 'is_cosine' - return new IndexIVFFlatDedup(get_q(), d, nlist, false, mt); + return new IndexIVFFlatDedup(get_q(), d, nlist, mt); } if (match(sq_pattern)) { return new IndexIVFScalarQuantizer(