Skip to content

Commit cebc22f

Browse files
authored
[Misc]Replace cuda hard code with current_platform in Ray (#14668)
Signed-off-by: noemotiovon <[email protected]>
1 parent 6c6dcd8 commit cebc22f

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

vllm/executor/ray_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ def execute_model_spmd(
8787
# TODO(swang): This is needed right now because Ray Compiled Graph
8888
# executes on a background thread, so we need to reset torch's
8989
# current device.
90-
import torch
9190
if not self.compiled_dag_cuda_device_set:
92-
torch.cuda.set_device(self.worker.device)
91+
current_platform.set_device(self.worker.device)
9392
self.compiled_dag_cuda_device_set = True
9493

9594
output = self.worker._execute_model_spmd(execute_model_req,
@@ -113,8 +112,7 @@ def setup_device_if_necessary(self):
113112
# Not needed
114113
pass
115114
else:
116-
import torch
117-
torch.cuda.set_device(self.worker.device)
115+
current_platform.set_device(self.worker.device)
118116

119117
self.compiled_dag_cuda_device_set = True
120118

0 commit comments

Comments
 (0)