Skip to content

Mistral with Flash atteniton v2 give error on long sequence input and max_new_tokens #27682

@binarycrayon

Description

@binarycrayon

System Info

system information

WSL on windows 11

Hardware

RTX3090Ti

Software

Python: 3.10.13
Transformers: 4.35

Input Sequence Length

15901 throws error
13266 was fine

Model Initialization

model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    trust_remote_code=True, 
    torch_dtype = torch.bfloat16,
    load_in_8bit=True,
    device_map="cuda:0", 
    use_flash_attention_2=True)

Who can help?

@younesbelkada @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction


from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_id)


model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    trust_remote_code=True, 
    torch_dtype = torch.bfloat16,
     device_map="cuda:0", 
    use_flash_attention_2=True
    )
# model.config.pad_token_id = tokenizer.pad_token_id

text = "<long text here>"

inputs = tokenizer(text, return_tensors="pt")
# print(model.device)  # Check model device
inputs = {k: v.to(model.device) for k, v in inputs.items()}

output = model.generate(**inputs, max_new_tokens=128)

Expected behavior

I expected the inference to work properly, got error when the input has 15901 characters,
it worked fine when the input has 13266 characters

Issue

ValueError                                Traceback (most recent call last)
/notebooks/Untitled.ipynb Cell 15 line 8
File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1754](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1754), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   [1737](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1737)     return self.assisted_decoding(
   [1738](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1738)         input_ids,
   [1739](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1739)         assistant_model=assistant_model,
ref='~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:0'>0</a>;32m   (...)
   [1750](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1750)         **model_kwargs,
   [1751](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1751)     )
   [1752](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1752) if generation_mode == GenerationMode.GREEDY_SEARCH:
   [1753](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1753)     # 11. run greedy search
-> [1754](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1754)     return self.greedy_search(
   [1755](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1755)         input_ids,
   [1756](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1756)         logits_processor=logits_processor,
   [1757](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1757)         stopping_criteria=stopping_criteria,
   [1758](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1758)         pad_token_id=generation_config.pad_token_id,
   [1759](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1759)         eos_token_id=generation_config.eos_token_id,
   [1760](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1760)         output_scores=generation_config.output_scores,
   [1761](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1761)         return_dict_in_generate=generation_config.return_dict_in_generate,
   [1762](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1762)         synced_gpus=synced_gpus,
   [1763](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1763)         streamer=streamer,
   [1764](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1764)         **model_kwargs,
   [1765](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1765)     )
   [1767](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1767) elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   [1768](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:1768)     if not model_kwargs["use_cache"]:

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2615](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2615), in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   [2612](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2612) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2614](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2614) # forward pass to get next token
-> [2615](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2615) outputs = self(
   [2616](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2616)     **model_inputs,
   [2617](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2617)     return_dict=True,
   [2618](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2618)     output_attentions=output_attentions,
   [2619](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2619)     output_hidden_states=output_hidden_states,
   [2620](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2620) )
   [2622](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2622) if synced_gpus and this_peer_finished:
   [2623](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/generation/utils.py:2623)     continue  # don't waste resources running the code we don't need

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     result = None

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1007](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1007), in MistralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   [1004](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1004) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   [1006](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1006) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1007](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1007) outputs = self.model(
   [1008](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1008)     input_ids=input_ids,
   [1009](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1009)     attention_mask=attention_mask,
   [1010](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1010)     position_ids=position_ids,
   [1011](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1011)     past_key_values=past_key_values,
   [1012](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1012)     inputs_embeds=inputs_embeds,
   [1013](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1013)     use_cache=use_cache,
   [1014](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1014)     output_attentions=output_attentions,
   [1015](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1015)     output_hidden_states=output_hidden_states,
   [1016](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1016)     return_dict=return_dict,
   [1017](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1017) )
   [1019](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1019) hidden_states = outputs[0]
   [1020](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1020) logits = self.lm_head(hidden_states)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     result = None

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:895](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:895), in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    [885](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:885)     layer_outputs = self._gradient_checkpointing_func(
    [886](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:886)         decoder_layer.__call__,
    [887](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:887)         hidden_states,
ref='~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:0'>0</a>;32m   (...)
    [892](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:892)         use_cache,
    [893](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:893)     )
    [894](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:894) else:
--> [895](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:895)     layer_outputs = decoder_layer(
    [896](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:896)         hidden_states,
    [897](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:897)         attention_mask=attention_mask,
    [898](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:898)         position_ids=position_ids,
    [899](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:899)         past_key_value=past_key_value,
    [900](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:900)         output_attentions=output_attentions,
    [901](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:901)         use_cache=use_cache,
    [902](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:902)     )
    [904](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:904) hidden_states = layer_outputs[0]
    [906](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:906) if use_cache:

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     result = None

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:624](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:624), in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    [621](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:621) hidden_states = self.input_layernorm(hidden_states)
    [623](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:623) # Self Attention
--> [624](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:624) hidden_states, self_attn_weights, present_key_value = self.self_attn(
    [625](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:625)     hidden_states=hidden_states,
    [626](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:626)     attention_mask=attention_mask,
    [627](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:627)     position_ids=position_ids,
    [628](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:628)     past_key_value=past_key_value,
    [629](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:629)     output_attentions=output_attentions,
    [630](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:630)     use_cache=use_cache,
    [631](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:631) )
    [632](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:632) hidden_states = residual + hidden_states
    [634](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:634) # Fully Connected

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     result = None

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:376](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:376), in MistralFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    [373](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:373) past_value = past_value[:, :, slicing_tokens:, :].contiguous()
    [375](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:375) if past_key.shape[-2] != self.config.sliding_window - 1:
--> [376](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:376)     raise ValueError(
    [377](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:377)         f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
    [378](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:378)         f" {past_key.shape}"
    [379](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:379)     )
    [381](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:381) past_key_value = (past_key, past_value)
    [383](//notebooks/~/dev/alignment-handbook/CondaENV/env/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:383) if attention_mask is not None:

ValueError: past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got torch.Size([1, 8, 3628, 128])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions