Skip to content

Commit

Permalink
Merge branch 'faiss_174_upgrade_with_my_files' into faiss_174_upgrade…
Browse files Browse the repository at this point in the history
…_candidate
  • Loading branch information
alexanderguzhva committed Sep 20, 2023
2 parents 00a0a09 + 8382969 commit 2e4a50c
Show file tree
Hide file tree
Showing 394 changed files with 46,363 additions and 12,120 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF)
knowhere_option(WITH_COVERAGE "Build with coverage" OFF)
knowhere_option(WITH_CCACHE "Build with ccache" ON)
knowhere_option(WITH_PROFILER "Build with profiler" OFF)
knowhere_option(WITH_FAISS_TESTS "Build with Faiss unit tests" OFF)

if(KNOWHERE_VERSION)
message(STATUS "Building KNOWHERE version: ${KNOWHERE_VERSION}")
Expand Down Expand Up @@ -147,6 +148,10 @@ if(WITH_BENCHMARK)
add_subdirectory(benchmark)
endif()

if(WITH_FAISS_TESTS)
add_subdirectory(tests/faiss)
endif()

install(TARGETS knowhere
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/knowhere"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ endif()
macro(benchmark_test target file)
set(FILE_SRCS ${file})
add_executable(${target} ${FILE_SRCS})
target_link_libraries(${target} ${depend_libs} ${unittest_libs})
target_link_libraries(${target} ${depend_libs} ${unittest_libs} atomic)
install(TARGETS ${target} DESTINATION unittest)
endmacro()

Expand Down
4 changes: 4 additions & 0 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ knowhere_file_glob(GLOB FAISS_AVX512_SRCS

list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX512_SRCS})

# disable RHNSW
knowhere_file_glob(GLOB FAISS_RHNSW_SRCS thirdparty/faiss/faiss/impl/RHNSW.cpp)
list(REMOVE_ITEM FAISS_SRCS ${FAISS_RHNSW_SRCS})

if(__X86_64)
set(UTILS_SRC src/simd/distances_ref.cc src/simd/hook.cc)
set(UTILS_SSE_SRC src/simd/distances_sse.cc)
Expand Down
5 changes: 5 additions & 0 deletions conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class KnowhereConan(ConanFile):
"with_ut": [True, False],
"with_benchmark": [True, False],
"with_coverage": [True, False],
"with_faiss_tests": [True, False],
}
default_options = {
"shared": True,
Expand All @@ -47,6 +48,7 @@ class KnowhereConan(ConanFile):
"with_coverage": False,
"boost:without_test": True,
"fmt:header_only": True,
"with_faiss_tests": False,
}

exports_sources = (
Expand Down Expand Up @@ -96,6 +98,8 @@ def requirements(self):
if self.options.with_benchmark:
self.requires("gtest/1.13.0")
self.requires("hdf5/1.14.0")
if self.options.with_faiss_tests:
self.requires("gtest/1.13.0")

@property
def _required_boost_components(self):
Expand Down Expand Up @@ -156,6 +160,7 @@ def generate(self):
tc.variables["WITH_UT"] = self.options.with_ut
tc.variables["WITH_BENCHMARK"] = self.options.with_benchmark
tc.variables["WITH_COVERAGE"] = self.options.with_coverage
tc.variables["WITH_FAISS_TESTS"] = self.options.with_faiss_tests
tc.generate()
deps = CMakeDeps(self)
deps.generate()
Expand Down
93 changes: 75 additions & 18 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "knowhere/log.h"
#include "knowhere/utils.h"

#include "index/bitsetview_idselector.h"

namespace knowhere {

/* knowhere wrapper API to call faiss brute force search for all metric types */
Expand Down Expand Up @@ -67,36 +69,51 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
// // todo aguzhva: bitset was here
// binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
// // todo aguzhva: bitset was here
// binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
// dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, bitset);
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
Expand All @@ -106,8 +123,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
// // todo aguzhva: bitset was here
// binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
// cur_labels, bitset);
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, bitset);
cur_labels, id_selector);
break;
}
default: {
Expand Down Expand Up @@ -161,36 +181,51 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset);
faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector);
} else {
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
// // todo aguzhva: bitset was here
// faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
// // todo aguzhva: bitset was here
// binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
// // todo aguzhva: bitset was here
// binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
// dim / 8, bitset);
binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb,
dim / 8, bitset);
dim / 8, id_selector);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
Expand All @@ -200,8 +235,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
// // todo aguzhva: bitset was here
// binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
// cur_labels, bitset);
binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances,
cur_labels, bitset);
cur_labels, id_selector);
break;
}
default: {
Expand Down Expand Up @@ -263,36 +301,55 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
// // todo aguzhva: bitset was here
// faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, id_selector);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (const float*)xq + dim * index;
if (is_cosine) {
auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius,
&res, bitset);
// // todo aguzhva: bitset was here
// faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius, &res,
// bitset);
faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius, &res,
id_selector);
} else {
// // todo aguzhva: bitset was here
// faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
// bitset);
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
bitset);
id_selector);
}
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
// // todo aguzhva: bitset was here
// faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
// faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, id_selector);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
// // todo aguzhva: bitset was here
// faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, cur_query,
// (const uint8_t*)xb, 1, nb, (int)radius,
// dim / 8, &res, bitset);
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, cur_query,
(const uint8_t*)xb, 1, nb, (int)radius,
dim / 8, &res, bitset);
dim / 8, &res, id_selector);
break;
}
default: {
Expand Down
2 changes: 1 addition & 1 deletion src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const s
auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) -
sizeof(ivfl->invlists->code_size);
auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t);
auto ids_size = ivfl->ntotal * sizeof(faiss::Index::idx_t);
auto ids_size = ivfl->ntotal * sizeof(faiss::idx_t);
// auto codes_size = ivfl->d * ivfl->ntotal * sizeof(float);

// IVF_FLAT_NM format, need convert to new format
Expand Down
31 changes: 31 additions & 0 deletions src/index/bitsetview_idselector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#pragma once

#include "knowhere/bitsetview.h"

#include <faiss/impl/IDSelector.h>

namespace knowhere {

struct BitsetViewIDSelector : faiss::IDSelector {
BitsetView bitset_view;

inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} {}

inline bool is_member(faiss::idx_t id) const override final {
// it is by design that bitset_view.empty() is not tested here
return (!bitset_view.test(id));
}
};

}
Loading

0 comments on commit 2e4a50c

Please sign in to comment.