Skip to content

Commit

Permalink
Disable XDL kernels on unsupported HW Add ck::is_xdl_supported (#768)
Browse files Browse the repository at this point in the history
* Disable XDL kernels on unsupported HW; Add ck::is_xdl_supported function (#765)

* Do not throw an error when GEMM problem is not supported.

---------

Co-authored-by: Bartlomiej Wroblewski <[email protected]>
Co-authored-by: Adam Osewski <[email protected]>
Co-authored-by: Illia Silin <[email protected]>
  • Loading branch information
4 people authored Jul 26, 2023
1 parent 016bd42 commit ac6d68b
Show file tree
Hide file tree
Showing 37 changed files with 121 additions and 51 deletions.
6 changes: 3 additions & 3 deletions example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ int main(int argc, char* argv[])

if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;

return 0;
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
Expand Down
7 changes: 7 additions & 0 deletions include/ck/host_utility/device_prop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,11 @@ inline std::string get_device_name()
return name;
}

inline bool is_xdl_supported()
{
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}

} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
ck::Tuple<>{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.Print();
#endif

if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,

static bool IsSupportedArgument(const Problem& problem)
{
if(!ck::is_xdl_supported())
{
return false;
}

if(problem.K % K1 != 0)
{
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

// vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,11 @@ struct

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,

static bool IsSupportedArgument(const Argument& karg)
{
if(!ck::is_xdl_supported())
{
return false;
}

return GridwiseGemm::CheckValidity(karg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}

const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle

static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
if(!ck::is_xdl_supported())
{
return false;
}
Expand Down
Loading

0 comments on commit ac6d68b

Please sign in to comment.