File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -87,9 +87,8 @@ def execute_model_spmd(
8787 # TODO(swang): This is needed right now because Ray Compiled Graph
8888 # executes on a background thread, so we need to reset torch's
8989 # current device.
90- import torch
9190 if not self .compiled_dag_cuda_device_set :
92- torch . cuda .set_device (self .worker .device )
91+ current_platform .set_device (self .worker .device )
9392 self .compiled_dag_cuda_device_set = True
9493
9594 output = self .worker ._execute_model_spmd (execute_model_req ,
@@ -113,8 +112,7 @@ def setup_device_if_necessary(self):
113112 # Not needed
114113 pass
115114 else :
116- import torch
117- torch .cuda .set_device (self .worker .device )
115+ current_platform .set_device (self .worker .device )
118116
119117 self .compiled_dag_cuda_device_set = True
120118
You can’t perform that action at this time.
0 commit comments