Skip to content

Commit

Permalink
add patch for compute float32 as bf16
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy committed May 10, 2024
1 parent f29da89 commit c6465cb
Show file tree
Hide file tree
Showing 13 changed files with 495 additions and 0 deletions.
11 changes: 11 additions & 0 deletions include/knowhere/comp/knowhere_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ class KnowhereConfig {
static std::string
SetSimdType(const SimdType simd_type);

/**
*The purpose of this interface is: part of the sealed indexes default to using bf16 as the base data to achieve
*higher capacity; to ensure consistency in computation between growing and sealed, it is necessary to maintain the
*same precision in growing calculations as in sealed.
*/
static void
EnablePatchForComputeFP32AsBF16();

static void
DisablePatchForComputeFP32AsBF16();

/**
* Set openblas threshold
* if nq < use_blas_threshold, calculated by omp
Expand Down
7 changes: 7 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ union fp32_bits {
float as_value;
};

__attribute__((always_inline)) inline float
bf16_float(float f) {
auto u32 = fp32_bits{.as_value = f}.as_bits;
// Round off
return fp32_bits{.as_bits = (u32 + 0x8000) & 0xFFFF0000}.as_value;
}

inline float
fp32_from_bits(const uint32_t& w) {
return fp32_bits{.as_bits = w}.as_value;
Expand Down
12 changes: 12 additions & 0 deletions src/common/comp/knowhere_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
return simd_str;
}

void
KnowhereConfig::EnablePatchForComputeFP32AsBF16() {
LOG_KNOWHERE_INFO_ << "Enable patch for compute fp32 as bf16";
faiss::enable_patch_for_fp32_bf16();
}

void
KnowhereConfig::DisablePatchForComputeFP32AsBF16() {
LOG_KNOWHERE_INFO_ << "Disable patch for compute fp32 as bf16";
faiss::disable_patch_for_fp32_bf16();
}

void
KnowhereConfig::SetBlasThreshold(const int64_t use_blas_threshold) {
LOG_KNOWHERE_INFO_ << "Set faiss::distance_compute_blas_threshold to " << use_blas_threshold;
Expand Down
84 changes: 84 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cassert>

#include "faiss/impl/platform_macros.h"
#include "knowhere/operands.h"

namespace faiss {

Expand Down Expand Up @@ -54,6 +55,20 @@ fvec_inner_product_avx(const float* x, const float* y, size_t d) {
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
res += x[i] * bf16_float(y[i]);
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
Expand All @@ -69,6 +84,20 @@ fvec_L2sqr_avx(const float* x, const float* y, size_t d) {
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
const float tmp = x[i] - bf16_float(y[i]);
res += tmp * tmp;
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

float
fvec_L1_avx(const float* x, const float* y, size_t d) {
__m256 msum1 = _mm256_setzero_ps();
Expand Down Expand Up @@ -187,6 +216,32 @@ fvec_inner_product_batch_4_avx(const float* __restrict x, const float* __restric
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_inner_product_batch_4_avx_bf16_patch(const float* __restrict x, const float* __restrict y0,
const float* __restrict y1, const float* __restrict y2,
const float* __restrict y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
d0 += x[i] * bf16_float(y0[i]);
d1 += x[i] * bf16_float(y1[i]);
d2 += x[i] * bf16_float(y2[i]);
d3 += x[i] * bf16_float(y3[i]);
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
Expand Down Expand Up @@ -215,6 +270,35 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
const float q0 = x[i] - bf16_float(y0[i]);
const float q1 = x[i] - bf16_float(y1[i]);
const float q2 = x[i] - bf16_float(y2[i]);
const float q3 = x[i] - bf16_float(y3[i]);
d0 += q0 * q0;
d1 += q1 * q1;
d2 += q2 * q2;
d3 += q3 * q3;
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) {
Expand Down
15 changes: 15 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@ namespace faiss {
float
fvec_L2sqr_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d);

/// inner product
float
fvec_inner_product_avx(const float* x, const float* y, size_t d);

float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d);
/// L1 distance
float
fvec_L1_avx(const float* x, const float* y, size_t d);
Expand All @@ -40,10 +45,20 @@ void
fvec_inner_product_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

Expand Down
83 changes: 83 additions & 0 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <string>

#include "faiss/impl/platform_macros.h"
#include "knowhere/operands.h"

namespace faiss {

Expand Down Expand Up @@ -53,6 +54,19 @@ fvec_inner_product_avx512(const float* x, const float* y, size_t d) {
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_inner_product_avx512_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
res += x[i] * bf16_float(y[i]);
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
Expand All @@ -68,6 +82,20 @@ fvec_L2sqr_avx512(const float* x, const float* y, size_t d) {
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_L2sqr_avx512_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
const float tmp = x[i] - bf16_float(y[i]);
res += tmp * tmp;
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

float
fvec_L1_avx512(const float* x, const float* y, size_t d) {
__m512 msum0 = _mm512_setzero_ps();
Expand Down Expand Up @@ -214,6 +242,32 @@ fvec_inner_product_batch_4_avx512(const float* __restrict x, const float* __rest
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_inner_product_batch_4_avx512_bf16_patch(const float* __restrict x, const float* __restrict y0,
const float* __restrict y1, const float* __restrict y2,
const float* __restrict y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
d0 += x[i] * bf16_float(y0[i]);
d1 += x[i] * bf16_float(y1[i]);
d2 += x[i] * bf16_float(y2[i]);
d3 += x[i] * bf16_float(y3[i]);
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
Expand Down Expand Up @@ -242,6 +296,35 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
const float q0 = x[i] - bf16_float(y0[i]);
const float q1 = x[i] - bf16_float(y1[i]);
const float q2 = x[i] - bf16_float(y2[i]);
const float q3 = x[i] - bf16_float(y3[i]);
d0 += q0 * q0;
d1 += q1 * q1;
d2 += q2 * q2;
d3 += q3 * q3;
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) {
Expand Down
16 changes: 16 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ namespace faiss {
float
fvec_L2sqr_avx512(const float* x, const float* y, size_t d);

float
fvec_L2sqr_avx512_bf16_patch(const float* x, const float* y, size_t d);

/// inner product
float
fvec_inner_product_avx512(const float* x, const float* y, size_t d);

float
fvec_inner_product_avx512_bf16_patch(const float* x, const float* y, size_t d);

/// L1 distance
float
fvec_L1_avx512(const float* x, const float* y, size_t d);
Expand All @@ -39,10 +45,20 @@ void
fvec_inner_product_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_inner_product_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

Expand Down
Loading

0 comments on commit c6465cb

Please sign in to comment.