Skip to content

Commit b288f80

Browse files
committed
Fixed accuracy issue of fast refit
1 parent f185abb commit b288f80

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,16 @@ def construct_refit_mapping_from_weight_name_map(
117117
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
118118
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
119119
# Batch Norm Layer
120-
params = {}
120+
params = {
121+
"weight": 1.0,
122+
"bias": 0.0,
123+
"running_mean": 0.0,
124+
"running_var": 1.0,
125+
}
121126
for w in sd_weight_name:
122-
params[w.split(".")[-1]] = state_dict[w]
123-
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
127+
if w in state_dict:
128+
params[w.split(".")[-1]] = state_dict[w]
129+
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-5)
124130
shift = params["bias"] - params["running_mean"] * scale
125131
# Set scale to scale or shift to shift
126132
engine_weight_map[engine_weight_name] = eval(
@@ -171,6 +177,11 @@ def _refit_single_trt_engine_with_gm(
171177
mapping = construct_refit_mapping_from_weight_name_map(
172178
weight_name_map, new_gm.state_dict()
173179
)
180+
181+
# Debug Use
182+
# correct = construct_refit_mapping(new_gm, input_list, settings)
183+
# {k: np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2) for k in mapping if k in correct}
184+
174185
for layer_name in weight_list:
175186
if layer_name not in mapping:
176187
logger.warning(f"{layer_name} is not found in weight mapping.")

0 commit comments

Comments
 (0)