Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIOpen bf16 batchnorm enablement #1665

Draft
wants to merge 27 commits into
base: rocm6.3_internal_testing
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f1aecc0
fitst steps to enable bf16 batchnorm
dnikolaev-amd Oct 14, 2024
fea29cb
enable forward fp16 batchnorm
dnikolaev-amd Oct 17, 2024
d12900c
start with bf16 bn backward
dnikolaev-amd Oct 17, 2024
685d4fe
bf16 eval
dnikolaev-amd Oct 18, 2024
2df6446
fwd and bwd fixes
dnikolaev-amd Oct 18, 2024
f1bd902
set dtype on python level
dnikolaev-amd Oct 21, 2024
e1d267a
play with tensors on python level
dnikolaev-amd Oct 21, 2024
38c781f
extra logging
dnikolaev-amd Oct 22, 2024
f7b72f3
cleanup
dnikolaev-amd Oct 22, 2024
ace6e2b
fix dtype
dnikolaev-amd Oct 22, 2024
35fd62a
remove odesc
dnikolaev-amd Oct 22, 2024
de15416
enable fp16 nhwc batchnorm instead of nchw
dnikolaev-amd Oct 22, 2024
fb5ecc5
cleanup
dnikolaev-amd Oct 22, 2024
ca2625f
extra logging
dnikolaev-amd Oct 24, 2024
f896484
it works
dnikolaev-amd Oct 24, 2024
ce59711
enable CK FP16 NHWC batchnorm on MIOpen
dnikolaev-amd Oct 24, 2024
097e41e
enable NCHW
dnikolaev-amd Oct 28, 2024
f9db16f
benchmark errors
dnikolaev-amd Oct 28, 2024
bda3fcc
split forward and inferecnce
dnikolaev-amd Oct 28, 2024
e6cb40c
benchmark works
dnikolaev-amd Oct 29, 2024
201b948
remove extra logging and add PYTORCH_MIOPEN_EXTRA_LOGGING env var
dnikolaev-amd Oct 29, 2024
b8cb39f
more logging with PYTORCH_MIOPEN_EXTRA_LOGGING env var
dnikolaev-amd Oct 29, 2024
64b4b3e
some logging fixes
dnikolaev-amd Oct 29, 2024
77d8e92
cleanup Normalization.cu from logging
dnikolaev-amd Oct 30, 2024
9e47d81
enable v2 fwd train
dnikolaev-amd Nov 1, 2024
e8483ea
v2 inference and backward
dnikolaev-amd Nov 1, 2024
f0effc8
fix logging
dnikolaev-amd Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# aten/src/THH/
# c10/hip
# aten/src/ATen/hip
# aten/src/ATen/native/hip
# aten/src/ATen/native/cudnn/hip
# aten/src/ATen/native/nested/hip
# aten/src/ATen/native/quantized/cudnn/hip
# aten/src/ATen/native/quantized/hip
# aten/src/ATen/native/transformers/hip
# aten/src/ATen/test/hip
# aten/src/ATen/test/test_install/hip
# binaries/hip
# aten/src/ATen/native/sparse/hip/
# READ THIS BEFORE YOU REFACTOR ME
#
# setup.py uses the list of patterns in this file to decide
Expand Down
164 changes: 164 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "bf16 nhwc v2",
"type": "debugpy",
"request": "launch",
"pythonArgs": ["-u"],
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nhwc_miopen_cuda_bfloat16"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
// "MIOPEN_LOG_LEVEL": "6",
// "MIOPEN_ENABLE_LOGGING": "1",
// "AMD_LOG_LEVEL": "6",
}
},
{
"name": "bf16 nhwc v1",
"type": "debugpy",
"request": "launch",
"pythonArgs": ["-u"],
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nhwc_miopen_cuda_bfloat16"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
// "PYTORCH_MIOPEN_USE_API_V2": "1",
// "MIOPEN_LOG_LEVEL": "6",
// "MIOPEN_ENABLE_LOGGING": "1",
// "AMD_LOG_LEVEL": "6",
}
},
{
"name": "bf16 nchw",
"type": "debugpy",
"request": "launch",
"pythonArgs": ["-u"],
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nchw_miopen_cuda_bfloat16"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
// "MIOPEN_LOG_LEVEL": "6",
// "MIOPEN_ENABLE_LOGGING": "1",
// "AMD_LOG_LEVEL": "6",
}
},
{
"name": "fp16 nhwc",
"type": "debugpy",
"request": "launch",
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nhwc_miopen_cuda_float16"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
}
},
{
"name": "fp16 nchw",
"type": "debugpy",
"request": "launch",
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nchw_miopen_cuda_float16"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
}
},
{
"name": "fp32 nChw",
"type": "debugpy",
"request": "launch",
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nchw_miopen_cuda_float32",
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
}
},
{
"name": "fp32 nHwc",
"type": "debugpy",
"request": "launch",
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nhwc_miopen_cuda_float32"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
}
},
{
"name": "eval",
"type": "debugpy",
"request": "launch",
"cwd": "${workspaceFolder}/test",
"program": "test_nn.py",
"console": "integratedTerminal",
"args": [
"-v",
"-k",
"test_batchnorm_nhwc_cuda"
],
"env": {
"MIOPEN_ENABLE_LOGGING_CMD": "1",
"PYTORCH_MIOPEN_EXTRA_LOGGING": "1",
"PYTORCH_MIOPEN_USE_API_V2": "1",
}
}
]
}
96 changes: 78 additions & 18 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <utility>
#include <vector>
#include <iostream>

static const int MIOPEN_DIM_MAX = 5;

Expand Down Expand Up @@ -153,7 +154,6 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
}
return std::make_tuple(output, save_mean, save_invstd);
}

const int64_t ndim = input.dim();
// Helper to convert 1d tensors to an nd tensor that broadcasts with input
// All elements go into the channel dimension
Expand Down Expand Up @@ -484,10 +484,13 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);

BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {

if (at::native::PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "********************* _select_batch_norm_backend" << std::endl;
auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();

Expand All @@ -514,25 +517,44 @@ BatchNormBackend _select_batch_norm_backend(
// See #64427
// non static variable is used to be able to change environment variable in runtime for testing
bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false);


if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout << "**+** SUGGEST_NHWC=" << PYTORCH_MIOPEN_SUGGEST_NHWC
<< " cudnn_enabled=" << cudnn_enabled
<< " dim=" << input.dim()
<< " memory_format=" << input.suggest_memory_format()
<< " input.dtype=" << input.scalar_type()
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " training=" << training
<< std::endl;
if (
input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC))
&& input.dim() <= MIOPEN_DIM_MAX
&& (input.dim() >= 3)
&&
(
(input.scalar_type() == at::kFloat && input.suggest_memory_format() == MemoryFormat::Contiguous && weight.scalar_type() == at::kFloat)
||
(input.scalar_type() == at::kFloat && input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC && weight.scalar_type() == at::kFloat)
||
(input.scalar_type() == at::kHalf) // && input.suggest_memory_format() == MemoryFormat::ChannelsLast /* && weight.scalar_type() == at::kFloat*/)
||
(input.scalar_type() == at::kBFloat16) // && input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC && weight.scalar_type() == at::kBFloat16)
)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training))
) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout << "***** BatchNormBackend::Miopen" << std::endl;
return BatchNormBackend::Miopen;
}

if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout << "***** BatchNormBackend::Native" << std::endl;
return BatchNormBackend::Native;
}

Expand All @@ -546,6 +568,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "********************* _batch_norm_impl_index"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Expand Down Expand Up @@ -605,10 +641,12 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

if (backend == BatchNormBackend::Miopen) {
return std::tuple_cat(
if (backend == BatchNormBackend::Miopen) {
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(),
input.contiguous(input.suggest_memory_format()),
weight.contiguous(),
bias.contiguous(),
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
Expand All @@ -625,9 +663,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
int64_t impl_index,
const Tensor& input, const Tensor& grad_output, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, const std::optional<Tensor>& save_mean_opt /* optional */, const std::optional<Tensor>& save_var_transform_opt /* optional */,
const Tensor& input,
const Tensor& grad_output,
const std::optional<Tensor>& weight_opt /* optional */,
const std::optional<Tensor>& running_mean_opt /* optional */,
const std::optional<Tensor>& running_var_opt /* optional */,
const std::optional<Tensor>& save_mean_opt /* optional */,
const std::optional<Tensor>& save_var_transform_opt /* optional */,
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "********************* _batch_norm_impl_index_backward" << std::endl;
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down Expand Up @@ -674,6 +720,20 @@ Tensor batch_norm(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
bool training, double momentum, double eps, bool cudnn_enabled) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "********************* batch_norm"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down
Loading