File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
py/torch_tensorrt/dynamo/conversion Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change
1
+ import gc
1
2
import io
2
3
import logging
3
4
import os
@@ -366,7 +367,7 @@ def check_weight_equal(
366
367
network_weight = torch .from_numpy (network_weight ).cuda ()
367
368
try :
368
369
return sd_weight .shape == network_weight .shape and torch .all (
369
- torch .abs (sd_weight - network_weight ) < 0.1
370
+ torch .abs (sd_weight - network_weight ) < 0.01
370
371
)
371
372
except Exception :
372
373
return torch .all (sd_weight == network_weight )
@@ -425,7 +426,7 @@ def _save_weight_mapping(self) -> None:
425
426
)
426
427
}
427
428
"""
428
- _LOGGER .info ("building weight name mapping..." )
429
+ _LOGGER .info ("Building weight name mapping..." )
429
430
# Stage 1: Name mapping
430
431
sd = self .module .state_dict ()
431
432
torch_device = to_torch_device (self .compilation_settings .device )
@@ -501,6 +502,7 @@ def _save_weight_mapping(self) -> None:
501
502
self .weight_name_map = weight_name_map
502
503
503
504
del np_map , sd
505
+ gc .collect ()
504
506
torch .cuda .empty_cache ()
505
507
506
508
def run (
You can’t perform that action at this time.
0 commit comments