Skip to content

Commit

Permalink
Optimize AVX-512BW implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kalcutter committed Jan 22, 2024
1 parent bc27d95 commit 7be5b20
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/bitshuffle.c
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,17 @@ static ALWAYS_INLINE __mmask64 MM_CVTSI128_MASK64(__m128i a) {
#endif
}

#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
// https://github.com/llvm/llvm-project/issues/65205
ATTRIBUTE_TARGET("avx512bw,avx512vl")
static ALWAYS_INLINE __m512i MM512_MASK_ADD_EPI8(__m512i src, __mmask64 k, __m512i a, __m512i b) {
__asm__("vpaddb\t{%3, %2, %0 %{%1%}|%0 %{%1%}, %2, %3}" : "+v"(src) : "Yk"(k), "v"(a), "v"(b));
return src;
}
#else
#define MM512_MASK_ADD_EPI8(SRC, K, A, B) _mm512_mask_add_epi8(SRC, K, A, B)
#endif

NO_INLINE
ATTRIBUTE_TARGET("avx512bw,avx512vl")
static void bitshuf_untrans_bit_avx512bw(char* restrict out, const char* restrict in, size_t size) {
Expand All @@ -1044,14 +1055,22 @@ static void bitshuf_untrans_bit_avx512bw(char* restrict out, const char* restric
const __m512i C7 = _mm512_set1_epi8(-128);
size_t i = 0;
for (; i + 8 <= size; i += 8) {
__m512i u = _mm512_maskz_mov_epi8(LOAD_MASK64(&in[0 * size + i]), C0);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[1 * size + i]), u, C1);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[2 * size + i]), u, C2);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[3 * size + i]), u, C3);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[4 * size + i]), u, C4);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[5 * size + i]), u, C5);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[6 * size + i]), u, C6);
u = _mm512_mask_add_epi8(X(u), LOAD_MASK64(&in[7 * size + i]), u, C7);
const __mmask64 a0 = LOAD_MASK64(&in[0 * size + i]);
__m512i u = _mm512_maskz_mov_epi8(a0, C0);
const __mmask64 a1 = LOAD_MASK64(&in[1 * size + i]);
const __mmask64 a2 = LOAD_MASK64(&in[2 * size + i]);
const __mmask64 a3 = LOAD_MASK64(&in[3 * size + i]);
const __mmask64 a4 = LOAD_MASK64(&in[4 * size + i]);
const __mmask64 a5 = LOAD_MASK64(&in[5 * size + i]);
const __mmask64 a6 = LOAD_MASK64(&in[6 * size + i]);
const __mmask64 a7 = LOAD_MASK64(&in[7 * size + i]);
u = MM512_MASK_ADD_EPI8(u, a1, u, C1);
u = MM512_MASK_ADD_EPI8(u, a2, u, C2);
u = MM512_MASK_ADD_EPI8(u, a3, u, C3);
u = MM512_MASK_ADD_EPI8(u, a4, u, C4);
u = MM512_MASK_ADD_EPI8(u, a5, u, C5);
u = MM512_MASK_ADD_EPI8(u, a6, u, C6);
u = MM512_MASK_ADD_EPI8(u, a7, u, C7);
_mm512_storeu_si512(&out[i * 8], u);
}
if (i < size) {
Expand All @@ -1065,13 +1084,13 @@ static void bitshuf_untrans_bit_avx512bw(char* restrict out, const char* restric
const __mmask64 a6 = MM_CVTSI128_MASK64(_mm_maskz_loadu_epi8(k, &in[6 * size + i]));
const __mmask64 a7 = MM_CVTSI128_MASK64(_mm_maskz_loadu_epi8(k, &in[7 * size + i]));
__m512i u = _mm512_maskz_mov_epi8(a0, C0);
u = _mm512_mask_add_epi8(X(u), a1, u, C1);
u = _mm512_mask_add_epi8(X(u), a2, u, C2);
u = _mm512_mask_add_epi8(X(u), a3, u, C3);
u = _mm512_mask_add_epi8(X(u), a4, u, C4);
u = _mm512_mask_add_epi8(X(u), a5, u, C5);
u = _mm512_mask_add_epi8(X(u), a6, u, C6);
u = _mm512_mask_add_epi8(X(u), a7, u, C7);
u = MM512_MASK_ADD_EPI8(u, a1, u, C1);
u = MM512_MASK_ADD_EPI8(u, a2, u, C2);
u = MM512_MASK_ADD_EPI8(u, a3, u, C3);
u = MM512_MASK_ADD_EPI8(u, a4, u, C4);
u = MM512_MASK_ADD_EPI8(u, a5, u, C5);
u = MM512_MASK_ADD_EPI8(u, a6, u, C6);
u = MM512_MASK_ADD_EPI8(u, a7, u, C7);
_mm512_mask_storeu_epi64(&out[i * 8], k, u);
}
}
Expand Down

0 comments on commit 7be5b20

Please sign in to comment.