Skip to content

Commit

Permalink
Support using fp16 master weights and fp16/fp8 optimizer states in Fu…
Browse files Browse the repository at this point in the history
…sedAdam (NVIDIA#1078)

* Add precision aware fused adam

Signed-off-by: kunlunl <[email protected]>

* Minor changes based on review comments.

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Kunlun Li <[email protected]>

---------

Signed-off-by: kunlunl <[email protected]>
Signed-off-by: Kunlun Li <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent 23caab3 commit 05c0fb0
Show file tree
Hide file tree
Showing 3 changed files with 542 additions and 71 deletions.
233 changes: 225 additions & 8 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from itertools import product
import copy
from contextlib import nullcontext

import pytest
import torch
Expand Down Expand Up @@ -174,6 +175,216 @@ def test_frozen_model(self):

torch.testing.assert_close(ref_param, tst_param)

def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()

ref_params = []
model_params = []

for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)

options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)

def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)

for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)

state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
tst_optim.load_state_dict(state_dict)

for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)

def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
Expand All @@ -185,12 +396,10 @@ def test_bf16_model_weight_cast(self):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
Expand All @@ -200,12 +409,17 @@ def test_bf16_model_weight_cast(self):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)

for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
Expand All @@ -224,12 +438,10 @@ def test_fp8_model_weight_cast(self):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
Expand All @@ -239,12 +451,17 @@ def test_fp8_model_weight_cast(self):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)

for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct AdamFunctorMaster {
}
};

template <typename T, typename FULL_T, typename index_t>
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
Expand All @@ -199,10 +199,10 @@ struct AdamFunctor {
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;

T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
Expand All @@ -223,10 +223,10 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
Expand Down Expand Up @@ -259,9 +259,9 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
Expand Down Expand Up @@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();

Expand All @@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
Expand All @@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
Expand Down
Loading

0 comments on commit 05c0fb0

Please sign in to comment.