Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,7 @@ def __init__(self, suppress_tokens, device: str = "cpu"):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device))
scores = torch.where(suppress_token_mask, -float("inf"), scores)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-device cases(like put 2 devices to run):
in current implementation, in assistant decoding case, assistant model will reuse main model's SuppressTokensLogitsProcessor, which place the suppress_tokens in the same device as input_tensor (which is device 0). assistant model will ingest encoder_outputs of the main model and do the decoder(in whisper case), while encoder_outputs may in device 1 but main model's suppress_tokens which is main model's is in device 0, so lead to RuntimeError:

RuntimeError: Expected all tensors to be on the same device, but got test_elements is on xpu:0, different from other tensors on xpu:1 (when checking argument in method wrapper_XPU_isin_Tensor_Tensor)

So based on current implementation(that assistant model shares main model's SuppressTokensLogitsProcessor), I move suppress_tokens to scores.device while doing isin.

return scores

Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_whisper_language(self):
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data

# Load model:
model_id = "openai/whisper-large-v2"
Expand Down Expand Up @@ -1133,8 +1133,8 @@ def test_speculative_decoding_whisper_non_distil(self):
num_beams=1,
)

transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_ass = pipe(sample)["text"]
transcription_ass = pipe(sample.clone().detach(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_non_ass = pipe(sample)["text"]
Comment on lines +1136 to +1137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching the incorrect inversion here!


self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
Expand Down Expand Up @@ -1422,13 +1422,13 @@ def test_whisper_prompted(self):
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data

# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)

unprompted_result = pipe(sample.copy())["text"]
unprompted_result = pipe(sample.clone().detach())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]

# fmt: off
Expand Down