Skip to content

Commit

Permalink
[Fix](mluOpCarafeForward,mluOpCarafeBackward): fix gen_case bug and a… (
Browse files Browse the repository at this point in the history
#1117)

Co-authored-by: chenyongjie <[email protected]>
  • Loading branch information
chen4231 and chenyongjie authored Oct 22, 2024
1 parent 0f2bdbf commit 22df809
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions kernels/carafe/carafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,19 +783,24 @@ mluOpStatus_t MLUOP_WIN_API mluOpCarafeForward(
&grid_dimG, &grid_dimC, &job_num),
"Error occured in generating policy.");

{
LARGE_TENSOR_CHECK("[mluOpCarafeForward]", input_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeForward]", mask_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeForward]", output_desc);
}
// GEN_CASE
if (MLUOP_GEN_CASE_ON_NEW) {
GEN_CASE_START("carafe_forward", "CARAFE_FORWARD");
GEN_CASE_HANDLE(handle);
GEN_CASE_DATA(true, "input", input, input_desc, 5.1, -5.3);
GEN_CASE_DATA(true, "mask", mask, mask_desc, 0.0, 1.0);
GEN_CASE_DATA(false, "output", output, output_desc, 1.7, -1.8);
GEN_CASE_OP_PARAM_SINGLE(0, "carafe_forward", "dimnb", carafe_desc->dimNb);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe_forward", "kernel_size",
GEN_CASE_OP_PARAM_SINGLE(0, "carafe", "dimnb", carafe_desc->dimNb);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe", "kernel_size",
carafe_desc->kernel_size);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe_forward", "group_size",
GEN_CASE_OP_PARAM_SINGLE(1, "carafe", "group_size",
carafe_desc->group_size);
GEN_CASE_OP_PARAM_SINGLE(2, "carafe_forward", "scale_factor",
GEN_CASE_OP_PARAM_SINGLE(2, "carafe", "scale_factor",
carafe_desc->scale_factor);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
Expand Down Expand Up @@ -840,6 +845,14 @@ mluOpStatus_t MLUOP_WIN_API mluOpCarafeBackward(
return param_check_status;
}

{
LARGE_TENSOR_CHECK("[mluOpCarafeBackward]", input_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeBackward]", mask_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeBackward]", grad_output_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeBackward]", grad_input_desc);
LARGE_TENSOR_CHECK("[mluOpCarafeBackward]", grad_mask_desc);
}

if (MLUOP_GEN_CASE_ON_NEW) {
GEN_CASE_START("carafe_backward", "CARAFE_BACKWARD");
GEN_CASE_HANDLE(handle);
Expand All @@ -849,12 +862,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpCarafeBackward(
-1.8);
GEN_CASE_DATA(false, "grad_input", grad_input, grad_input_desc, 0, 0);
GEN_CASE_DATA(false, "grad_mask", grad_mask, grad_mask_desc, 0, 0);
GEN_CASE_OP_PARAM_SINGLE(0, "carafe_backward", "dimnb", carafe_desc->dimNb);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe_backward", "kernel_size",
GEN_CASE_OP_PARAM_SINGLE(0, "carafe", "dimnb", carafe_desc->dimNb);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe", "kernel_size",
carafe_desc->kernel_size);
GEN_CASE_OP_PARAM_SINGLE(1, "carafe_backward", "group_size",
GEN_CASE_OP_PARAM_SINGLE(1, "carafe", "group_size",
carafe_desc->group_size);
GEN_CASE_OP_PARAM_SINGLE(2, "carafe_backward", "scale_factor",
GEN_CASE_OP_PARAM_SINGLE(2, "carafe", "scale_factor",
carafe_desc->scale_factor);
GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0);
}
Expand Down

0 comments on commit 22df809

Please sign in to comment.