diff --git a/bangc_helper_dtype.h b/bangc_helper_dtype.h new file mode 100644 index 000000000..2e005110b --- /dev/null +++ b/bangc_helper_dtype.h @@ -0,0 +1,113 @@ +/************************************************************************* + * Copyright (C) [2020] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#pragma once + +/** + * Provides `BANG_WRAP_T(ptr_arg)` for .cc and `BANG_UNWRAP_T(ptr_arg)` for .mlu + * to bridge Eigen:: type and BANGC type + */ + +#include + +struct bang_half_t; +struct bang_bfloat16_t; + +namespace detail { +/* + * `bang_wrap_data` and `bang_unwrap_data` could be the same thing, + * but should be used in different scope + * + * handle 'const DType', 'Dtype *', + * could be implemented by SFINAE or just specialization + */ +template class Impl, + typename RawType = DType> +struct bang_trans_impl_ { + static_assert(std::is_same_v); + typedef DType type; +}; + +template class Impl, typename RawType> +struct bang_trans_impl_ { + typedef typename Impl::type* type; +}; + +template class Impl, typename RawType> +struct bang_trans_impl_ { + typedef const typename Impl::type type; +}; + +} // namespace detail + +#define BANG_TRANS_TYPE_FROM_TO(TOKEN, From, To) \ + template <> \ + struct TOKEN { \ + typedef To type; \ + } + +/* For .cc/.cpp trans unknown type to wrapped type */ +#if !defined(__BANG__) + +namespace Eigen { +struct half; +struct bfloat16; +} // namespace Eigen + +template +struct bang_wrap_data { + using type = typename detail::bang_trans_impl_::type; +}; + +#define BANG_WRAP_TYPE_FROM_TO(From, To) \ + BANG_TRANS_TYPE_FROM_TO(bang_wrap_data, From, To) + +BANG_WRAP_TYPE_FROM_TO(Eigen::half, bang_half_t); +BANG_WRAP_TYPE_FROM_TO(Eigen::bfloat16, bang_bfloat16_t); + +template +using bang_wrap_data_t = typename bang_wrap_data::type; + +#define BANG_WRAP_T(a) reinterpret_cast>(a) + +#endif // !defined(__BANG__) + +/* For .mlu trans intermediate type to mlu's underlying type */ + +#if __BANG__ +template +struct bang_unwrap_data { + using type = typename detail::bang_trans_impl_::type; +}; + +#define BANG_UNWRAP_TYPE_FROM_TO(From, To) \ + BANG_TRANS_TYPE_FROM_TO(bang_unwrap_data, From, To) + +BANG_UNWRAP_TYPE_FROM_TO(bang_half_t, half); +BANG_UNWRAP_TYPE_FROM_TO(bang_bfloat16_t, bfloat16_t); + +template +using bang_unwrap_data_t = typename bang_unwrap_data::type; + +#define BANG_UNWRAP_T(a) reinterpret_cast>(a) + +#endif // __BANG__ diff --git a/bangc_kernels.h b/bangc_kernels.h new file mode 100644 index 000000000..5ffbcde31 --- /dev/null +++ b/bangc_kernels.h @@ -0,0 +1,95 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef BANGC_KERNELS_H_ +#define BANGC_KERNELS_H_ + +#ifndef NAMESPACE_BANGC_KERNELS_GEGIN +#define NAMESPACE_BANGC_KERNELS_GEGIN namespace bangc_kernels { +#endif + +NAMESPACE_BANGC_KERNELS_GEGIN + +#ifndef BANGC_KERNELS_WIN_API +#ifdef _WIN32 +#define BANGC_KERNELS_WIN_API __stdcall +#else +#define BANGC_KERNELS_WIN_API +#endif +#endif + +typedef enum { + BANGC_KERNELS_STATUS_SUCCESS = + 0, /*!< The operation is successfully completed. */ + BANGC_KERNELS_STATUS_ALLOC_FAILED = 1, + /*!< This error occurs when the resource allocation fails, which is usually + caused by failing to call cnMallocHost due to exceeded memory usage. Make + sure that the memory allocated previously is deallocated as much as + possible. */ + BANGC_KERNELS_STATUS_BAD_PARAM = 2, + /*!< Invalid value or parameters are passed to the function, including data + type, layout, dimensions, etc. */ + BANGC_KERNELS_STATUS_INTERNAL_ERROR = 3, + /*!< An error occurs inside of the function, which may indicate an internal + error or bug in the library. This error is usually caused by failing to + call cnrtMemcpyAsync. Check whether the memory passed to the function is + deallocated before the completion of the routine. */ + BANGC_KERNELS_STATUS_ARCH_MISMATCH = 4, + /*!< Invalid MLU device which is not supported by current function. */ + BANGC_KERNELS_STATUS_EXECUTION_FAILED = 5, + /*!< An error occurs when the function fails to be executed on MLU device due + to multiple reasons. You can check whether the hardware environment, driver + version and other prerequisite libraries are correctly installed. */ + BANGC_KERNELS_STATUS_NOT_SUPPORTED = 6, + /*!< An error occurs when the requested functionality is not supported in this + version but would be supported in the future. */ + BANGC_KERNELS_STATUS_NUMERICAL_OVERFLOW = 7, + /*!< A numerical overflow occurs when executing the function, which is usually + due to large scale or inappropriate range of value of input tensor. */ +} bangcKernelsStatus_t; + +template +bangcKernelsStatus_t BANGC_KERNELS_WIN_API +mluApplyAdamW(const cnrtQueue_t queue, + const float lr, + const float beta1, + const float beta2, + const float bias1, + const float bias2, + const float epsilon, + const float weight_decay, + const float scale, + const bool use_nesterov, + const size_t size, + T *param_h, + T *grad, + void *param, + void *momentum, + void *velocity); + +#ifndef NAMESPACE_BANGC_KERNELS_END +#define NAMESPACE_BANGC_KERNELS_END } +#endif + +NAMESPACE_BANGC_KERNELS_END + +#endif // BANGC_KERNELS_H_ diff --git a/independent_build.sh b/independent_build.sh index a601c5577..6a8ce17c7 100755 --- a/independent_build.sh +++ b/independent_build.sh @@ -8,6 +8,7 @@ MLUOP_TARGET_CPU_ARCH=`uname -m` GEN_SYMBOL_VIS_FILE_PY="./scripts/gen_symbol_visibility_map.py" MLUOP_SYMBOL_VIS_FILE="symbol_visibility.map" TARGET_SYMBOL_FILE="mlu_op.h" +TARGET_SYMBOL_FILE_LITE="bangc_kernels.h" PACKAGE_EXTRACT_DIR="dep_libs_extract" PROG_NAME=$(basename $0) # current script filename, DO NOT EDIT @@ -421,7 +422,7 @@ if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then fi fi -if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" < "5" ]]; then +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then prog_log_note "we do not support g++<5, try to activate devtoolset-8 env" source /opt/rh/devtoolset-8/enable && prog_log_warn "devtoolset-8 activated" \ || ( prog_log_warn "source devtoolset-8 failed, ignore this info if you have set env TOOLCHAIN_ROOT, TARGET_C_COMPILER, TARGET_CXX_COMPILER properly (see more details in README.md)" && sleep 4 ) # I hope user will see it @@ -459,8 +460,8 @@ export PATH=${NEUWARE_HOME}/bin:$PATH export LD_LIBRARY_PATH=${NEUWARE_HOME}/lib64:$LD_LIBRARY_PATH prog_log_info "generate ${MLUOP_SYMBOL_VIS_FILE} file." -prog_log_info "python3 ${GEN_SYMBOL_VIS_FILE_PY} ${BUILD_PATH}/${MLUOP_SYMBOL_VIS_FILE} ${TARGET_SYMBOL_FILE}" -python3 ${GEN_SYMBOL_VIS_FILE_PY} ${BUILD_PATH}/${MLUOP_SYMBOL_VIS_FILE} ${TARGET_SYMBOL_FILE} +prog_log_info "python3 ${GEN_SYMBOL_VIS_FILE_PY} ${BUILD_PATH}/${MLUOP_SYMBOL_VIS_FILE} ${TARGET_SYMBOL_FILE} ${TARGET_SYMBOL_FILE_LITE}" +python3 ${GEN_SYMBOL_VIS_FILE_PY} ${BUILD_PATH}/${MLUOP_SYMBOL_VIS_FILE} ${TARGET_SYMBOL_FILE} ${TARGET_SYMBOL_FILE_LITE} pushd ${BUILD_PATH} > /dev/null prog_log_info "Rmove cmake cache ${PWD}" diff --git a/kernels/adam_w/adam_w_union1.mlu b/kernels/adam_w/adam_w_union1.mlu index db9cc672a..90512ce6c 100644 --- a/kernels/adam_w/adam_w_union1.mlu +++ b/kernels/adam_w/adam_w_union1.mlu @@ -21,12 +21,15 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ +#include "adam_w.h" + #include -#include -#include -#include "core/logging.h" -#include "kernels/adam_w/adam_w.h" -#include "kernels/utils/common.h" + +#include "bangc_helper_dtype.h" +#include "bangc_kernels.h" + + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define SIZE_NRAM_PER_REGION PAD_DOWN((MAX_NRAM_SIZE / 12), NFU_ALIGN_SIZE) #define HIGH_PRECISION_MODE 1 @@ -193,7 +196,7 @@ __mlu_global__ void unionApplyAdamW(T *param_h, T *grad, float *param, ddr_param - 2 * num_x * param_flag, ddr_momentum - 2 * num_x, ddr_velocity - 2 * num_x, nbuf_paramh, nbuf_param, nbuf_momentum, nbuf_velocity, - std::min(num_x, (int)(num_task - (i - 2) * num_x)), + MIN(num_x, (int)(num_task - (i - 2) * num_x)), (i - 2) % 2 * pong); } // load data @@ -201,15 +204,14 @@ __mlu_global__ void unionApplyAdamW(T *param_h, T *grad, float *param, loadData(nbuf_paramh, (T *)(nbuf_grad + pong / 2), nbuf_param, nbuf_momentum, nbuf_velocity, ddr_paramh, ddr_grad, ddr_param, ddr_momentum, ddr_velocity, - std::min(num_x, (int)(num_task - i * num_x)), i % 2 * pong); + MIN(num_x, (int)(num_task - i * num_x)), i % 2 * pong); } // compute if (i >= 1 && i <= num_iter) { computeAdamW(nbuf_paramh, (T *)(nbuf_grad + pong / 2), nbuf_param, nbuf_grad, nbuf_momentum, nbuf_velocity, temp_1, temp_2, lr, beta1, beta2, bias1, bias2, epsilon, weight_decay, scale, - use_nesterov, - std::min(num_x, (int)(num_task - (i - 1) * num_x)), + use_nesterov, MIN(num_x, (int)(num_task - (i - 1) * num_x)), (i - 1) % 2 * pong, param_flag); } ddr_paramh += num_x * paramh_flag; @@ -228,16 +230,48 @@ mluOpStatus_t MLUOP_WIN_API KernelApplyAdamW( void *momentum, void *velocity, float lr, float beta1, float beta2, float bias1, float bias2, float epsilon, float weight_decay, float scale, bool use_nesterov, size_t size, mluOpDataType_t k_data_type) { - switch (k_data_type) { - default: { - LOG(ERROR) << "Not Implemented."; - } - case MLUOP_DTYPE_BFLOAT16: { - KERNEL_CHECK(unionApplyAdamW<<>>( - (bfloat16_t *)param_h, (bfloat16_t *)grad, (float *)param, - (float *)momentum, (float *)velocity, lr, beta1, beta2, bias1, bias2, - epsilon, weight_decay, scale, use_nesterov, size)); - }; break; - } + // launch kernel + unionApplyAdamW<<>>( + (bfloat16_t *)param_h, (bfloat16_t *)grad, (float *)param, + (float *)momentum, (float *)velocity, lr, beta1, beta2, bias1, bias2, + epsilon, weight_decay, scale, use_nesterov, size); return MLUOP_STATUS_SUCCESS; } + +NAMESPACE_BANGC_KERNELS_GEGIN + +template +bangcKernelsStatus_t BANGC_KERNELS_WIN_API +mluApplyAdamW(const cnrtQueue_t queue, const float lr, const float beta1, + const float beta2, const float bias1, const float bias2, + const float epsilon, const float weight_decay, const float scale, + const bool use_nesterov, size_t size, T *param_h, T *grad, + void *param, void *momentum, void *velocity) { + // set job type + int ordinal = -1; + int cluster_num; + int core_dim; + cnrtGetDevice(&ordinal); + cnrtDeviceGetAttribute(&core_dim, cnrtAttrMcorePerCluster, ordinal); + cnrtDeviceGetAttribute(&cluster_num, cnrtAttrMaxClusterPerUnionLimitTask, + ordinal); + cnrtFunctionType_t k_type = cnrtFuncTypeUnion1; + cnrtDim3_t k_dim{.x = (uint32_t)core_dim, .y = (uint32_t)cluster_num, .z = 1}; + + // launch kernel + unionApplyAdamW<<>>( + BANG_UNWRAP_T(param_h), BANG_UNWRAP_T(grad), (float *)param, + (float *)momentum, (float *)velocity, lr, beta1, beta2, bias1, bias2, + epsilon, weight_decay, scale, use_nesterov, size); + return BANGC_KERNELS_STATUS_SUCCESS; +} + +#define IMPL_MLU_APPLY_ADAMW_KERNEL(DType) \ + template bangcKernelsStatus_t BANGC_KERNELS_WIN_API mluApplyAdamW( \ + const cnrtQueue_t, const float, const float, const float, const float, \ + const float, const float, const float, const float, const bool, \ + const size_t, DType *, DType *, void *, void *, void *) + +IMPL_MLU_APPLY_ADAMW_KERNEL(bang_bfloat16_t); + +NAMESPACE_BANGC_KERNELS_END diff --git a/mlu_op.h b/mlu_op.h index 0dae7b1a4..6aeb83633 100644 --- a/mlu_op.h +++ b/mlu_op.h @@ -20,8 +20,8 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#ifndef MLUOP_EXAMPLE_H_ -#define MLUOP_EXAMPLE_H_ +#ifndef MLUOP_H_ +#define MLUOP_H_ /****************************************************************************** * MLU-OPS: Cambricon Open Source operator library for Network @@ -14526,4 +14526,4 @@ mluOpLgamma(mluOpHandle_t handle, } #endif -#endif // MLUOP_EXAMPLE_H_ +#endif // MLUOP_H_ diff --git a/samples/mlu-ops/abs_sample/build.sh b/samples/mlu-ops/abs_sample/build.sh index 459b22600..d6f6fec67 100755 --- a/samples/mlu-ops/abs_sample/build.sh +++ b/samples/mlu-ops/abs_sample/build.sh @@ -9,7 +9,7 @@ if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then fi fi -if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" < "5" ]]; then +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then echo "we do not support g++<5, try to activate devtoolset-7 env" source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \ || ( echo "source devtoolset-7 failed, ignore this info if you have set env TOOLCHAIN_ROOT, TARGET_C_COMPILER, TARGET_CXX_COMPILER properly (see more details in README.md)" && sleep 4 ) # I hope user will see it diff --git a/samples/mlu-ops/build.sh b/samples/mlu-ops/build.sh index 8b0987113..0189ec7d8 100755 --- a/samples/mlu-ops/build.sh +++ b/samples/mlu-ops/build.sh @@ -9,7 +9,7 @@ if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then fi fi -if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" < "5" ]]; then +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then echo "we do not support g++<5, try to activate devtoolset-8 env" source /opt/rh/devtoolset-8/enable && echo "devtoolset-8 activated" \ || ( echo "source devtoolset-8 failed, ignore this info if you have set env TOOLCHAIN_ROOT, TARGET_C_COMPILER, TARGET_CXX_COMPILER properly (see more details in README.md)" && sleep 4 ) # I hope user will see it diff --git a/samples/mlu-ops/fault_sample/build.sh b/samples/mlu-ops/fault_sample/build.sh index 787249233..95e635c8e 100755 --- a/samples/mlu-ops/fault_sample/build.sh +++ b/samples/mlu-ops/fault_sample/build.sh @@ -9,7 +9,7 @@ if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then fi fi -if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" < "5" ]]; then +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then echo "we do not support g++<5, try to activate devtoolset-7 env" source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \ || ( echo "source devtoolset-7 failed, ignore this info if you have set env TOOLCHAIN_ROOT, TARGET_C_COMPILER, TARGET_CXX_COMPILER properly (see more details in README.md)" && sleep 4 ) # I hope user will see it diff --git a/samples/mlu-ops/poly_nms_sample/build.sh b/samples/mlu-ops/poly_nms_sample/build.sh index 787249233..95e635c8e 100755 --- a/samples/mlu-ops/poly_nms_sample/build.sh +++ b/samples/mlu-ops/poly_nms_sample/build.sh @@ -9,7 +9,7 @@ if [ "$OS_RELEASE_ID" = "centos" -a "$OS_RELEASE_VERSION_ID" = "7" ]; then fi fi -if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" < "5" ]]; then +if [[ "$(g++ --version | head -n1 | awk '{ print $3 }' | cut -d '.' -f1)" -lt "5" ]]; then echo "we do not support g++<5, try to activate devtoolset-7 env" source /opt/rh/devtoolset-7/enable && echo "devtoolset-7 activated" \ || ( echo "source devtoolset-7 failed, ignore this info if you have set env TOOLCHAIN_ROOT, TARGET_C_COMPILER, TARGET_CXX_COMPILER properly (see more details in README.md)" && sleep 4 ) # I hope user will see it diff --git a/scripts/gen_symbol_visibility_map.py b/scripts/gen_symbol_visibility_map.py index 22eb24d98..8e26114f4 100644 --- a/scripts/gen_symbol_visibility_map.py +++ b/scripts/gen_symbol_visibility_map.py @@ -6,12 +6,18 @@ def get_mluops(input_file): ops_str="" pattern = re.compile(r'(?PmluOp\w+) *\(') + pattern_lite = re.compile(r'(?PmluApply\w+) *\(') with open(input_file,'r', encoding='utf8') as f: for line in f: match = pattern.search(line) + lite_match = pattern_lite.search(line) if match: op = match.groupdict()['api'] + ';' ops_str += op + + if lite_match: + op = lite_match.groupdict()['api'] + '*;' + ops_str += '*' + op return ops_str def create_map_file(map_file,ops_str): diff --git a/test/mlu_op_gtest/CMakeLists.txt b/test/mlu_op_gtest/CMakeLists.txt index 1ffb3c3d3..225f625be 100644 --- a/test/mlu_op_gtest/CMakeLists.txt +++ b/test/mlu_op_gtest/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.5) cmake_policy(SET CMP0048 NEW) # Use project(... VERSION ...) project(mlu_op_test) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # check @@ -155,7 +155,7 @@ add_library(gtest_shared STATIC ${SRC_DIR}) # for runtime convenience # #target_link_libraries(gen_half2float_table cnrt) add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/gen_half2float_table - COMMAND ${BANG_CNCC_EXECUTABLE} -mavx2 -mf16c -std=c++14 -I ${CMAKE_CURRENT_SOURCE_DIR}/include -I ${NEUWARE_HOME}/include ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_half2float_table.cpp -o ${CMAKE_CURRENT_BINARY_DIR}/gen_half2float_table -L ${NEUWARE_HOME}/lib64 -lcnrt -lm -lstdc++ -Wl,-rpath=${NEUWARE_HOME}/lib64 + COMMAND ${BANG_CNCC_EXECUTABLE} -mavx2 -mf16c -std=c++17 -I ${CMAKE_CURRENT_SOURCE_DIR}/include -I ${NEUWARE_HOME}/include ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_half2float_table.cpp -o ${CMAKE_CURRENT_BINARY_DIR}/gen_half2float_table -L ${NEUWARE_HOME}/lib64 -lcnrt -lm -lstdc++ -Wl,-rpath=${NEUWARE_HOME}/lib64 DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_half2float_table.cpp ${CMAKE_CURRENT_SOURCE_DIR}/include/math_half.h ) add_custom_command( diff --git a/test/mlu_op_gtest/pb_gtest/include/executor.h b/test/mlu_op_gtest/pb_gtest/include/executor.h index 9adb5416c..6bfa52deb 100644 --- a/test/mlu_op_gtest/pb_gtest/include/executor.h +++ b/test/mlu_op_gtest/pb_gtest/include/executor.h @@ -39,6 +39,7 @@ #include #include "gtest/gtest.h" #include "mlu_op.h" +#include "bangc_kernels.h" #include "core/tensor.h" #include "core/tool.h" #include "core/type.h" @@ -138,6 +139,7 @@ struct ExecuteConfig { bool random_mlu_address = false; bool enable_const_dram = false; bool auto_tuning = false; + bool enable_lite_interface = getEnv("MLUOP_GTEST_INTERFACE_MODE", 0) == 1; // #if GTEST_ENABLE_GPERFTOOLS // // TODO(None) move into global_var // bool gtest_internal_cpu_profile = diff --git a/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.cpp b/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.cpp index b0f6f229f..7dba934dc 100644 --- a/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.cpp +++ b/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.cpp @@ -24,6 +24,9 @@ #include "adam_w.h" #include "cn_api.h" +#include "bangc_helper_dtype.h" +#include "bangc_kernels.h" + namespace mluoptest { void AdamWExecutor::paramCheck() { @@ -38,7 +41,7 @@ void AdamWExecutor::paramCheck() { } void AdamWExecutor::compute() { - VLOG(4) << "AdamWExecutor compute "; + VLOG(4) << "AdamWExecutor compute(). "; auto desc_param = tensor_desc_[0].tensor; auto desc_paramh = tensor_desc_[1].tensor; auto desc_momentum = tensor_desc_[2].tensor; @@ -62,25 +65,37 @@ void AdamWExecutor::compute() { const float fp32_scale = parser_->getProtoNode()->adamw_param().scale(); bool use_nesterov = parser_->getProtoNode()->adamw_param().use_nesterov(); - mluOpAdamWDescriptor_t adamw_desc; - MLUOP_CHECK(mluOpCreateAdamWDescriptor(&adamw_desc)); - MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_WEIGHT_DECAY, - &fp32_weight_decay, - sizeof(fp32_weight_decay))); - MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_GRAD_SCALE, - &fp32_scale, sizeof(fp32_scale))); - MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_USE_NESTEROV, - &use_nesterov, sizeof(use_nesterov))); - - VLOG(4) << "call mluOpAdamw()"; - interface_timer_.start(); - MLUOP_CHECK(mluOpAdamW(handle_, adamw_desc, desc_param, dev_param, - desc_paramh, dev_paramh, desc_momentum, dev_momentum, - desc_velocity, dev_velocity, desc_grad, dev_grad, - fp32_lr, fp32_beta1, fp32_beta2, fp32_bias1, - fp32_bias2, fp32_epsilon)); - interface_timer_.stop(); - MLUOP_CHECK(mluOpDestroyAdamWDescriptor(adamw_desc)); + if (!exe_config_->enable_lite_interface) { + VLOG(4) << "call mluOpAdamw(). "; + mluOpAdamWDescriptor_t adamw_desc; + MLUOP_CHECK(mluOpCreateAdamWDescriptor(&adamw_desc)); + MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_WEIGHT_DECAY, + &fp32_weight_decay, + sizeof(fp32_weight_decay))); + MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_GRAD_SCALE, + &fp32_scale, sizeof(fp32_scale))); + MLUOP_CHECK(mluOpSetAdamWDescAttr(adamw_desc, MLUOP_ADAMW_USE_NESTEROV, + &use_nesterov, sizeof(use_nesterov))); + interface_timer_.start(); + MLUOP_CHECK(mluOpAdamW(handle_, adamw_desc, desc_param, dev_param, + desc_paramh, dev_paramh, desc_momentum, dev_momentum, + desc_velocity, dev_velocity, desc_grad, dev_grad, + fp32_lr, fp32_beta1, fp32_beta2, fp32_bias1, + fp32_bias2, fp32_epsilon)); + interface_timer_.stop(); + MLUOP_CHECK(mluOpDestroyAdamWDescriptor(adamw_desc)); + } else { + VLOG(4) << "call mluApplyAdamW(). "; + const int size = mluOpGetTensorElementNum(desc_momentum) * sizeof(float); + interface_timer_.start(); + const auto adamw_status = bangc_kernels::mluApplyAdamW( + handle_->queue, fp32_lr, fp32_beta1, fp32_beta2, fp32_bias1, fp32_bias2, + fp32_epsilon, fp32_weight_decay, fp32_scale, use_nesterov, size, + BANG_WRAP_T((Eigen::bfloat16 *)dev_paramh), + BANG_WRAP_T((Eigen::bfloat16 *)dev_grad), dev_param, dev_momentum, + dev_velocity); + interface_timer_.stop(); + } } void AdamWExecutor::setMiscellaneousParam() { diff --git a/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.h b/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.h index 79cb82bf9..66388bbfb 100644 --- a/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.h +++ b/test/mlu_op_gtest/pb_gtest/src/zoo/adam_w/adam_w.h @@ -22,6 +22,7 @@ *************************************************************************/ #ifndef TEST_MLU_OP_GTEST_SRC_ZOO_ADAMW_ADAMW_H_ #define TEST_MLU_OP_GTEST_SRC_ZOO_ADAMW_ADAMW_H_ + #include "executor.h" namespace mluoptest {