Skip to content

Commit

Permalink
update index and data type check function
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Jun 14, 2024
1 parent 8ad9fca commit 646d314
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 41 deletions.
5 changes: 5 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ constexpr const char* INDEX_RAFT_IVFFLAT = "GPU_RAFT_IVF_FLAT";
constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ";
constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";

constexpr const char* INDEX_GPU_BRUTEFORCE = "GPU_BRUTE_FORCE";
constexpr const char* INDEX_GPU_IVFFLAT = "GPU_IVF_FLAT";
constexpr const char* INDEX_GPU_IVFPQ = "GPU_IVF_PQ";
constexpr const char* INDEX_GPU_CAGRA = "GPU_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8";
constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE";
Expand Down
27 changes: 7 additions & 20 deletions include/knowhere/comp/knowhere_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,14 @@
#include "knowhere/index/index_factory.h"
namespace knowhere {
namespace KnowhereCheck {
bool
static bool
IndexTypeAndDataTypeCheck(const std::string& index_name, VecType data_type) {
auto& index_factory = IndexFactory::Instance();
switch (data_type) {
case VecType::VECTOR_BINARY:
return index_factory.HasIndex<bin1>(index_name);
case VecType::VECTOR_FLOAT:
return index_factory.HasIndex<fp32>(index_name);
case VecType::VECTOR_BFLOAT16:
return index_factory.HasIndex<bf16>(index_name);
case VecType::VECTOR_FLOAT16:
return index_factory.HasIndex<fp16>(index_name);
case VecType::VECTOR_SPARSE_FLOAT:
if (index_name != IndexEnum::INDEX_SPARSE_INVERTED_INDEX && index_name != IndexEnum::INDEX_SPARSE_WAND &&
index_name != IndexEnum::INDEX_HNSW) {
return false;
} else {
return index_factory.HasIndex<fp32>(index_name);
}
default:
return false;
auto& static_index_table = IndexFactory::StaticIndexTableInstance();
auto key = std::pair<std::string, VecType>(index_name, data_type);
if (static_index_table.find(key) != static_index_table.end()) {
return true;
} else {
return false;
}
}
} // namespace KnowhereCheck
Expand Down
13 changes: 10 additions & 3 deletions include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define INDEX_FACTORY_H

#include <functional>
#include <set>
#include <string>
#include <unordered_map>

Expand All @@ -28,11 +29,11 @@ class IndexFactory {
template <typename DataType>
const IndexFactory&
Register(const std::string& name, std::function<Index<IndexNode>(const int32_t&, const Object&)> func);
template <typename DataType>
bool
HasIndex(const std::string& name);
static IndexFactory&
Instance();
typedef std::set<std::pair<std::string, VecType>> GlobalIndexTable;
static GlobalIndexTable&
StaticIndexTableInstance();

private:
struct FunMapValueBase {
Expand Down Expand Up @@ -76,6 +77,12 @@ class IndexFactory {
std::make_unique<index_node<MockData<data_type>::type>>(version, object), thread_size)); \
}, \
data_type)
#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(name, index_table) \
static int name = []() -> int { \
auto& static_index_table = IndexFactory::StaticIndexTableInstance(); \
static_index_table.insert(index_table.begin(), index_table.end()); \
return 0; \
}();
} // namespace knowhere

#endif /* INDEX_FACTORY_H */
79 changes: 79 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#ifndef INDEX_TABLE_H
#define INDEX_TABLE_H
#include <set>
#include <string>

#include "knowhere/comp/index_param.h"
#include "knowhere/index/index_factory.h"
namespace knowhere {
static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
// binary ivf
{IndexEnum::INDEX_FAISS_BIN_IDMAP, VecType::VECTOR_BINARY},
{IndexEnum::INDEX_FAISS_BIN_IVFFLAT, VecType::VECTOR_BINARY},
// ivf
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16},
// gpu index
{IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT},
// hnsw
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16},
// diskann
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_BFLOAT16},
// sparse index
{IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT},
{IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT},
};
KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(KNOWHERE_STATIC_INDEX, legal_knowhere_index)
} // namespace knowhere
#endif /* INDEX_TABLE_H */
4 changes: 2 additions & 2 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) {

inline bool
IsFlatIndex(const knowhere::IndexType& index_type) {
static std::vector<knowhere::IndexType> flat_index_list = {IndexEnum::INDEX_FAISS_IDMAP,
IndexEnum::INDEX_FAISS_GPU_IDMAP};
static std::vector<knowhere::IndexType> flat_index_list = {
IndexEnum::INDEX_FAISS_IDMAP, IndexEnum::INDEX_FAISS_GPU_IDMAP, IndexEnum::INDEX_GPU_BRUTEFORCE};
return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end();
}

Expand Down
23 changes: 7 additions & 16 deletions src/index/index_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include "knowhere/index/index_factory.h"

#include "knowhere/index/index_table.h"

#ifdef KNOWHERE_WITH_RAFT
#include <cuda_runtime_api.h>
#endif
Expand Down Expand Up @@ -72,14 +74,6 @@ IndexFactory::Register(const std::string& name, std::function<Index<IndexNode>(c
return *this;
}

template <typename DataType>
bool
IndexFactory::HasIndex(const std::string& name) {
auto& func_mapping_ = MapInstance();
auto key = GetKey<DataType>(name);
return (func_mapping_.find(key) != func_mapping_.end());
}

IndexFactory&
IndexFactory::Instance() {
static IndexFactory factory;
Expand All @@ -93,6 +87,11 @@ IndexFactory::MapInstance() {
static FuncMap func_map;
return func_map;
}
IndexFactory::GlobalIndexTable&
IndexFactory::StaticIndexTableInstance() {
static GlobalIndexTable static_index_table;
return static_index_table;
}

} // namespace knowhere
//
Expand All @@ -116,11 +115,3 @@ knowhere::IndexFactory::Register<knowhere::fp16>(
template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::bf16>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>);
template bool
knowhere::IndexFactory::HasIndex<knowhere::fp32>(const std::string&);
template bool
knowhere::IndexFactory::HasIndex<knowhere::bin1>(const std::string&);
template bool
knowhere::IndexFactory::HasIndex<knowhere::fp16>(const std::string&);
template bool
knowhere::IndexFactory::HasIndex<knowhere::bf16>(const std::string&);
10 changes: 10 additions & 0 deletions tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@ TEST_CASE("Test index and data type check", "[IndexCheckTest]") {
knowhere::VecType::VECTOR_FLOAT) == true);
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW,
knowhere::VecType::VECTOR_BFLOAT16) == true);

#ifndef KNOWHERE_WITH_CARDINAL
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW,
knowhere::VecType::VECTOR_BINARY) == false);
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW,
knowhere::VecType::VECTOR_SPARSE_FLOAT) == false);
#else
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW,
knowhere::VecType::VECTOR_BINARY) == true);
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW,
knowhere::VecType::VECTOR_SPARSE_FLOAT) == true);
#endif
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_DISKANN,
knowhere::VecType::VECTOR_FLOAT) == true);
REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_DISKANN,
Expand Down

0 comments on commit 646d314

Please sign in to comment.