Skip to content

Commit df95f00

Browse files
committed
Fixed style issue
1 parent 6bfd978 commit df95f00

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import io
23
import logging
34
import os
@@ -366,7 +367,7 @@ def check_weight_equal(
366367
network_weight = torch.from_numpy(network_weight).cuda()
367368
try:
368369
return sd_weight.shape == network_weight.shape and torch.all(
369-
torch.abs(sd_weight - network_weight) < 0.1
370+
torch.abs(sd_weight - network_weight) < 0.01
370371
)
371372
except Exception:
372373
return torch.all(sd_weight == network_weight)
@@ -425,7 +426,7 @@ def _save_weight_mapping(self) -> None:
425426
)
426427
}
427428
"""
428-
_LOGGER.info("building weight name mapping...")
429+
_LOGGER.info("Building weight name mapping...")
429430
# Stage 1: Name mapping
430431
sd = self.module.state_dict()
431432
torch_device = to_torch_device(self.compilation_settings.device)
@@ -501,6 +502,7 @@ def _save_weight_mapping(self) -> None:
501502
self.weight_name_map = weight_name_map
502503

503504
del np_map, sd
505+
gc.collect()
504506
torch.cuda.empty_cache()
505507

506508
def run(

0 commit comments

Comments
 (0)