|
40 | 40 | from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
|
41 | 41 | from torch_tensorrt.fx.observer import Observer
|
42 | 42 | from torch_tensorrt.logging import TRT_LOGGER
|
43 |
| -from tqdm import tqdm |
44 | 43 |
|
45 | 44 | import tensorrt as trt
|
46 | 45 | from packaging import version
|
@@ -341,13 +340,21 @@ def _construct_trt_network_def(self) -> None:
|
341 | 340 |
|
342 | 341 | @staticmethod
|
343 | 342 | def find_weight(
|
344 |
| - weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] |
| 343 | + weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any] |
345 | 344 | ) -> str:
|
| 345 | + """ |
| 346 | + We need to build map from engine weight name to state_dict weight name. |
| 347 | + The purpose of this function is to find the corresponding weight name in module state_dict. |
| 348 | +
|
| 349 | + weight_name: the target weight name we want to search for |
| 350 | + np_map: the map from weight name to np values in INetworkDefinition |
| 351 | + state_dict: state of the graph module |
| 352 | + """ |
346 | 353 | network_weight = np_map[weight_name]
|
347 | 354 | network_weight = torch.from_numpy(np_map[weight_name]).cuda()
|
348 |
| - for sd_w_name, sd_weight in sd.items(): |
| 355 | + for sd_w_name, sd_weight in state_dict.items(): |
349 | 356 | if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
|
350 |
| - del sd[sd_w_name] |
| 357 | + del state_dict[sd_w_name] |
351 | 358 | return sd_w_name
|
352 | 359 | return ""
|
353 | 360 |
|
@@ -475,7 +482,7 @@ def _save_weight_mapping(self) -> None:
|
475 | 482 | np_map[engine_weight_name] = weight
|
476 | 483 |
|
477 | 484 | # Stage 2: Value mapping
|
478 |
| - for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()): |
| 485 | + for engine_weight_name, sd_weight_name in weight_name_map.items(): |
479 | 486 | if "SCALE" in engine_weight_name:
|
480 | 487 | # There is no direct connection in batch_norm layer. So skip it
|
481 | 488 | pass
|
|
0 commit comments