-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C] Normalization Refactor + Adding CUDNN backend #1315
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <[email protected]>
dgamma_part.dtype()); | ||
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(), | ||
dbeta_part.dtype()); | ||
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would have unexpected behavior if the user sets NVTE_BWD_LAYERNORM_USE_CUDNN=0
:
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN")) { | |
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN") | |
|| !std::atoi(std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN"))) { |
Alternatively, we could use TE's getenv
function:
T getenv(const char *variable); |
However, this is delicate since it can run into issues if the core lib and framework lib are compiled with different C++ ABIs. A solution might be to make getenv
a header-only impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I had not yet introduced changes in the framework part yet.
With the new change in PyTorch side in the latest commit, we no longer need this env check.
transformer_engine/common/include/transformer_engine/layernorm.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/include/transformer_engine/layernorm.h
Outdated
Show resolved
Hide resolved
Signed-off-by: Phuong Nguyen <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
enum class NVTE_Norm_Backend { Te, Cudnn }; | ||
enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; | ||
enum class NVTE_Norm_Stage { Forward, Backward }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We use the NVTE_
prefix for C-style enums to avoid ambiguity, but it's not necessary for these C++-style enums since they're defined within the transformer_engine::normalization
namespace:
enum class NVTE_Norm_Backend { Te, Cudnn }; | |
enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; | |
enum class NVTE_Norm_Stage { Forward, Backward }; | |
enum class NormBackend { Te, Cudnn }; | |
enum class NormType { LayerNorm, RMSNorm }; | |
enum class NormStage { Forward, Backward }; |
Not relevant yet, but in the future we might want to have separate stages for ForwardTrain
and ForwardInfer
if cuDNN exposes an optimized inference impl.
|
||
NVTE_Norm_Backend norm_backend; | ||
bool is_aligned = true; | ||
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use TE's getenv
function so that users can specify NVTE_FWD_LAYERNORM_USE_CUDNN=0
:
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) { | |
if (getenv<bool>("NVTE_FWD_LAYERNORM_USE_CUDNN")) { |
See:
T getenv(const char *variable); |
Similar changes should be made in the backward function, as well as RMSNorm. We might also want to consider caching the value to avoid CPU overheads:
static const bool use_cudnn_backend = getenv<bool>("NVTE_FWD_LAYERNORM_USE_CUDNN");
if (use_cudnn_backend) {
Related: #1315 (comment)
if (workspace->data.shape.empty()) { | ||
CheckInputTensor(x, "x"); | ||
CheckInputTensor(gamma, "gamma"); | ||
CheckInputTensor(beta, "beta"); | ||
|
||
CheckOutputTensor(*z, "z"); | ||
CheckOutputTensor(*mu, "mu"); | ||
CheckOutputTensor(*rsigma, "rsigma"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not immediately obvious why we only check tensor dimensions when the workspace is empty. If it's cheap, it'll be easier to check tensors every time. Otherwise, it would be helpful to add a comment that we assume that this function is called twice (to query workspace and to launch kernel) and that we only check tensors the first time.
Similar changes should be made in the backward function, as well as RMSNorm.
* where | ||
* @f[ | ||
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} | ||
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma) + \beta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mismatched parenthesis:
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma) + \beta | |
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta |
We could also make the LaTeX look nicer with:
* y = \frac{x - \mathbb{E}[x]}{\sqrt{\text{Var}[x] + \varepsilon}} \gamma + \beta
Any changes here should also be made in the backward function and in RMSNorm.
* \param[in] multiprocessorCount Number of SMs in the device. | ||
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should provide more details since we don't include the formula in the main documentation.
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled | |
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ |
Similar changes should be made to the backward function as well as RMSNorm.
double atol_bwd = 1e-3; | ||
double rtol_bwd = 1e-3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to relax tolerances?
double atol_bwd = 1e-3; | |
double rtol_bwd = 1e-3; | |
double atol_bwd = 1e-4; | |
double rtol_bwd = 1e-4; |
Ideally we would use tight tolerances like:
TransformerEngine/tests/pytorch/utils.py
Lines 73 to 84 in c0a539c
if dtype == torch.float16: | |
return dict(rtol=1e-3, atol=1e-5) | |
if dtype == torch.bfloat16: | |
return dict(rtol=1.6e-2, atol=1e-5) | |
if dtype == torch.float32: | |
return dict(rtol=1.3e-6, atol=1e-5) | |
if dtype == torch.float64: | |
return dict(rtol=1e-7, atol=1e-7) | |
if dtype == torch.float8_e4m3fn: | |
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 | |
if dtype == torch.float8_e5m2: | |
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 |
LaunchParams<KernelParamsType> _launch_params; | ||
std::function<void(LaunchParams<KernelParamsType>&, const bool)> _kernel; | ||
|
||
bool _is_layernorm; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Caching this value makes the code slightly more concise, but I think it's clearer and more general to just store the original enum:
bool _is_layernorm; | |
NVTE_Norm_Type _norm_type; |
virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, | ||
void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, | ||
cudaStream_t stream) = 0; | ||
|
||
virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, | ||
void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, | ||
void* workspace_dptr, cudaStream_t stream) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit awkward that NormalizationPlanBase
includes logic for both the forward and backward kernels. I feel the "right" solution is to implement separate abstract base classes for forward and backward. That said, this approach does allow some code reuse with the cuDNN backend though, so I think it's fine for now. We can revisit in the future if needed, e.g. if we add a separate forward inference stage.
|
||
NVTE_Norm_Backend norm_backend; | ||
bool is_aligned = true; | ||
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem with these envvars is it's all-or-nothing. This is especially troublesome for testing, where we may want to test TE and cuDNN norm kernels within the same process. Perhaps we could have some function like nvte_enable_cudnn_layernorm
that sets a global variable, and the initial value is based on an envvar.
INSTANTIATE_TEST_SUITE_P( | ||
OperatorTest, | ||
LNTestSuite, | ||
NormTestSuite, | ||
::testing::Combine( | ||
::testing::Values(NormType::LayerNorm, NormType::RMSNorm), | ||
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), | ||
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), | ||
::testing::ValuesIn(test_cases), | ||
::testing::Values(false, true)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to test both cuDNN and TE kernels. We can reduce the number of test shapes to keep the total number of tests manageable.
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Description
TODO:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: