Skip to content

The output of Gemma 3 4B for TensorRT and Transformers is not the same, even when using float32 #4815

@Alireza3242

Description

@Alireza3242

System Info

A100
tensorrt_llm 0.20.0rc3

Who can help?

@kaiyux

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I ran the Gemma 3 4B model using both TensorRT and Transformers, but I’m not getting identical outputs.

Serving with Transformers:

from transformers import AutoTokenizer, Gemma3ForCausalLM
ckpt = "/home/dev/ra_workspace/ra_workspace/gemma3_4b/"

model = Gemma3ForCausalLM.from_pretrained(
    ckpt, torch_dtype=torch.float32, device_map="auto", attn_implementation="eager"
)

tokenizer = AutoTokenizer.from_pretrained(ckpt)

messages = [
    {"role": "user", "content": [{"type": "text", "text": "سلام چه طوری؟"}]}
]

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True,
    return_dict=True, return_tensors="pt",
).to(model.device)

generation = model.generate(**inputs, max_new_tokens=256, do_sample=False)

Serving with TensorRT-LLM:

In this case, we used version 0.20.0rc3. The model was converted and built using the following arguments:

"convert": {
  "model-dir": model_files + "/model",
  "output-model-dir": "/app/data/tllm_checkpoint",
  "dtype": "float",
  "ckpt-type": "hf"
},
"build": {
  "checkpoint_dir": "/app/data/tllm_checkpoint",
  "output_dir": "/app/model_repository/tensorrt_llm/1",
  "gemm_plugin": "disable",
  "max_batch_size": "32"
}

Serve-time arguments:

"args": {
  "triton_max_batch_size": "128",
  "tokenizer_dir": model_files + "/model",
  "decoupled_mode": "true",
  "engine_dir": "/app/model_repository/tensorrt_llm/1",
  "batch_scheduler_policy": "max_utilization",
  "max_tokens_in_paged_kv_cache": "16384",
  "postprocessing_instance_count": "1",
  "preprocessing_instance_count": "1",
  "bls_instance_count": "1",
  "add_special_tokens": "false",
  "triton_backend": "tensorrtllm",
  "max_queue_delay_microseconds": "1000000",
  "max_beam_width": "1",
  "batching_strategy": "inflight_fused_batching",
  "max_attention_window_size": "1024/,1024/,1024/,1024/,1024/,16384",
  "accumulate_tokens": "",
  "tensorrt_llm_model_name": "tensorrt_llm",
  "max_queue_size": "100000",
  "encoder_engine_dir": "",
  "exclude_input_in_output": "true",
  "encoder_input_features_data_type": "TYPE_BF16",
  "logits_datatype": "TYPE_FP32",
  "xgrammar_tokenizer_info_path": ""
}
message='<bos><start_of_turn>user\nسلام چه طوری؟<end_of_turn>\n<start_of_turn>model\n'
data = {
  "text_input": message,
  "temperature": 0.00,
  "top_k": 1,
  "max_tokens": 1024,
  "stream": False
}

url = 'http://0.0.0.0:8000/v2/models/ensemble/generate_stream'

response = requests.post(url, json=data, headers={"Content-Type": "application/json"})

Expected behavior

exact output

actual behavior

they are not same as each other

additional notes

We have modified some parts of the tensorrt_llm library to make it behave more similarly to transformers. The following changes were made:

The value of eps in the layer_norm within the attention layer was not the same.
The eps value should be changed to 1e-6.

In the file /usr/local/lib/python3.12/dist-packages/tensorrt_llm/functional.py,
inside the function:

def rms_norm(input: Tensor,
             normalized_shape: Union[int, Tuple[int]],
             num_groups: int = 1,
             weight: Optional[Tensor] = None,
             eps: float = 1e-06) -> Tensor:

We modified the section related to with precision("float32") as follows:

with precision("float32"):
    input_dtype = input.dtype
    fp32_input = cast(input, "float32")
    varx = pow(fp32_input, 2.0)

    varx = varx.mean(dim=dim, keepdim=True)
    denom = varx + eps
    denom = denom.sqrt()
    fp32_y = fp32_input / denom
    if num_groups > 1:
        fp32_y = fp32_y.view(old_shape)

    if weight is not None:
        fp32_y = fp32_y * weight
    y = cast(fp32_y, input_dtype)

3-
In the file /usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/gemma/convert.py,
in the following section:

elif any(keyword in name for keyword in (
        "pre_attention_norm.scale",
        "pre_ffw_norm.scale",
        "final_norm.scale",
        "pre_attention_norm/vars/0",
        "pre_ffw_norm/vars/0",
        "rms_normalization/vars/0",
        "input_layernorm",
        "post_attention_layernorm",
        "pre_feedforward_layernorm",
        "post_feedforward_layernorm",
        "model.norm.weight",
        "q_norm.weight",
        "k_norm.weight",
)):
    param = param + 1.0  # upcasted to float32 in case of bfloat16
    add_trt_llm_weight(weights, trt_llm_name, param, 'float32')

The value of trt_llm_config.dtype was changed to "float32".

4-
In /usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/gemma/model.py,
we changed:

self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                               eps=config.norm_epsilon,
                               dtype=config.dtype)

to:

self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                               eps=config.norm_epsilon,
                               dtype="float32")

However, despite all these changes, the outputs are still different.

Metadata

Metadata

Assignees

Labels

Model customization<NV>Adding support for new model architectures or variantsOOTB<NV>Support models out of the boxbugSomething isn't workingtriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions