-
Couldn't load subscription status.
- Fork 31k
Closed
Labels
Description
System Info
transformersversion: 4.44.1- Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.4
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.4.0a0+f70bd71a48.nv24.06 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: distributed
- Using GPU in script?: YES
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
There's a implementation difference between HF transformers' RMSNorm and Nvidia transformer_engine's RMSNorm.
Version: transformer-engine 1.7.0+4e7caa1
First define HFRMSNorm code, which is copied from modeling_llama implementation from transformers library.
import torch
from torch import nn
class HFRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6, config=None):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
Next, run the test code:
import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm
class TestLayerNormComparison(unittest.TestCase):
def setUp(self):
self.hidden_size = 4096
self.batch_size = 1
self.seq_length = 1024
self.eps = 1e-5
self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))
self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)
self.te_layernorm.weight = self.shared_weight
self.hf_rmsnorm.weight = self.shared_weight
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.te_layernorm.to(self.device)
self.hf_rmsnorm.to(self.device)
def test_layernorm_comparison(self):
input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
dtype=torch.bfloat16, device=self.device)
with torch.no_grad():
te_output = self.te_layernorm(input_tensor)
hf_output = self.hf_rmsnorm(input_tensor)
assert torch.allclose(te_output, hf_output, atol=1e-2)
if __name__ == '__main__':
unittest.main()
The assertion will fail.
Expected behavior
If we change the last line of HFRMSNorm from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.
We have a discussion here, and I agree that we should all internal computation in FP32. So what's your opinion on HF side?