Skip to content

Commit f73f5e6

Browse files
authored
Avoid check expected exception when it is on CUDA (#34408)
* update * update --------- Co-authored-by: ydshieh <[email protected]>
1 parent e447185 commit f73f5e6

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

tests/pipelines/test_pipelines_summarization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ def run_pipeline_test(self, summarizer, _):
8585
and len(summarizer.model.trainable_weights) > 0
8686
and "GPU" in summarizer.model.trainable_weights[0].device
8787
):
88-
with self.assertRaises(Exception):
89-
outputs = summarizer("This " * 1000)
88+
if str(summarizer.device) == "cpu":
89+
with self.assertRaises(Exception):
90+
outputs = summarizer("This " * 1000)
9091
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
9192

9293
@require_torch

tests/pipelines/test_pipelines_text_generation.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,17 +493,19 @@ def run_pipeline_test(self, text_generator, _):
493493
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
494494
):
495495
# Handling of large generations
496-
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
497-
text_generator("This is a test" * 500, max_new_tokens=20)
496+
if str(text_generator.device) == "cpu":
497+
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
498+
text_generator("This is a test" * 500, max_new_tokens=20)
498499

499500
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
500501
# Hole strategy cannot work
501-
with self.assertRaises(ValueError):
502-
text_generator(
503-
"This is a test" * 500,
504-
handle_long_generation="hole",
505-
max_new_tokens=tokenizer.model_max_length + 10,
506-
)
502+
if str(text_generator.device) == "cpu":
503+
with self.assertRaises(ValueError):
504+
text_generator(
505+
"This is a test" * 500,
506+
handle_long_generation="hole",
507+
max_new_tokens=tokenizer.model_max_length + 10,
508+
)
507509

508510
@require_torch
509511
@require_accelerate

0 commit comments

Comments
 (0)