-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hnsw support fp16/bf16 #494
Conversation
@cqy123456 🔍 Important: PR Classification Needed! For efficient project management and a seamless review process, it's essential to classify your PR correctly. Here's how:
For any PR outside the kind/improvement category, ensure you link to the associated issue using the format: “issue: #”. Thanks for your efforts and contribution to the community!. |
for (auto i = 0; i < d; i++) { | ||
norm_l2_sqr += (float)x[i] * (float)x[i]; | ||
} | ||
if (norm_l2_sqr > 0 && std::abs(1.0f - norm_l2_sqr) > FloatAccuracy) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just in case for the future: should FloatAccuracy
remain the same for bf16 and fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think float16 can use the smallest positive value(6.1 x 10^(-5)), and bfloat16 can use the same as float32.
@@ -317,65 +317,65 @@ TEST_CASE("Test Iterator IVFFlatCC With Newly Insert Vectors", "[float metrics] | |||
} | |||
} | |||
|
|||
TEST_CASE("Test Iterator Mem Index With Binary Metrics", "[float metrics]") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any comments / TODO about the reason for disabling this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metric type and data type not match in this case. I change the data type to binary, and it work now.
float res = 0; | ||
for (unsigned i = 0; i < qty; i++) { | ||
res += ((float*)pVect1)[i] * ((float*)pVect2)[i]; | ||
if constexpr (!std::is_same<DataType, float>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_same_v
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
template <typename DataType, typename DistanceType> | ||
static DistanceType | ||
NormSqr(const void* pVect1v, const void* qty_ptr) { | ||
if constexpr (!std::is_same<DataType, float>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_same_v
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
template <typename DataType, typename DistanceType> | ||
static DistanceType | ||
L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { | ||
if constexpr (!std::is_same<DataType, float>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_same_v
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
5a468bb
to
3221ccf
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #494 +/- ##
=========================================
+ Coverage 0 71.11% +71.11%
=========================================
Files 0 67 +67
Lines 0 4393 +4393
=========================================
+ Hits 0 3124 +3124
- Misses 0 1269 +1269 |
When metric type = cosine, fp16 query vector will normalize into a new fp16 vector. Compare with float32, fraction of bf16(fraction = 7) is much less than float32(fraction=23). So e2e fail in fp16 recall check(0.9817 >= 0.99). |
/lgtm |
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(), | ||
hnsw_cfg.efConstruction.value()); | ||
|
||
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use "quantType" instead of "quant_type" to sync up the coding style
thirdparty/hnswlib/hnswlib/hnswlib.h
Outdated
@@ -183,7 +183,7 @@ struct IteratorWorkspace { | |||
// normalized_query_data(if any). Thus storing the normalized_query_data | |||
// separately in a unique_ptr so it can be freed when finished. | |||
IteratorWorkspace(const void* query_data, const size_t num_elements, const size_t seed_ef, const bool for_tuning, | |||
std::unique_ptr<float[]> normalized_query_data, const knowhere::BitsetView& bitset, | |||
std::unique_ptr<char[]> normalized_query_data, const knowhere::BitsetView& bitset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change float[] to char[] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for supporting float16 and bfloat16, and hnswlib use char[] to store data/
thirdparty/hnswlib/hnswlib/hnswalg.h
Outdated
if constexpr (knowhere::KnowhereFloatTypeCheck<data_t>::value) { | ||
if (metric_type_ == Metric::COSINE) { | ||
auto normalized_query = std::make_unique<char[]>(space_->get_data_size()); | ||
std::memcpy(normalized_query.get(), query_data, space_->get_data_size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memcpy_s maybe safer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memcpy_s( void *dest, size_t numberOfElements, const void *src, size_t count );
numberOfElements and count cannot be distinguished in these code, and memcpy_s is not as portable as memcpy.
/lgtm |
Signed-off-by: cqy123456 <[email protected]>
/lgtm |
Signed-off-by: cqy123456 <[email protected]>
/lgtm |
/approve |
1 similar comment
/approve |
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: cqy123456, Presburger The full list of commands accepted by this bot can be found here. The pull request process is described here
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
/lgtm |
issue: #287