Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C] Normalization Refactor + Adding CUDNN backend #1315

Draft
wants to merge 59 commits into
base: main
Choose a base branch
from

Conversation

phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Nov 5, 2024

Description

TODO:

  • Adapt normalization in JAX and Paddle.
  • Benchmark performance of new APIs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
dgamma_part.dtype());
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(),
dbeta_part.dtype());
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would have unexpected behavior if the user sets NVTE_BWD_LAYERNORM_USE_CUDNN=0:

Suggested change
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN")) {
if (!std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN")
|| !std::atoi(std::getenv("NVTE_BWD_LAYERNORM_USE_CUDNN"))) {

Alternatively, we could use TE's getenv function:

T getenv(const char *variable);

However, this is delicate since it can run into issues if the core lib and framework lib are compiled with different C++ ABIs. A solution might be to make getenv a header-only impl.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I had not yet introduced changes in the framework part yet.
With the new change in PyTorch side in the latest commit, we no longer need this env check.

@timmoon10 timmoon10 self-requested a review November 8, 2024 00:47
phu0ngng and others added 8 commits November 8, 2024 17:32
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Comment on lines +138 to +140
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We use the NVTE_ prefix for C-style enums to avoid ambiguity, but it's not necessary for these C++-style enums since they're defined within the transformer_engine::normalization namespace:

Suggested change
enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward };
enum class NormBackend { Te, Cudnn };
enum class NormType { LayerNorm, RMSNorm };
enum class NormStage { Forward, Backward };

Not relevant yet, but in the future we might want to have separate stages for ForwardTrain and ForwardInfer if cuDNN exposes an optimized inference impl.


NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use TE's getenv function so that users can specify NVTE_FWD_LAYERNORM_USE_CUDNN=0:

Suggested change
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) {
if (getenv<bool>("NVTE_FWD_LAYERNORM_USE_CUDNN")) {

See:

T getenv(const char *variable);

Similar changes should be made in the backward function, as well as RMSNorm. We might also want to consider caching the value to avoid CPU overheads:

  static const bool use_cudnn_backend = getenv<bool>("NVTE_FWD_LAYERNORM_USE_CUDNN");
  if (use_cudnn_backend) {

Related: #1315 (comment)

Comment on lines +42 to +50
if (workspace->data.shape.empty()) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");

CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not immediately obvious why we only check tensor dimensions when the workspace is empty. If it's cheap, it'll be easier to check tensors every time. Otherwise, it would be helpful to add a comment that we assume that this function is called twice (to query workspace and to launch kernel) and that we only check tensors the first time.

Similar changes should be made in the backward function, as well as RMSNorm.

* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma) + \beta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mismatched parenthesis:

Suggested change
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma) + \beta
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta

We could also make the LaTeX look nicer with:

 * y = \frac{x - \mathbb{E}[x]}{\sqrt{\text{Var}[x] + \varepsilon}} \gamma + \beta

Any changes here should also be made in the backward function and in RMSNorm.

* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should provide more details since we don't include the formula in the main documentation.

Suggested change
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$

Similar changes should be made to the backward function as well as RMSNorm.

Comment on lines +282 to +283
double atol_bwd = 1e-3;
double rtol_bwd = 1e-3;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to relax tolerances?

Suggested change
double atol_bwd = 1e-3;
double rtol_bwd = 1e-3;
double atol_bwd = 1e-4;
double rtol_bwd = 1e-4;

Ideally we would use tight tolerances like:

if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
if dtype == torch.float8_e4m3fn:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152

LaunchParams<KernelParamsType> _launch_params;
std::function<void(LaunchParams<KernelParamsType>&, const bool)> _kernel;

bool _is_layernorm;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Caching this value makes the code slightly more concise, but I think it's clearer and more general to just store the original enum:

Suggested change
bool _is_layernorm;
NVTE_Norm_Type _norm_type;

Comment on lines +205 to +211
virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
cudaStream_t stream) = 0;

virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr,
void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit awkward that NormalizationPlanBase includes logic for both the forward and backward kernels. I feel the "right" solution is to implement separate abstract base classes for forward and backward. That said, this approach does allow some code reuse with the cuDNN backend though, so I think it's fine for now. We can revisit in the future if needed, e.g. if we add a separate forward inference stage.


NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (std::getenv("NVTE_FWD_LAYERNORM_USE_CUDNN")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem with these envvars is it's all-or-nothing. This is especially troublesome for testing, where we may want to test TE and cuDNN norm kernels within the same process. Perhaps we could have some function like nvte_enable_cudnn_layernorm that sets a global variable, and the initial value is based on an envvar.

Comment on lines 329 to 337
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
LNTestSuite,
NormTestSuite,
::testing::Combine(
::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to test both cuDNN and TE kernels. We can reduce the number of test shapes to keep the total number of tests manageable.

@timmoon10 timmoon10 self-requested a review November 14, 2024 03:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants