Skip to content

Commit 9fd123a

Browse files
authored
ci: mark model_parallel tests as cuda specific (#35269)
`parallelize()` API is deprecated in favor of accelerate's `device_map="auto"` and therefore is not accepting new features. At the same time `parallelize()` implementation is currently CUDA-specific. This commit marks respective ci tests with `@require_torch_gpu`. Fixes: #35252 Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent bd442c6 commit 9fd123a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/test_modeling_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,6 +3061,7 @@ def test_multi_gpu_data_parallel_forward(self):
30613061
with torch.no_grad():
30623062
_ = model(**self._prepare_for_class(inputs_dict, model_class))
30633063

3064+
@require_torch_gpu
30643065
@require_torch_multi_gpu
30653066
def test_model_parallelization(self):
30663067
if not self.test_model_parallel:
@@ -3123,6 +3124,7 @@ def get_current_gpu_memory_use():
31233124
gc.collect()
31243125
torch.cuda.empty_cache()
31253126

3127+
@require_torch_gpu
31263128
@require_torch_multi_gpu
31273129
def test_model_parallel_equal_results(self):
31283130
if not self.test_model_parallel:

0 commit comments

Comments
 (0)