@@ -400,7 +400,7 @@ def _construct_trt_network_def(self) -> None:
400400 @staticmethod
401401 def find_weight (
402402 weight_name : str ,
403- np_map : dict [str , Any ],
403+ weight_refit_map : dict [str , Any ],
404404 state_dict : dict [str , Any ],
405405 device : torch .device ,
406406 ) -> str :
@@ -413,7 +413,7 @@ def find_weight(
413413 state_dict: state of the graph module
414414 """
415415 with unset_fake_temporarily ():
416- network_weight = torch . from_numpy ( np_map [weight_name ]) .to (device )
416+ network_weight = weight_refit_map [weight_name ].to (device )
417417 for sd_w_name , sd_weight in state_dict .items ():
418418 if TRTInterpreter .check_weight_equal (sd_weight , network_weight , device ):
419419 del state_dict [sd_w_name ]
@@ -427,8 +427,8 @@ def check_weight_equal(
427427 device : torch .device ,
428428 ) -> Any :
429429 with unset_fake_temporarily ():
430- if not isinstance ( network_weight , torch . Tensor ) :
431- network_weight = torch . from_numpy ( network_weight ) .to (device )
430+ if network_weight . device != device :
431+ network_weight = network_weight .to (device )
432432 try :
433433 return sd_weight .shape == network_weight .shape and torch .all (
434434 torch .abs (sd_weight - network_weight ) < 0.01
@@ -494,11 +494,10 @@ def _save_weight_mapping(self) -> None:
494494 _LOGGER .info ("Building weight name mapping..." )
495495 # Stage 1: Name mapping
496496 torch_device = to_torch_device (self .compilation_settings .device )
497- self .module .to (torch_device )
498- sd = self .module .state_dict ()
497+ sd = {k : v .to (torch_device ) for k , v in self .module .state_dict ().items ()}
499498 weight_name_map : dict [str , Any ] = {}
500- np_map = self .ctx .weight_refit_map
501- constant_mapping = {k : v for k , v in np_map .items () if v .size == 1 }
499+ weight_refit_map = self .ctx .weight_refit_map
500+ constant_mapping = {k : v for k , v in weight_refit_map .items () if v .size == 1 }
502501 net = self .ctx .net
503502 for i in range (net .num_layers ):
504503 layer = net [i ]
@@ -540,7 +539,7 @@ def _save_weight_mapping(self) -> None:
540539 else :
541540 sd_weight_name = f"{ sd_weight_name } .{ torch_attr } "
542541
543- if engine_weight_name in np_map :
542+ if engine_weight_name in weight_refit_map :
544543 weight_name_map [engine_weight_name ] = sd_weight_name
545544
546545 # Stage 2: Value mapping
@@ -549,10 +548,10 @@ def _save_weight_mapping(self) -> None:
549548 # There is no direct connection in batch_norm layer. So skip it
550549 pass
551550 elif sd_weight_name not in sd or not TRTInterpreter .check_weight_equal (
552- sd [sd_weight_name ], np_map [engine_weight_name ], torch_device
551+ sd [sd_weight_name ], weight_refit_map [engine_weight_name ], torch_device
553552 ):
554553 weight_name_map [engine_weight_name ] = TRTInterpreter .find_weight (
555- engine_weight_name , np_map , sd , torch_device
554+ engine_weight_name , weight_refit_map , sd , torch_device
556555 )
557556 if (
558557 weight_name_map [engine_weight_name ] != ""
@@ -563,12 +562,13 @@ def _save_weight_mapping(self) -> None:
563562
564563 weight_name_map [engine_weight_name ] = [
565564 weight_name_map [engine_weight_name ],
566- np_map [engine_weight_name ].dtype ,
565+ weight_refit_map [engine_weight_name ].dtype ,
567566 ]
568567
569568 weight_name_map ["constant_mapping" ] = constant_mapping
570569 self .weight_name_map = weight_name_map
571- del np_map , sd
570+
571+ del weight_refit_map , sd
572572 gc .collect ()
573573 torch .cuda .empty_cache ()
574574
0 commit comments