Skip to content

Commit

Permalink
Avx2: Introduce LutAvx2
Browse files Browse the repository at this point in the history
This avoids reloading the lookup table on every iteration of the inner
loop.
  • Loading branch information
AndersTrier committed Oct 10, 2024
1 parent a39dbea commit 38c5eb3
Showing 1 changed file with 92 additions and 46 deletions.
138 changes: 92 additions & 46 deletions src/engine/engine_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,73 +93,114 @@ impl Default for Avx2 {
//
//

impl Avx2 {
#[target_feature(enable = "avx2")]
unsafe fn mul_avx2(&self, x: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

for chunk in x.iter_mut() {
let x_ptr = chunk.as_mut_ptr() as *mut __m256i;
unsafe {
let x_lo = _mm256_loadu_si256(x_ptr);
let x_hi = _mm256_loadu_si256(x_ptr.add(1));
let (prod_lo, prod_hi) = Self::mul_256(x_lo, x_hi, lut);
_mm256_storeu_si256(x_ptr, prod_lo);
_mm256_storeu_si256(x_ptr.add(1), prod_hi);
}
}
}
#[derive(Copy, Clone)]
struct LutAvx2 {
t0_lo: __m256i,
t1_lo: __m256i,
t2_lo: __m256i,
t3_lo: __m256i,
t0_hi: __m256i,
t1_hi: __m256i,
t2_hi: __m256i,
t3_hi: __m256i,
}

// Impelemntation of LEO_MUL_256
impl From<&Multiply128lutT> for LutAvx2 {
#[inline(always)]
fn mul_256(value_lo: __m256i, value_hi: __m256i, lut: &Multiply128lutT) -> (__m256i, __m256i) {
let mut prod_lo: __m256i;
let mut prod_hi: __m256i;
fn from(lut: &Multiply128lutT) -> Self {
let t0_lo: __m256i;
let t1_lo: __m256i;
let t2_lo: __m256i;
let t3_lo: __m256i;
let t0_hi: __m256i;
let t1_hi: __m256i;
let t2_hi: __m256i;
let t3_hi: __m256i;

unsafe {
let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[0] as *const u128 as *const __m128i,
));
let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[1] as *const u128 as *const __m128i,
));
let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[2] as *const u128 as *const __m128i,
));
let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[3] as *const u128 as *const __m128i,
));

let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[0] as *const u128 as *const __m128i,
));
let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[1] as *const u128 as *const __m128i,
));
let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[2] as *const u128 as *const __m128i,
));
let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[3] as *const u128 as *const __m128i,
));
}

LutAvx2 {
t0_lo,
t1_lo,
t2_lo,
t3_lo,
t0_hi,
t1_hi,
t2_hi,
t3_hi,
}
}
}

impl Avx2 {
#[target_feature(enable = "avx2")]
unsafe fn mul_avx2(&self, x: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

let lut_avx2 = LutAvx2::from(lut);

for chunk in x.iter_mut() {
let x_ptr = chunk.as_mut_ptr() as *mut __m256i;
unsafe {
let x_lo = _mm256_loadu_si256(x_ptr);
let x_hi = _mm256_loadu_si256(x_ptr.add(1));
let (prod_lo, prod_hi) = Self::mul_256(x_lo, x_hi, lut_avx2);
_mm256_storeu_si256(x_ptr, prod_lo);
_mm256_storeu_si256(x_ptr.add(1), prod_hi);
}
}
}

// Impelemntation of LEO_MUL_256
#[inline(always)]
fn mul_256(value_lo: __m256i, value_hi: __m256i, lut_avx2: LutAvx2) -> (__m256i, __m256i) {
let mut prod_lo: __m256i;
let mut prod_hi: __m256i;

unsafe {
let clr_mask = _mm256_set1_epi8(0x0f);

let data_0 = _mm256_and_si256(value_lo, clr_mask);
prod_lo = _mm256_shuffle_epi8(t0_lo, data_0);
prod_hi = _mm256_shuffle_epi8(t0_hi, data_0);
prod_lo = _mm256_shuffle_epi8(lut_avx2.t0_lo, data_0);
prod_hi = _mm256_shuffle_epi8(lut_avx2.t0_hi, data_0);

let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_lo, 4), clr_mask);
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t1_lo, data_1));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t1_hi, data_1));
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t1_lo, data_1));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t1_hi, data_1));

let data_0 = _mm256_and_si256(value_hi, clr_mask);
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t2_lo, data_0));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t2_hi, data_0));
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t2_lo, data_0));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t2_hi, data_0));

let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_hi, 4), clr_mask);
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(t3_lo, data_1));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(t3_hi, data_1));
prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t3_lo, data_1));
prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t3_hi, data_1));
}

(prod_lo, prod_hi)
Expand All @@ -173,9 +214,9 @@ impl Avx2 {
mut x_hi: __m256i,
y_lo: __m256i,
y_hi: __m256i,
lut: &Multiply128lutT,
lut_avx2: LutAvx2,
) -> (__m256i, __m256i) {
let (prod_lo, prod_hi) = Self::mul_256(y_lo, y_hi, lut);
let (prod_lo, prod_hi) = Self::mul_256(y_lo, y_hi, lut_avx2);
unsafe {
x_lo = _mm256_xor_si256(x_lo, prod_lo);
x_hi = _mm256_xor_si256(x_hi, prod_hi);
Expand All @@ -190,18 +231,18 @@ impl Avx2 {
impl Avx2 {
// Implementation of LEO_FFTB_256
#[inline(always)]
fn fftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
fn fftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) {
let x_ptr = x.as_mut_ptr() as *mut __m256i;
let y_ptr = y.as_mut_ptr() as *mut __m256i;

unsafe {
let mut x_lo = _mm256_loadu_si256(x_ptr);
let mut x_hi = _mm256_loadu_si256(x_ptr.add(1));

let mut y_lo = _mm256_loadu_si256(y_ptr);
let mut y_hi = _mm256_loadu_si256(y_ptr.add(1));

(x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut);
(x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2);

_mm256_storeu_si256(x_ptr, x_lo);
_mm256_storeu_si256(x_ptr.add(1), x_hi);
Expand All @@ -217,8 +258,11 @@ impl Avx2 {
// Partial butterfly, caller must do `GF_MODULUS` check with `xor`.
#[inline(always)]
fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
let lut_avx2 = LutAvx2::from(lut);

for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
self.fftb_256(x_chunk, y_chunk, log_m);
self.fftb_256(x_chunk, y_chunk, lut_avx2);
}
}

Expand Down Expand Up @@ -331,8 +375,7 @@ impl Avx2 {
impl Avx2 {
// Implementation of LEO_IFFTB_256
#[inline(always)]
fn ifftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
fn ifftb_256(&self, x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) {
let x_ptr = x.as_mut_ptr() as *mut __m256i;
let y_ptr = y.as_mut_ptr() as *mut __m256i;

Expand All @@ -349,7 +392,7 @@ impl Avx2 {
_mm256_storeu_si256(y_ptr, y_lo);
_mm256_storeu_si256(y_ptr.add(1), y_hi);

(x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut);
(x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2);

_mm256_storeu_si256(x_ptr, x_lo);
_mm256_storeu_si256(x_ptr.add(1), x_hi);
Expand All @@ -358,8 +401,11 @@ impl Avx2 {

#[inline(always)]
fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];
let lut_avx2 = LutAvx2::from(lut);

for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
self.ifftb_256(x_chunk, y_chunk, log_m);
self.ifftb_256(x_chunk, y_chunk, lut_avx2);
}
}

Expand Down

0 comments on commit 38c5eb3

Please sign in to comment.