Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hnsw support fp16/bf16 #494

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ IsFlatIndex(const knowhere::IndexType& index_type) {
return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end();
}

template <typename DataType>
extern float
NormalizeVec(float* x, int32_t d);
NormalizeVec(DataType* x, int32_t d);

template <typename DataType>
extern std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim);
NormalizeVecs(DataType* x, size_t rows, int32_t dim);

template <typename DataType = knowhere::fp32>
extern void
Normalize(const DataSet& dataset);

extern std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim);
template <typename DataType>
extern std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim);

constexpr inline uint64_t seed = 0xc70f6907UL;

Expand Down Expand Up @@ -78,6 +82,16 @@ hash_binary_vec(const uint8_t* x, size_t d) {
return h;
}

inline uint64_t
hash_half_precision_float(const void* x, size_t d) {
uint64_t h = seed;
auto u16_x = (uint16_t*)(x);
for (size_t i = 0; i < d; ++i) {
h = h * 13331 + u16_x[i];
}
return h;
}

template <typename DataType>
inline std::string
GetIndexKey(const std::string& name) {
Expand Down
59 changes: 54 additions & 5 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@ namespace knowhere {
const float FloatAccuracy = 0.00001;

// normalize one vector and return its norm
// todo(cqy123456): Template specialization for fp16/bf16;
// float16 uses the smallest representable positive float16 value(6.1 x 10^(-5)) as FloatAccuracy;
// bfloat16 uses the same FloatAccuracy as float32;
template <typename DataType>
float
NormalizeVec(DataType* x, int32_t d) {
float norm_l2_sqr = 0.0;
for (auto i = 0; i < d; i++) {
norm_l2_sqr += (float)x[i] * (float)x[i];
}
if (norm_l2_sqr > 0 && std::abs(1.0f - norm_l2_sqr) > FloatAccuracy) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just in case for the future: should FloatAccuracy remain the same for bf16 and fp16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think float16 can use the smallest positive value(6.1 x 10^(-5)), and bfloat16 can use the same as float32.

float norm_l2 = std::sqrt(norm_l2_sqr);
for (int32_t i = 0; i < d; i++) {
x[i] = (DataType)((float)x[i] / norm_l2);
}
return norm_l2;
}
return 1.0f;
}

template <>
float
NormalizeVec(float* x, int32_t d) {
float norm_l2_sqr = faiss::fvec_norm_L2sqr(x, d);
Expand All @@ -41,20 +62,22 @@ NormalizeVec(float* x, int32_t d) {
}

// normalize all vectors and return their norms
template <typename DataType>
std::vector<float>
NormalizeVecs(float* x, size_t rows, int32_t dim) {
NormalizeVecs(DataType* x, size_t rows, int32_t dim) {
std::vector<float> norms(rows);
for (size_t i = 0; i < rows; i++) {
norms[i] = NormalizeVec(x + i * dim, dim);
Presburger marked this conversation as resolved.
Show resolved Hide resolved
}
return norms;
}

template <typename DataType>
void
Normalize(const DataSet& dataset) {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
float* data = (float*)dataset.GetTensor();
auto data = (DataType*)dataset.GetTensor();

LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim;

Expand All @@ -64,9 +87,10 @@ Normalize(const DataSet& dataset) {
}

// copy and return normalized vectors
std::unique_ptr<float[]>
CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<float[]>(rows * dim);
template <typename DataType>
std::unique_ptr<DataType[]>
CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim) {
auto x_normalized = std::make_unique<DataType[]>(rows * dim);
std::copy_n(x, rows * dim, x_normalized.get());
NormalizeVecs(x_normalized.get(), rows, dim);
return x_normalized;
Expand Down Expand Up @@ -120,4 +144,29 @@ UseDiskLoad(const std::string& index_type, const int32_t& version) {
#endif
}

template float
NormalizeVec<fp16>(fp16* x, int32_t d);
template float
NormalizeVec<bf16>(bf16* x, int32_t d);

template std::vector<float>
NormalizeVecs<fp32>(fp32* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<fp16>(fp16* x, size_t rows, int32_t dim);
template std::vector<float>
NormalizeVecs<bf16>(bf16* x, size_t rows, int32_t dim);

template void
Normalize<fp32>(const DataSet& dataset);
template void
Normalize<fp16>(const DataSet& dataset);
template void
Normalize<bf16>(const DataSet& dataset);

template std::unique_ptr<fp32[]>
CopyAndNormalizeVecs(const fp32* x, size_t rows, int32_t dim);
template std::unique_ptr<fp16[]>
CopyAndNormalizeVecs(const fp16* x, size_t rows, int32_t dim);
template std::unique_ptr<bf16[]>
CopyAndNormalizeVecs(const bf16* x, size_t rows, int32_t dim);
} // namespace knowhere
82 changes: 45 additions & 37 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ using hnswlib::QuantType;

template <typename DataType, QuantType quant_type = QuantType::None>
class HnswIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"HnswIndexNode only support float/bianry");

public:
using DistType = float;
HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) {
Expand All @@ -49,22 +46,33 @@ class HnswIndexNode : public IndexNode {
auto dim = dataset.GetDim();
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);
hnswlib::SpaceInterface<DistType>* space = nullptr;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
if constexpr (KnowhereFloatTypeCheck<DataType>::value) {
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) {
space = new (std::nothrow) hnswlib::L2Space<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) {
space = new (std::nothrow) hnswlib::InnerProductSpace<DataType, DistType>(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
space = new (std::nothrow) hnswlib::CosineSpace<DataType, DistType>(dim);
} else {
LOG_KNOWHERE_WARNING_
<< "metric type and data type(float32, float16 and bfloat16) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
} else {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING)) {
space = new (std::nothrow) hnswlib::HammingSpace(dim);
} else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
space = new (std::nothrow) hnswlib::JaccardSpace(dim);
} else {
LOG_KNOWHERE_WARNING_ << "metric type and data type(binary) are not match in hnsw: "
<< hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
}
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(),
hnsw_cfg.efConstruction.value());

auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use "quantType" instead of "quant_type" to sync up the coding style

space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -75,7 +83,7 @@ class HnswIndexNode : public IndexNode {
}
this->index_ = index;
if constexpr (quant_type != QuantType::None) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
this->index_->trainSQuant((const DataType*)dataset.GetTensor(), rows);
}
return Status::success;
}
Expand Down Expand Up @@ -225,11 +233,11 @@ class HnswIndexNode : public IndexNode {
private:
class iterator : public IndexIterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf,
const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data)
iterator(const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index, const char* query,
const bool transform, const BitsetView& bitset, const bool for_tuning = false,
const size_t seed_ef = kIteratorSeedEf, const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data)
? refine_ratio
: 0.0f),
index_(index),
Expand All @@ -251,15 +259,15 @@ class HnswIndexNode : public IndexNode {
}
float
raw_distance(int64_t id) override {
if constexpr (hnswlib::HierarchicalNSW<DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DistType, quant_type>::has_raw_data) {
if constexpr (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data) {
return (transform_ ? -1 : 1) * index_->calcRefineDistance(workspace_->raw_query_data.get(), id);
}
throw std::runtime_error("raw_distance not supported: index does not have raw data or sq is not enabled");
}

private:
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
};
Expand Down Expand Up @@ -466,8 +474,8 @@ class HnswIndexNode : public IndexNode {

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(reader);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
Expand All @@ -486,8 +494,8 @@ class HnswIndexNode : public IndexNode {
delete index_;
}
try {
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
Expand Down Expand Up @@ -581,7 +589,7 @@ class HnswIndexNode : public IndexNode {
}

private:
hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index_;
std::shared_ptr<ThreadPool> search_pool_;
};

Expand All @@ -592,14 +600,14 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_DEPRECATED, HnswIndexNode, bf16);
#else
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
#endif

KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
} // namespace knowhere
10 changes: 5 additions & 5 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics]
}
}

TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any comments / TODO about the reason for disabling this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metric type and data type not match in this case. I change the data type to binary, and it work now.

TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[binary metrics]") {
using Catch::Approx;

const int64_t nb = 1000, nq = 10;
Expand All @@ -348,21 +348,21 @@ TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") {
json[knowhere::indexparam::SEED_EF] = 64;
return json;
};
const auto train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);
const auto train_ds = GenBinDataSet(nb, dim);
const auto query_ds = GenBinDataSet(nq, dim);

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric},
{knowhere::meta::TOPK, topk},
};

auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::bin1>(train_ds, query_ds, conf, nullptr);
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
Expand Down
Loading
Loading