Skip to content

Commit

Permalink
add SimpleRMSNorm (#465)
Browse files Browse the repository at this point in the history
Summary:
This PR implements SimpleRMSNorm, as proposed in:
Scaling TransNormer to 175 Billion Parameters
and is:
"
In TransNormerLLM, we replace the RMSNorm with a new simple normalization
function called SimpleRMSNorm, abbreviated as SRMSNorm:
SRMSNorm(x) = x / ∥x∥2/√d

We empirically find that using SRMSNorm does not lead to any performance loss, as demonstrated in
the ablation study [below]:

Norm Type Params Updates Loss PPL
SRMSNorm 385M 100K 2.247 4.765
RMSNorm 385M 100K 2.247 4.766
LayerNorm 385M 100K 2.247 4.765
"

note that their architecture is not a TransFormer but a TransNormer...therefore, I tested this on gpt2 transformer and saw equivalent results between LayerNorm and SimpleRMSNorm as below:

<img width="494" alt="simpleRMS_gpt2" src="https://github.com/facebookresearch/multimodal/assets/46302957/7239ed80-60c8-4dec-ad89-62c180bb6b2a">

In addition, SimpleRMSNorm is ~ 34% faster vs regular RMSNorm (eager mode comparison).

Pull Request resolved: #465

Test Plan: Tested on GPT2 training as shown above, and have added 4 unit tests (2 for BF16 and 2 for FP32 dtypes).

Reviewed By: ebsmothers

Differential Revision: D49638459

Pulled By: pbontrager

fbshipit-source-id: 203b2bdd95dd79a5817060d85fc5920c6523733a
  • Loading branch information
lessw2020 authored and facebook-github-bot committed Sep 26, 2023
1 parent d0cf041 commit 0793eb4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/modules/layers/test_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Fp32GroupNorm,
Fp32LayerNorm,
RMSNorm,
SimpleRMSNorm,
)


Expand Down Expand Up @@ -61,3 +62,63 @@ def test_rms_norm_core_algo():
assert_expected(output_ones, input_ones)
assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05)
assert output_fixed.dtype == torch.float32


def test_simple_rmsnorm():
dims = 12
srms_norm = SimpleRMSNorm(dims)

input_bf16_ones = torch.ones(dims, dtype=torch.bfloat16)

input_fixed_fp32 = torch.tensor(
[
0.999,
1.1111,
2.222,
3.333,
4.444,
5.555,
6.678,
7.987,
8.123,
9.101010,
110.00,
120.2589,
],
dtype=torch.float32,
)

expected_output_bf16_ones = torch.tensor(
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
dtype=torch.bfloat16,
)
expected_output_fixed = torch.tensor(
[
0.0211,
0.0235,
0.0469,
0.0704,
0.0939,
0.1174,
0.1411,
0.1687,
0.1716,
0.1923,
2.3238,
2.5405,
],
dtype=torch.float32,
)

actual_output_bf16_ones = srms_norm(input_bf16_ones)
actual_output_fixed = srms_norm(input_fixed_fp32)

# verify ones output and dtype
assert_expected(
actual_output_bf16_ones, expected_output_bf16_ones, atol=1e-04, rtol=1e-05
)
assert actual_output_bf16_ones.dtype == torch.bfloat16

# verify fixed output and dtype
assert_expected(actual_output_fixed, expected_output_fixed, atol=1e-04, rtol=1e-05)
assert actual_output_fixed.dtype == torch.float32
22 changes: 22 additions & 0 deletions torchmultimodal/modules/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,25 @@ def _norm(self, x: Tensor) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
x_normed = self._norm(x.float()).type_as(x)
return x_normed * self.scale


class SimpleRMSNorm(nn.Module):
"""Simple RMSNorm
SRMSNorm(x) = (x / ∥x∥2) /√d
as proposed in:
Scaling TransNormer to 175 Billion Parameters
https://arxiv.org/abs/2307.14995
Usage: use as drop in replacement for RMSNorm.
"""

def __init__(self, dim: int, eps: float = 1e-12):
super().__init__()
self.scaling = dim**0.5
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
denom = x.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(x)
return (x / denom) * self.scaling

0 comments on commit 0793eb4

Please sign in to comment.