Skip to content

Commit

Permalink
Transfer from Vc to xsimd
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 5, 2023
1 parent b460e43 commit 4ef621c
Show file tree
Hide file tree
Showing 11 changed files with 317 additions and 90 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
[submodule "librapid/vendor/CLBlast"]
path = librapid/vendor/CLBlast
url = https://github.com/CNugteren/CLBlast.git
[submodule "librapid/vendor/xsimd"]
path = librapid/vendor/xsimd
url = https://github.com/xtensor-stack/xsimd.git
34 changes: 16 additions & 18 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -415,19 +415,14 @@ endif ()

# Add dependencies
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/fmt")
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/Vc")

if (NOT MINGW)
# scnlib does not support MinGW, since it does not implement std::from_chars, which is required by the library
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/scnlib")
else ()
message(WARNING "[ LIBRAPID ] scnlib cannot be built by MinGW, so it will not be enabled")
target_compile_definitions(${module_name} PUBLIC LIBRAPID_MINGW)
endif ()
# add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/Vc")
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/xsimd")
add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/scnlib")

target_compile_definitions(fmt PUBLIC FMT_HEADER_ONLY)
target_compile_definitions(Vc PRIVATE Vc_HACK_OSTREAM_FOR_TTY)
target_link_libraries(${module_name} PUBLIC fmt scn Vc)
# target_compile_definitions(Vc PRIVATE Vc_HACK_OSTREAM_FOR_TTY)
# target_link_libraries(${module_name} PUBLIC fmt scn Vc xsimd)
target_link_libraries(${module_name} PUBLIC fmt scn xsimd)

if (${LIBRAPID_USE_MULTIPREC})
# Load MPIR
Expand Down Expand Up @@ -484,15 +479,18 @@ if (LIBRAPID_FAST_MATH)
target_compile_definitions(${module_name} PUBLIC LIBRAPID_FAST_MATH)
endif ()

set(LIBRAPID_ARCH_FLAGS)
if (LIBRAPID_NATIVE_ARCH)
message(STATUS "[ LIBRAPID ] Compiling for native architecture")
OptimizeForArchitecture()
target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS})
target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH)
set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS})
message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}")
message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}")

include(ArchDetect2)
target_compile_options(${module_name} PUBLIC ${LIBRAPID_ARCH_FLAGS})

# OptimizeForArchitecture()
# target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS})
# target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH)
# set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS})
# message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}")
# message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}")
endif ()

# Add defines for CUDA vector widths
Expand Down
243 changes: 243 additions & 0 deletions cmake/ArchDetect2.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
INCLUDE(CheckCXXSourceRuns)

set(COMPILER_GNU false)
set(COMPILER_INTEL false)
set(COMPILER_CLANG false)
set(COMPILER_MSVC false)

if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(COMPILER_GNU true)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
set(COMPILER_INTEL true)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(COMPILER_CLANG true)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
set(COMPILER_MSVC true)
else ()
# Unknown Compiler
endif ()

set(LIBRAPID_ARCH_FLAGS)
set(LIBRAPID_ARCH_FOUND)

# Function to test a given SIMD capability
function(check_simd_capability FLAG_GNU FLAG_MSVC NAME TEST_SOURCE VAR)
set(CMAKE_REQUIRED_FLAGS)
if (COMPILER_GNU OR COMPILER_INTEL OR COMPILER_CLANG)
set(CMAKE_REQUIRED_FLAGS "${FLAG_GNU}")
elseif (COMPILER_MSVC) # reserve for WINDOWS
set(CMAKE_REQUIRED_FLAGS "${FLAG_MSVC}")
endif ()

CHECK_CXX_SOURCE_RUNS("${TEST_SOURCE}" ${VAR})

if (${${VAR}})
if (COMPILER_GNU OR COMPILER_INTEL OR COMPILER_CLANG)
# set(LIBRAPID_ARCH_FLAGS "${LIBRAPID_ARCH_FLAGS} ${FLAG_GNU}" PARENT_SCOPE)

list(APPEND LIBRAPID_ARCH_FLAGS ${FLAG_GNU})
set(LIBRAPID_ARCH_FLAGS ${LIBRAPID_ARCH_FLAGS} PARENT_SCOPE)

message(STATUS "[ LIBRAPID ] ${NAME} found: ${FLAG_GNU}")
elseif (MSVC)
# set(LIBRAPID_ARCH_FLAGS "${LIBRAPID_ARCH_FLAGS} ${FLAG_MSVC}" PARENT_SCOPE)

list(APPEND LIBRAPID_ARCH_FLAGS ${FLAG_MSVC})
set(LIBRAPID_ARCH_FLAGS ${LIBRAPID_ARCH_FLAGS} PARENT_SCOPE)

message(STATUS "[ LIBRAPID ] ${NAME} found: ${FLAG_MSVC}")
endif ()
set(LIBRAPID_ARCH_FOUND TRUE PARENT_SCOPE)
else ()
message(STATUS "[ LIBRAPID ] ${NAME} not found")
endif ()
endfunction()

# Check SSE2 (not a valid flag for MSVC)
check_simd_capability("-msse2" "" "SSE2" "
#include <emmintrin.h>
int main() {
__m128i a = _mm_set_epi32 (-1, 2, -3, 4);
__m128i result = _mm_abs_epi32 (a);
return 0;
}" SIMD_SSE2)

# Check SSE3 (not a valid flag for MSVC)
check_simd_capability("-msse3" "" "SSE3" "
#include <pmmintrin.h>
int main() {
__m128 a = _mm_set_ps (-1.0f, 2.0f, -3.0f, 4.0f);
__m128 b = _mm_set_ps (1.0f, 2.0f, 3.0f, 4.0f);
__m128 result = _mm_addsub_ps (a, b);
return 0;
}" SIMD_SSE3)

# Check SSSE3 (not a valid flag for MSVC)
check_simd_capability("-mssse3" "" "SSSE3" "
#include <tmmintrin.h>
int main() {
__m128i a = _mm_set_epi8(-1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4);
__m128i result = _mm_abs_epi8(a);
return 0;
}" SIMD_SSSE3)

# Check SSE4.1 (not a valid flag for MSVC)
check_simd_capability("-msse4.1" "" "SSE4.1" "
#include <smmintrin.h>
int main() {
__m128i a = _mm_set_epi32(-1, 2, -3, 4);
__m128i result = _mm_abs_epi32(a);
return 0;
}" SIMD_SSE4_1)

# Check SSE4.2 (not a valid flag for MSVC)
check_simd_capability("-msse4.2" "" "SSE4.2" "
#include <nmmintrin.h>
int main() {
__m128i a = _mm_set_epi32(-1, 2, -3, 4);
__m128i result = _mm_abs_epi32(a);
return 0;
}" SIMD_SSE4_2)

# Check AVX
check_simd_capability("-mavx" "/arch:AVX" "AVX" "
#include <immintrin.h>
int main() {
__m256 a = _mm256_set_ps(-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f);
__m256 result = _mm256_abs_ps(a);
return 0;
}" SIMD_AVX)

# Check AVX2
check_simd_capability("-mavx2" "/arch:AVX2" "AVX2" "
#include <immintrin.h>
int main() {
__m256i a = _mm256_set_epi32(-1, 2, -3, 4, -1, 2, -3, 4);
__m256i result = _mm256_abs_epi32(a);
return 0;
}" SIMD_AVX2)

# Check AVX512F
check_simd_capability("-mavx512f" "/arch:AVX512" "AVX512F" "
#include <immintrin.h>
int main() {
__m512i a = _mm512_set_epi32(-1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4);
__m512i result = _mm512_abs_epi32(a);
return 0;
}" SIMD_AVX512F)

# Check AVX512BW
check_simd_capability("-mavx512bw" "/arch:AVX512" "AVX512BW" "
#include <immintrin.h>
int main() {
__m512i a = _mm512_set_epi64(-1, 2, -3, 4, -1, 2, -3, 4);
__m512i result = _mm512_abs_epi8(a);
return 0;
}" SIMD_AVX512BW)

# Check AVX512CD
check_simd_capability("-mavx512cd" "/arch:AVX512" "AVX512CD" "
#include <immintrin.h>
int main() {
__m512i a = _mm512_set_epi64(-1, 2, -3, 4, -1, 2, -3, 4);
__m512i result = _mm512_conflict_epi64(a);
return 0;
}" SIMD_AVX512CD)

# Check AVX512DQ
check_simd_capability("-mavx512dq" "/arch:AVX512" "AVX512DQ" "
#include <immintrin.h>
int main() {
__m512d a = _mm512_set_pd(-1.0, 2.0, -3.0, 4.0, -1.0, 2.0, -3.0, 4.0);
__m512d result = _mm512_abs_pd(a);
return 0;
}" SIMD_AVX512DQ)

# Check AVX512ER
check_simd_capability("-mavx512er" "/arch:AVX512" "AVX512ER" "
#include <immintrin.h>
int main() {
__m512d a = _mm512_set_pd(-1.0, 2.0, -3.0, 4.0, -1.0, 2.0, -3.0, 4.0);
__m512d result = _mm512_exp_pd(a);
return 0;
}" SIMD_AVX512ER)

# Check AVX512PF
check_simd_capability("-mavx512pf" "/arch:AVX512" "AVX512PF" "
#include <immintrin.h>
int main() {
__m512 a = _mm512_set_ps(-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f);
__m512 result = _mm512_exp_ps(a);
return 0;
}" SIMD_AVX512PF)

# ARM
check_simd_capability("-march=armv7-a" "" "ARMv7" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv7)

check_simd_capability("-march=armv8-a" "" "ARMv8" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8)

# ARM64
check_simd_capability("-march=armv8.1-a" "" "ARMv8.1" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8_1)

check_simd_capability("-march=armv8.2-a" "" "ARMv8.2" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8_2)

check_simd_capability("-march=armv8.3-a" "" "ARMv8.3" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8_3)

check_simd_capability("-march=armv8.4-a" "" "ARMv8.4" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8_4)

check_simd_capability("-march=armv8.5-a" "" "ARMv8.5" "
#include <arm_neon.h>
int main() {
int32x4_t a = vdupq_n_s32(1);
int32x4_t b = vdupq_n_s32(2);
int32x4_t result = vaddq_s32(a, b);
return 0;
}" SIMD_ARMv8_5)

if (LIBRAPID_ARCH_FOUND)
message(STATUS "[ LIBRAPID ] Architecture Flags: ${LIBRAPID_ARCH_FLAGS}")
else()
message(STATUS "[ LIBRAPID ] Architecture Flags Not Found")
endif()
6 changes: 2 additions & 4 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,7 @@ namespace librapid {

template<typename ShapeType_, typename StorageType_>
auto ArrayContainer<ShapeType_, StorageType_>::packet(size_t index) const -> Packet {
Packet res;
res.load(m_storage.begin() + index);
return res;
return xsimd::load_aligned(m_storage.begin() + index);
}

template<typename ShapeType_, typename StorageType_>
Expand All @@ -681,7 +679,7 @@ namespace librapid {
template<typename ShapeType_, typename StorageType_>
void ArrayContainer<ShapeType_, StorageType_>::writePacket(size_t index,
const Packet &value) {
value.store(m_storage.begin() + index);
value.store_aligned(m_storage.begin() + index);
}

template<typename ShapeType_, typename StorageType_>
Expand Down
13 changes: 8 additions & 5 deletions librapid/include/librapid/array/storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,14 @@ namespace librapid {
/// \param newSize New size of the Storage object
LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize);

#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE)
alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr;
#else
Pointer m_begin = nullptr; // Pointer to the beginning of the data
#endif
//#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE)
// alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr;
//#else
// Pointer m_begin = nullptr; // Pointer to the beginning of the data
//#endif

Pointer m_begin = nullptr;

SizeType m_size = 0; // Number of elements in the Storage object
bool m_ownsData = true; // Whether this Storage object owns the data it points to
};
Expand Down
Loading

0 comments on commit 4ef621c

Please sign in to comment.