-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
System Info
A100
tensorrt_llm 0.20.0rc3
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I ran the Gemma 3 4B model using both TensorRT and Transformers, but I’m not getting identical outputs.
Serving with Transformers:
from transformers import AutoTokenizer, Gemma3ForCausalLM
ckpt = "/home/dev/ra_workspace/ra_workspace/gemma3_4b/"
model = Gemma3ForCausalLM.from_pretrained(
ckpt, torch_dtype=torch.float32, device_map="auto", attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
messages = [
{"role": "user", "content": [{"type": "text", "text": "سلام چه طوری؟"}]}
]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True,
return_dict=True, return_tensors="pt",
).to(model.device)
generation = model.generate(**inputs, max_new_tokens=256, do_sample=False)
Serving with TensorRT-LLM:
In this case, we used version 0.20.0rc3. The model was converted and built using the following arguments:
"convert": {
"model-dir": model_files + "/model",
"output-model-dir": "/app/data/tllm_checkpoint",
"dtype": "float",
"ckpt-type": "hf"
},
"build": {
"checkpoint_dir": "/app/data/tllm_checkpoint",
"output_dir": "/app/model_repository/tensorrt_llm/1",
"gemm_plugin": "disable",
"max_batch_size": "32"
}
Serve-time arguments:
"args": {
"triton_max_batch_size": "128",
"tokenizer_dir": model_files + "/model",
"decoupled_mode": "true",
"engine_dir": "/app/model_repository/tensorrt_llm/1",
"batch_scheduler_policy": "max_utilization",
"max_tokens_in_paged_kv_cache": "16384",
"postprocessing_instance_count": "1",
"preprocessing_instance_count": "1",
"bls_instance_count": "1",
"add_special_tokens": "false",
"triton_backend": "tensorrtllm",
"max_queue_delay_microseconds": "1000000",
"max_beam_width": "1",
"batching_strategy": "inflight_fused_batching",
"max_attention_window_size": "1024/,1024/,1024/,1024/,1024/,16384",
"accumulate_tokens": "",
"tensorrt_llm_model_name": "tensorrt_llm",
"max_queue_size": "100000",
"encoder_engine_dir": "",
"exclude_input_in_output": "true",
"encoder_input_features_data_type": "TYPE_BF16",
"logits_datatype": "TYPE_FP32",
"xgrammar_tokenizer_info_path": ""
}
message='<bos><start_of_turn>user\nسلام چه طوری؟<end_of_turn>\n<start_of_turn>model\n'
data = {
"text_input": message,
"temperature": 0.00,
"top_k": 1,
"max_tokens": 1024,
"stream": False
}
url = 'http://0.0.0.0:8000/v2/models/ensemble/generate_stream'
response = requests.post(url, json=data, headers={"Content-Type": "application/json"})
Expected behavior
exact output
actual behavior
they are not same as each other
additional notes
We have modified some parts of the tensorrt_llm library to make it behave more similarly to transformers. The following changes were made:
The value of eps in the layer_norm within the attention layer was not the same.
The eps value should be changed to 1e-6.
In the file /usr/local/lib/python3.12/dist-packages/tensorrt_llm/functional.py,
inside the function:
def rms_norm(input: Tensor,
normalized_shape: Union[int, Tuple[int]],
num_groups: int = 1,
weight: Optional[Tensor] = None,
eps: float = 1e-06) -> Tensor:
We modified the section related to with precision("float32") as follows:
with precision("float32"):
input_dtype = input.dtype
fp32_input = cast(input, "float32")
varx = pow(fp32_input, 2.0)
varx = varx.mean(dim=dim, keepdim=True)
denom = varx + eps
denom = denom.sqrt()
fp32_y = fp32_input / denom
if num_groups > 1:
fp32_y = fp32_y.view(old_shape)
if weight is not None:
fp32_y = fp32_y * weight
y = cast(fp32_y, input_dtype)
3-
In the file /usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/gemma/convert.py,
in the following section:
elif any(keyword in name for keyword in (
"pre_attention_norm.scale",
"pre_ffw_norm.scale",
"final_norm.scale",
"pre_attention_norm/vars/0",
"pre_ffw_norm/vars/0",
"rms_normalization/vars/0",
"input_layernorm",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"model.norm.weight",
"q_norm.weight",
"k_norm.weight",
)):
param = param + 1.0 # upcasted to float32 in case of bfloat16
add_trt_llm_weight(weights, trt_llm_name, param, 'float32')
The value of trt_llm_config.dtype was changed to "float32".
4-
In /usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/gemma/model.py,
we changed:
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
to:
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype="float32")
However, despite all these changes, the outputs are still different.