-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hnsw support fp16/bf16 #494
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,9 +34,6 @@ using hnswlib::QuantType; | |
|
||
template <typename DataType, QuantType quant_type = QuantType::None> | ||
class HnswIndexNode : public IndexNode { | ||
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>, | ||
"HnswIndexNode only support float/bianry"); | ||
|
||
public: | ||
using DistType = float; | ||
HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) { | ||
|
@@ -49,22 +46,33 @@ class HnswIndexNode : public IndexNode { | |
auto dim = dataset.GetDim(); | ||
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg); | ||
hnswlib::SpaceInterface<DistType>* space = nullptr; | ||
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) { | ||
space = new (std::nothrow) hnswlib::L2Space(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) { | ||
space = new (std::nothrow) hnswlib::InnerProductSpace(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) { | ||
space = new (std::nothrow) hnswlib::CosineSpace(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) { | ||
space = new (std::nothrow) hnswlib::HammingSpace(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) { | ||
space = new (std::nothrow) hnswlib::JaccardSpace(dim); | ||
if constexpr (KnowhereFloatTypeCheck<DataType>::value) { | ||
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) { | ||
space = new (std::nothrow) hnswlib::L2Space<DataType, DistType>(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) { | ||
space = new (std::nothrow) hnswlib::InnerProductSpace<DataType, DistType>(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) { | ||
space = new (std::nothrow) hnswlib::CosineSpace<DataType, DistType>(dim); | ||
} else { | ||
LOG_KNOWHERE_WARNING_ | ||
<< "metric type and data type(float32, float16 and bfloat16) are not match in hnsw: " | ||
<< hnsw_cfg.metric_type.value(); | ||
return Status::invalid_metric_type; | ||
} | ||
} else { | ||
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value(); | ||
return Status::invalid_metric_type; | ||
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) { | ||
space = new (std::nothrow) hnswlib::HammingSpace(dim); | ||
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) { | ||
space = new (std::nothrow) hnswlib::JaccardSpace(dim); | ||
} else { | ||
LOG_KNOWHERE_WARNING_ << "metric type and data type(binary) are not match in hnsw: " | ||
<< hnsw_cfg.metric_type.value(); | ||
return Status::invalid_metric_type; | ||
} | ||
} | ||
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(), | ||
hnsw_cfg.efConstruction.value()); | ||
|
||
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should use "quantType" instead of "quant_type" to sync up the coding style |
||
space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); | ||
if (index == nullptr) { | ||
LOG_KNOWHERE_WARNING_ << "memory malloc error."; | ||
return Status::malloc_error; | ||
|
@@ -75,7 +83,7 @@ class HnswIndexNode : public IndexNode { | |
} | ||
this->index_ = index; | ||
if constexpr (quant_type != QuantType::None) { | ||
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows); | ||
this->index_->trainSQuant((const DataType*)dataset.GetTensor(), rows); | ||
} | ||
return Status::success; | ||
} | ||
|
@@ -225,11 +233,11 @@ class HnswIndexNode : public IndexNode { | |
private: | ||
class iterator : public IndexIterator { | ||
public: | ||
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform, | ||
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf, | ||
const float refine_ratio = 0.5f) | ||
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled && | ||
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data) | ||
iterator(const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index, const char* query, | ||
const bool transform, const BitsetView& bitset, const bool for_tuning = false, | ||
const size_t seed_ef = kIteratorSeedEf, const float refine_ratio = 0.5f) | ||
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled && | ||
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data) | ||
? refine_ratio | ||
: 0.0f), | ||
index_(index), | ||
|
@@ -251,15 +259,15 @@ class HnswIndexNode : public IndexNode { | |
} | ||
float | ||
raw_distance(int64_t id) override { | ||
if constexpr (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled && | ||
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data) { | ||
if constexpr (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled && | ||
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data) { | ||
return (transform_ ? -1 : 1) * index_->calcRefineDistance(workspace_->raw_query_data.get(), id); | ||
} | ||
throw std::runtime_error("raw_distance not supported: index does not have raw data or sq is not enabled"); | ||
} | ||
|
||
private: | ||
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_; | ||
const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_; | ||
const bool transform_; | ||
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_; | ||
}; | ||
|
@@ -466,8 +474,8 @@ class HnswIndexNode : public IndexNode { | |
|
||
MemoryIOReader reader(binary->data.get(), binary->size); | ||
|
||
hnswlib::SpaceInterface<float>* space = nullptr; | ||
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space); | ||
hnswlib::SpaceInterface<DistType>* space = nullptr; | ||
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space); | ||
index_->loadIndex(reader); | ||
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_ | ||
<< " #max level:" << index_->maxlevel_ | ||
|
@@ -486,8 +494,8 @@ class HnswIndexNode : public IndexNode { | |
delete index_; | ||
} | ||
try { | ||
hnswlib::SpaceInterface<float>* space = nullptr; | ||
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space); | ||
hnswlib::SpaceInterface<DistType>* space = nullptr; | ||
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space); | ||
index_->loadIndex(filename, config); | ||
} catch (std::exception& e) { | ||
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); | ||
|
@@ -581,7 +589,7 @@ class HnswIndexNode : public IndexNode { | |
} | ||
|
||
private: | ||
hnswlib::HierarchicalNSW<DistType, quant_type>* index_; | ||
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_; | ||
std::shared_ptr<ThreadPool> search_pool_; | ||
}; | ||
|
||
|
@@ -592,14 +600,14 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp16); | |
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bf16); | ||
#else | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16); | ||
#endif | ||
|
||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8); | ||
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8); | ||
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine); | ||
} // namespace knowhere |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -323,7 +323,7 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics] | |
} | ||
} | ||
|
||
TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any comments / TODO about the reason for disabling this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metric type and data type not match in this case. I change the data type to binary, and it work now. |
||
TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[binary metrics]") { | ||
using Catch::Approx; | ||
|
||
const int64_t nb = 1000, nq = 10; | ||
|
@@ -348,21 +348,21 @@ TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") { | |
json[knowhere::indexparam::SEED_EF] = 64; | ||
return json; | ||
}; | ||
const auto train_ds = GenDataSet(nb, dim); | ||
const auto query_ds = GenDataSet(nq, dim); | ||
const auto train_ds = GenBinDataSet(nb, dim); | ||
const auto query_ds = GenBinDataSet(nq, dim); | ||
|
||
const knowhere::Json conf = { | ||
{knowhere::meta::METRIC_TYPE, metric}, | ||
{knowhere::meta::TOPK, topk}, | ||
}; | ||
|
||
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, nullptr); | ||
auto gt = knowhere::BruteForce::Search<knowhere::bin1>(train_ds, query_ds, conf, nullptr); | ||
SECTION("Test Search using iterator") { | ||
using std::make_tuple; | ||
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({ | ||
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), | ||
})); | ||
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value(); | ||
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value(); | ||
auto cfg_json = gen().dump(); | ||
CAPTURE(name, cfg_json); | ||
knowhere::Json json = knowhere::Json::parse(cfg_json); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just in case for the future: should
FloatAccuracy
remain the same for bf16 and fp16?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think float16 can use the smallest positive value(6.1 x 10^(-5)), and bfloat16 can use the same as float32.