Skip to content

Commit 491ca91

Browse files
authored
Allow policy to be loaded on CPU. (#98)
1 parent 750e845 commit 491ca91

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

rsl_rl/runners/on_policy_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ def save(self, path: str, infos=None):
423423
if self.logger_type in ["neptune", "wandb"] and not self.disable_logs:
424424
self.writer.save_model(path, self.current_learning_iteration)
425425

426-
def load(self, path: str, load_optimizer: bool = True):
427-
loaded_dict = torch.load(path, weights_only=False)
426+
def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None):
427+
loaded_dict = torch.load(path, weights_only=False, map_location=map_location)
428428
# -- Load model
429429
resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"])
430430
# -- Load RND model if used

0 commit comments

Comments
 (0)