Skip to content

Commit 6bfd978

Browse files
committed
Made revision according to comments
1 parent 9b801e7 commit 6bfd978

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
4141
from torch_tensorrt.fx.observer import Observer
4242
from torch_tensorrt.logging import TRT_LOGGER
43-
from tqdm import tqdm
4443

4544
import tensorrt as trt
4645
from packaging import version
@@ -341,13 +340,21 @@ def _construct_trt_network_def(self) -> None:
341340

342341
@staticmethod
343342
def find_weight(
344-
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
343+
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
345344
) -> str:
345+
"""
346+
We need to build map from engine weight name to state_dict weight name.
347+
The purpose of this function is to find the corresponding weight name in module state_dict.
348+
349+
weight_name: the target weight name we want to search for
350+
np_map: the map from weight name to np values in INetworkDefinition
351+
state_dict: state of the graph module
352+
"""
346353
network_weight = np_map[weight_name]
347354
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
348-
for sd_w_name, sd_weight in sd.items():
355+
for sd_w_name, sd_weight in state_dict.items():
349356
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
350-
del sd[sd_w_name]
357+
del state_dict[sd_w_name]
351358
return sd_w_name
352359
return ""
353360

@@ -475,7 +482,7 @@ def _save_weight_mapping(self) -> None:
475482
np_map[engine_weight_name] = weight
476483

477484
# Stage 2: Value mapping
478-
for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()):
485+
for engine_weight_name, sd_weight_name in weight_name_map.items():
479486
if "SCALE" in engine_weight_name:
480487
# There is no direct connection in batch_norm layer. So skip it
481488
pass

0 commit comments

Comments
 (0)