Skip to content

Commit

Permalink
[Feature](mluOpSyncBatchNormBackwardReduce): Add new API and deprecat…
Browse files Browse the repository at this point in the history
…e old ones (#937)
  • Loading branch information
duzekunKTH authored Jan 24, 2024
1 parent 8994790 commit ac22f7f
Show file tree
Hide file tree
Showing 3 changed files with 509 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
*************************************************************************/
#include "kernels/utils/cnnl_helper.h"

mluOpStatus_t MLUOP_WIN_API mluOpGetSyncBatchnormBackwardReduceWorkspaceSize(
mluOpStatus_t MLUOP_WIN_API mluOpGetSyncBatchNormBackwardReduceWorkspaceSize(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_x,
size_t *workspace_size) {
PARAM_CHECK("mluOpSyncBatchnormBackwardReduce_v2", handle != NULL);
PARAM_CHECK("mluOpSyncBatchnormBackwardReduce_v2", desc_x != NULL);
PARAM_CHECK("mluOpGetSyncBatchNormBackwardReduceWorkspaceSize",
handle != NULL);
PARAM_CHECK("mluOpGetSyncBatchNormBackwardReduceWorkspaceSize",
desc_x != NULL);

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(desc_x, cnnl_desc_x);
Expand All @@ -35,16 +37,27 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetSyncBatchnormBackwardReduceWorkspaceSize(
cnnlGetSyncBatchnormBackwardReduceWorkspaceSize(cnnl_handle, cnnl_desc_x,
workspace_size),
CNNL_STATUS_SUCCESS,
"[mluOpSyncBatchnormBackwardReduce_v2] Internal error"
" accured in mluOpGetSyncBatchnormBackwardReduceWorkspaceSize.",
"[mluOpGetSyncBatchNormBackwardReduceWorkspaceSize] Internal error"
" accured in cnnlGetSyncBatchnormBackwardReduceWorkspaceSize.",
MLUOP_STATUS_INTERNAL_ERROR);

DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_desc_x);
DESTROY_CNNL_HANDLE(cnnl_handle);
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce(
mluOpStatus_t MLUOP_WIN_API mluOpGetSyncBatchnormBackwardReduceWorkspaceSize(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_x,
size_t *workspace_size) {
LOG_FIRST_N(WARNING, 1)
<< "[mluOpGetSyncBatchnormBackwardReduceWorkspaceSize] is deprecated and"
<< " will be removed in the future release, please use "
<< "[mluOpGetSyncBatchNormBackwardReduceWorkspaceSize] instead.";
return mluOpGetSyncBatchNormBackwardReduceWorkspaceSize(
handle, desc_x, workspace_size);
}

mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchNormBackwardReduce(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_dz, const void *dz,
const mluOpTensorDescriptor_t desc_x, const void *x,
const mluOpTensorDescriptor_t desc_mean, const void *mean,
Expand All @@ -55,15 +68,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce(
const mluOpTensorDescriptor_t desc_sum_dy_xmu, void *sum_dy_xmu,
const bool needs_input_grad0, const bool needs_input_grad1,
const bool needs_input_grad2) {
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", handle != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_dz != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_x != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_mean != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", dz != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", x != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", mean != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", handle != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", desc_dz != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", desc_x != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", desc_mean != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", desc_invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", dz != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", x != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", mean != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce]", invstd != NULL);

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(desc_dz, cnnl_desc_dz);
Expand All @@ -83,8 +96,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce(
dbias, cnnl_desc_sum_dy, sum_dy, cnnl_desc_sum_dy_xmu, sum_dy_xmu,
needs_input_grad0, needs_input_grad1, needs_input_grad2),
CNNL_STATUS_SUCCESS,
"[mluOpSyncBatchnormBackwardReduce] Internal error"
" accured in mluOpSyncBatchnormBackwardReduce.",
"[mluOpSyncBatchNormBackwardReduce] Internal error"
" accured in cnnlSyncBatchnormBackwardReduce.",
MLUOP_STATUS_INTERNAL_ERROR);

DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_desc_dz);
Expand All @@ -99,7 +112,30 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce(
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce_v2(
mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_dz, const void *dz,
const mluOpTensorDescriptor_t desc_x, const void *x,
const mluOpTensorDescriptor_t desc_mean, const void *mean,
const mluOpTensorDescriptor_t desc_invstd, const void *invstd,
const mluOpTensorDescriptor_t desc_dfilter, void *dfilter,
const mluOpTensorDescriptor_t desc_dbias, void *dbias,
const mluOpTensorDescriptor_t desc_sum_dy, void *sum_dy,
const mluOpTensorDescriptor_t desc_sum_dy_xmu, void *sum_dy_xmu,
const bool needs_input_grad0, const bool needs_input_grad1,
const bool needs_input_grad2) {
LOG_FIRST_N(WARNING, 1)
<< "[mluOpSyncBatchnormBackwardReduce] is deprecated and"
<< " will be removed in the future release, please use "
<< "[mluOpSyncBatchNormBackwardReduce] instead.";
return mluOpSyncBatchNormBackwardReduce(
handle, desc_dz, dz, desc_x, x, desc_mean, mean,
desc_invstd, invstd, desc_dfilter, dfilter,
desc_dbias, dbias, desc_sum_dy, sum_dy,
desc_sum_dy_xmu, sum_dy_xmu,
needs_input_grad0, needs_input_grad1, needs_input_grad2);
}

mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchNormBackwardReduce_v2(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_dz, const void *dz,
const mluOpTensorDescriptor_t desc_x, const void *x,
const mluOpTensorDescriptor_t desc_mean, const void *mean,
Expand All @@ -111,17 +147,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce_v2(
const mluOpTensorDescriptor_t desc_sum_dy_xmu, void *sum_dy_xmu,
const bool needs_input_grad0, const bool needs_input_grad1,
const bool needs_input_grad2) {
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", handle != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_dz != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_x != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_mean != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", desc_invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", dz != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", x != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", mean != NULL);
PARAM_CHECK("[mluOpSyncBatchnormBackwardReduce]", invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", handle != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", desc_dz != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", desc_x != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", desc_mean != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", desc_invstd != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", dz != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", x != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", mean != NULL);
PARAM_CHECK("[mluOpSyncBatchNormBackwardReduce_v2]", invstd != NULL);
if (workspace_size > 0) {
PARAM_CHECK("mluOpSyncBatchnormBackwardReduce_v2", workspace != NULL);
PARAM_CHECK("mluOpSyncBatchNormBackwardReduce_v2", workspace != NULL);
}

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
Expand All @@ -143,8 +179,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce_v2(
sum_dy, cnnl_desc_sum_dy_xmu, sum_dy_xmu, needs_input_grad0,
needs_input_grad1, needs_input_grad2),
CNNL_STATUS_SUCCESS,
"[mluOpSyncBatchnormBackwardReduce] Internal error"
" accured in mluOpSyncBatchnormBackwardReduce_v2.",
"[mluOpSyncBatchNormBackwardReduce_v2] Internal error"
" accured in cnnlSyncBatchnormBackwardReduce_v2.",
MLUOP_STATUS_INTERNAL_ERROR);

DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_desc_dz);
Expand All @@ -158,3 +194,27 @@ mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce_v2(
DESTROY_CNNL_HANDLE(cnnl_handle);
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSyncBatchnormBackwardReduce_v2(
mluOpHandle_t handle, const mluOpTensorDescriptor_t desc_dz, const void *dz,
const mluOpTensorDescriptor_t desc_x, const void *x,
const mluOpTensorDescriptor_t desc_mean, const void *mean,
const mluOpTensorDescriptor_t desc_invstd, const void *invstd,
void *workspace, size_t workspace_size,
const mluOpTensorDescriptor_t desc_dfilter, void *dfilter,
const mluOpTensorDescriptor_t desc_dbias, void *dbias,
const mluOpTensorDescriptor_t desc_sum_dy, void *sum_dy,
const mluOpTensorDescriptor_t desc_sum_dy_xmu, void *sum_dy_xmu,
const bool needs_input_grad0, const bool needs_input_grad1,
const bool needs_input_grad2) {
LOG_FIRST_N(WARNING, 1)
<< "[mluOpSyncBatchnormBackwardReduce_v2] is deprecated and"
<< " will be removed in the future release, please use "
<< "[mluOpSyncBatchNormBackwardReduce_v2] instead.";
return mluOpSyncBatchNormBackwardReduce_v2(
handle, desc_dz, dz, desc_x, x, desc_mean, mean,
desc_invstd, invstd, workspace, workspace_size,
desc_dfilter, dfilter, desc_dbias, dbias,
desc_sum_dy, sum_dy, desc_sum_dy_xmu, sum_dy_xmu,
needs_input_grad0, needs_input_grad1, needs_input_grad2);
}
Loading

0 comments on commit ac22f7f

Please sign in to comment.