Skip to content

Commit

Permalink
fix: use u32 for rabitq quantized sum for DIM > 4369
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Sep 2, 2024
1 parent 4bac484 commit 3bea028
Show file tree
Hide file tree
Showing 9 changed files with 443 additions and 26 deletions.
10 changes: 10 additions & 0 deletions crates/base/src/scalar/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub fn sum_of_and(lhs: &[u64], rhs: &[u64]) -> u32 {
}

mod sum_of_and {
// FIXME: add manually-implemented SIMD version for AVX512 and AVX2

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4_avx512vpopcntdq")]
Expand Down Expand Up @@ -68,6 +70,8 @@ pub fn sum_of_or(lhs: &[u64], rhs: &[u64]) -> u32 {
}

mod sum_of_or {
// FIXME: add manually-implemented SIMD version for AVX512 and AVX2

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4_avx512vpopcntdq")]
Expand Down Expand Up @@ -132,6 +136,8 @@ pub fn sum_of_xor(lhs: &[u64], rhs: &[u64]) -> u32 {
}

mod sum_of_xor {
// FIXME: add manually-implemented SIMD version for AVX512 and AVX2

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4_avx512vpopcntdq")]
Expand Down Expand Up @@ -196,6 +202,8 @@ pub fn sum_of_and_or(lhs: &[u64], rhs: &[u64]) -> (u32, u32) {
}

mod sum_of_and_or {
// FIXME: add manually-implemented SIMD version for AVX512 and AVX2

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4_avx512vpopcntdq")]
Expand Down Expand Up @@ -268,6 +276,8 @@ pub fn sum_of_x(this: &[u64]) -> u32 {
}

mod sum_of_x {
// FIXME: add manually-implemented SIMD version for AVX512 and AVX2

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4_avx512vpopcntdq")]
Expand Down
38 changes: 38 additions & 0 deletions crates/base/src/scalar/emulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,44 @@ pub unsafe fn emulate_mm256_reduce_add_ps(mut x: std::arch::x86_64::__m256) -> f
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
pub unsafe fn emulate_mm512_reduce_add_epi16(x: std::arch::x86_64::__m512i) -> i16 {
unsafe {
use std::arch::x86_64::*;
_mm256_reduce_add_epi16(_mm512_castsi512_si256(x))
+ _mm256_reduce_add_epi16(_mm512_extracti32x8_epi32(x, 1))
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
pub unsafe fn emulate_mm256_reduce_add_epi16(mut x: std::arch::x86_64::__m256i) -> i16 {
unsafe {
use std::arch::x86_64::*;
x = _mm256_add_epi16(x, _mm256_permute2f128_si256(x, x, 1));
x = _mm256_hadd_epi16(x, x);
x = _mm256_hadd_epi16(x, x);
let x = _mm256_cvtsi256_si32(x);
(x as i16) + ((x >> 16) as i16)
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
pub unsafe fn emulate_mm256_reduce_add_epi32(mut x: std::arch::x86_64::__m256i) -> i32 {
unsafe {
use std::arch::x86_64::*;
x = _mm256_add_epi32(x, _mm256_permute2f128_si256(x, x, 1));
x = _mm256_hadd_epi32(x, x);
x = _mm256_hadd_epi32(x, x);
_mm256_cvtsi256_si32(x)
}
}

#[inline]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
Expand Down
11 changes: 11 additions & 0 deletions crates/base/src/scalar/f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ impl ScalarLike for f16 {
x
}

// FIXME: add manually-implemented SIMD version
#[inline(always)]
fn reduce_sum_of_abs_x(this: &[f16]) -> f32 {
let n = this.len();
let mut x = 0.0f32;
for i in 0..n {
x += this[i].to_f32().abs();
}
x
}

// FIXME: add manually-implemented SIMD version
#[inline(always)]
fn reduce_sum_of_x2(this: &[f16]) -> f32 {
Expand Down
161 changes: 148 additions & 13 deletions crates/base/src/scalar/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ impl ScalarLike for f32 {
reduce_sum_of_x::reduce_sum_of_x(this)
}

#[inline(always)]
fn reduce_sum_of_abs_x(this: &[f32]) -> f32 {
reduce_sum_of_abs_x::reduce_sum_of_abs_x(this)
}

#[inline(always)]
fn reduce_sum_of_x2(this: &[f32]) -> f32 {
reduce_sum_of_x2::reduce_sum_of_x2(this)
Expand Down Expand Up @@ -192,19 +197,19 @@ mod reduce_sum_of_x {
use std::arch::x86_64::*;
let mut n = this.len();
let mut a = this.as_ptr();
let mut sum_of_x = _mm512_setzero_ps();
let mut sum = _mm512_setzero_ps();
while n >= 16 {
let x = _mm512_loadu_ps(a);
a = a.add(16);
n -= 16;
sum_of_x = _mm512_add_ps(x, sum_of_x);
sum = _mm512_add_ps(x, sum);
}
if n > 0 {
let mask = _bzhi_u32(0xffff, n as u32) as u16;
let x = _mm512_maskz_loadu_ps(mask, a);
sum_of_x = _mm512_add_ps(x, sum_of_x);
sum = _mm512_add_ps(x, sum);
}
_mm512_reduce_add_ps(sum_of_x)
_mm512_reduce_add_ps(sum)
}
}

Expand Down Expand Up @@ -244,28 +249,28 @@ mod reduce_sum_of_x {
use std::arch::x86_64::*;
let mut n = this.len();
let mut a = this.as_ptr();
let mut sum_of_x = _mm256_setzero_ps();
let mut sum = _mm256_setzero_ps();
while n >= 8 {
let x = _mm256_loadu_ps(a);
a = a.add(8);
n -= 8;
sum_of_x = _mm256_add_ps(x, sum_of_x);
sum = _mm256_add_ps(x, sum);
}
if n >= 4 {
let x = _mm256_zextps128_ps256(_mm_loadu_ps(a));
a = a.add(4);
n -= 4;
sum_of_x = _mm256_add_ps(x, sum_of_x);
sum = _mm256_add_ps(x, sum);
}
let mut sum_of_x = emulate_mm256_reduce_add_ps(sum_of_x);
let mut sum = emulate_mm256_reduce_add_ps(sum);
// this hint is used to disable loop unrolling
while std::hint::black_box(n) > 0 {
let x = a.read();
a = a.add(1);
n -= 1;
sum_of_x += x;
sum += x;
}
sum_of_x
sum
}
}

Expand Down Expand Up @@ -300,11 +305,141 @@ mod reduce_sum_of_x {
#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn reduce_sum_of_x(this: &[f32]) -> f32 {
let n = this.len();
let mut sum_of_x = 0.0f32;
let mut sum = 0.0f32;
for i in 0..n {
sum += this[i];
}
sum
}
}

mod reduce_sum_of_abs_x {
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
unsafe fn reduce_sum_of_abs_x_v4(this: &[f32]) -> f32 {
unsafe {
use std::arch::x86_64::*;
let mut n = this.len();
let mut a = this.as_ptr();
let mut sum = _mm512_setzero_ps();
while n >= 16 {
let x = _mm512_loadu_ps(a);
let abs_x = _mm512_abs_ps(x);
a = a.add(16);
n -= 16;
sum = _mm512_add_ps(abs_x, sum);
}
if n > 0 {
let mask = _bzhi_u32(0xffff, n as u32) as u16;
let x = _mm512_maskz_loadu_ps(mask, a);
let abs_x = _mm512_abs_ps(x);
sum = _mm512_add_ps(abs_x, sum);
}
_mm512_reduce_add_ps(sum)
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn reduce_sum_of_abs_x_v4_test() {
use rand::Rng;
const EPSILON: f32 = 0.008;
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
let mut rng = rand::thread_rng();
for _ in 0..256 {
let n = 4016;
let this = (0..n)
.map(|_| rng.gen_range(-1.0..=1.0))
.collect::<Vec<_>>();
for z in 3984..4016 {
let this = &this[..z];
let specialized = unsafe { reduce_sum_of_abs_x_v4(&this) };
let fallback = unsafe { reduce_sum_of_abs_x_fallback(&this) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}

#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
unsafe fn reduce_sum_of_abs_x_v3(this: &[f32]) -> f32 {
use crate::scalar::emulate::emulate_mm256_reduce_add_ps;
unsafe {
use std::arch::x86_64::*;
let abs = _mm256_castsi256_ps(_mm256_srli_epi32(_mm256_set1_epi32(-1), 1));
let mut n = this.len();
let mut a = this.as_ptr();
let mut sum = _mm256_setzero_ps();
while n >= 8 {
let x = _mm256_loadu_ps(a);
let abs_x = _mm256_and_ps(abs, x);
a = a.add(8);
n -= 8;
sum = _mm256_add_ps(abs_x, sum);
}
if n >= 4 {
let x = _mm256_zextps128_ps256(_mm_loadu_ps(a));
let abs_x = _mm256_and_ps(abs, x);
a = a.add(4);
n -= 4;
sum = _mm256_add_ps(abs_x, sum);
}
let mut sum = emulate_mm256_reduce_add_ps(sum);
// this hint is used to disable loop unrolling
while std::hint::black_box(n) > 0 {
let x = a.read();
let abs_x = x.abs();
a = a.add(1);
n -= 1;
sum += abs_x;
}
sum
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn reduce_sum_of_abs_x_v3_test() {
use rand::Rng;
const EPSILON: f32 = 0.008;
detect::init();
if !detect::v3::detect() {
println!("test {} ... skipped (v3)", module_path!());
return;
}
let mut rng = rand::thread_rng();
for _ in 0..256 {
let n = 4016;
let this = (0..n)
.map(|_| rng.gen_range(-1.0..=1.0))
.collect::<Vec<_>>();
for z in 3984..4016 {
let this = &this[..z];
let specialized = unsafe { reduce_sum_of_abs_x_v3(this) };
let fallback = unsafe { reduce_sum_of_abs_x_fallback(this) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}

#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn reduce_sum_of_abs_x(this: &[f32]) -> f32 {
let n = this.len();
let mut sum = 0.0f32;
for i in 0..n {
sum_of_x += this[i];
sum += this[i].abs();
}
sum_of_x
sum
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/base/src/scalar/impossible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ impl ScalarLike for Impossible {
unimplemented!()
}

fn reduce_sum_of_abs_x(_lhs: &[Self]) -> f32 {
unimplemented!()
}

fn reduce_sum_of_x2(_this: &[Self]) -> f32 {
unimplemented!()
}
Expand Down
1 change: 1 addition & 0 deletions crates/base/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub trait ScalarLike:
fn to_f32(self) -> f32;

fn reduce_sum_of_x(lhs: &[Self]) -> f32;
fn reduce_sum_of_abs_x(lhs: &[Self]) -> f32;
fn reduce_sum_of_x2(this: &[Self]) -> f32;
fn reduce_min_max_of_x(this: &[Self]) -> (f32, f32);

Expand Down
Loading

0 comments on commit 3bea028

Please sign in to comment.