-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix mistral generate for long prompt / response #27548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hey @lorabit110 👋 I'm not sure if I follow the need for a fix here :) On my end, the test case you added passes: from transformers import AutoModelForCausalLM
import torch
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
input_ids = [1] + [306, 338] * 2048
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16, use_flash_attention_2=True)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
print(input_ids.shape)
generated_ids = model.generate(input_ids, max_new_tokens=2)
print(generated_ids[0][-2:].tolist() == EXPECTED_OUTPUT_TOKEN_IDS)
# TrueCan you share a short reproducible script that results in the failure? |
Actually, we need to set max_new_tokens to at least 3 to trigger the error. The below code can reproduce the issue: from transformers import AutoModelForCausalLM
import torch
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16, use_flash_attention_2=True)
input_ids = [1] + [306, 338] * 2048
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4)
print(generated_ids[0][-2:].tolist() == EXPECTED_OUTPUT_TOKEN_IDS)
|
|
Can anyone take a look? |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lorabit110 I see, it makes sense!
I've added a comment to expand your solution to a more general one. I'm also tagging @younesbelkada (who added the FA2 code) for a quick double-check :)
| # Activate slicing cache only if the config has a value `sliding_windows` attribute | ||
| if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: | ||
| slicing_tokens = kv_seq_len - self.config.sliding_window | ||
| slicing_tokens = 1 - self.config.sliding_window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of 1, we should place query_length (= query_states.shape[-2]) here.
Query length is almost always 1 in autoregressive generation, corresponding to the latest token. However, in some advanced applications like assisted generation, it can be larger than 1. Adding the generation solution here would be nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added assisted generation test case. It seems to work. let me know if we still want to update this to query_length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, I had a deep look at this PR. Note before #25242 the generation with large sequence length was working because prepare_inputs_for_generation was always slicing the input_ids to consider the last token only. i.e. input_ids = input_ids[:, -1]
With the changes of #25242 being merged, it changes the assumption we made on mistral modeling code about 1D tokens being passed in case one uses cache.
I can also confirm the generations looks correct after this change (input of ~4800 tokens):
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
text = """
# LoRA
This conceptual guide gives a brief overview of LoRA, a technique that accelerates the fine-tuning of large models while consuming less memory.
To make fine-tuning more efficient, LoRA’s approach is to represent the weight updates with two smaller matrices (called update matrices) through low-rank decomposition. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are combined.
This approach has a number of advantages:
LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.
The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them.
LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them.
Performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models.
LoRA does not add any inference latency because adapter weights can be merged with the base model.
In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank update matrices, which is determined mainly by the rank r and the shape of the original weight matrix.
## Merge LoRA weights into the base model
While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA model. To eliminate latency, use the merge_and_unload() function to merge the adapter weights with the base model which allows you to effectively use the newly merged model as a standalone model.
This works because during training, the smaller weight matrices (A and B in the diagram above) are separate. But once training is complete, the weights can actually be merged into a new weight matrix that is identical.
## Utils for LoRA
Use merge_adapter() to merge the LoRa layers into the base model while retaining the PeftModel. This will help in later unmerging, deleting, loading different adapters and so on.
Use unmerge_adapter() to unmerge the LoRa layers from the base model while retaining the PeftModel. This will help in later merging, deleting, loading different adapters and so on.
Use unload() to get back the base model without the merging of the active lora modules. This will help when you want to get back the pretrained base model in some applications when you want to reset the model to its original state. For example, in Stable Diffusion WebUi, when the user wants to infer with base model post trying out LoRAs.
Use delete_adapter() to delete an existing adapter.
Use add_weighted_adapter() to combine multiple LoRAs into a new adapter based on the user provided weighing scheme.
## Common LoRA parameters in PEFT
As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to:
Instantiate a base model.
Create a configuration (LoraConfig) where you define LoRA-specific parameters.
Wrap the base model with get_peft_model() to get a trainable PeftModel.
Train the PeftModel as you normally would train the base model.
LoraConfig allows you to control how LoRA is applied to the base model through the following parameters:
r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
alpha: LoRA scaling factor.
bias: Specifies if the bias parameters should be trained. Can be 'none', 'all' or 'lora_only'.
modules_to_save: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model’s custom head that is randomly initialized for the fine-tuning task.
layers_to_transform: List of layers to be transformed by LoRA. If not specified, all layers in target_modules are transformed.
layers_pattern: Pattern to match layer names in target_modules, if layers_to_transform is specified. By default PeftModel will look at common layer pattern (layers, h, blocks, etc.), use it for exotic and custom models.
rank_pattern: The mapping from layer names or regexp expression to ranks which are different from the default rank specified by r.
alpha_pattern: The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by lora_alpha.
## LoRA examples
For an example of LoRA method application to various downstream tasks, please refer to the following guides:
- Image classification using LoRA
- Semantic segmentation
While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning models. As such, you can leverage this technique with diffusion models. See Dreambooth fine-tuning with LoRA task guide for an example.
""" * 4
text = text + """
# Conclusion
"""
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", low_cpu_mem_usage=True, torch_dtype=torch.float16, use_flash_attention_2=True).to(0)
inputs = tokenizer(text, return_tensors="pt").to(0)
print(inputs.input_ids.shape)
outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> In this guide, we have seen how to apply LoRA method to various downstream tasks, please refer to the LoRA method to various downstream tasks ...Note the changes #25242 also broke this 1D cache tokens (i.e. key_seq_len = 1) assumptions that are being made in other libraries such as autoawq: casper-hansen/AutoAWQ#146 so I wonder if we shouldn't think of a more global solution to revert back to pass key_seq_len = 1 for cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more context, inside the logic in prepare_inputs_for_generation:
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1If one passes a very large context (let's say 4800) we fall back into the condition
if input_ids.shape[1] > past_length:And past_length is always equal to 4096 for mistral, therefore the new sliced input_ids shape will be equal to batch_size, 724 which breaks the assumptions we made in the flash attention layer about input_ids having a key_seq_len being equal to 1 for cache.
The fix is correct thanks @lorabit110 , but flagging it just in case we can think of a more global solution !
|
@younesbelkada I see. I'm missing one variable here, which is how many tokens the model has seen so far through its cache -- does any of the inputs to Without it, I'm not sure how to properly slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks this is more resilient!
Would say let's add a test for @gante's example (assisted decoding) but the cache should always be sliced to have a max length of window_size, negative index makes sense to me.
|
Let's make sure the models are quantized for testing |
|
Thanks everyone for the discussion. I will need to spend sometime to figure out how to implement assisted decoding test case. |
|
All comments have been addressed. I need a maintainer to approve and run a workflow in order to merge the PR. |
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the test more memory efficient!
@ArthurZucker @gante I think we can merge no? It also fixes #27682
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Thanks @lorabit110 for the fix 😉 |

What does this PR do?
Fix the below issue:
When use mistral model to generate texts, if prompt + max_tokens > 4095 and use_cache=True, you would get the below error.
ValueError: past key much have a shape of (
batch_size, num_heads, self.config.sliding_window-1, head_dim), got torch.Size([1, 8, 2989, 128]).This PR fixes the logic that determine which part of the cached key and value should be used for predicting future tokens.
Fixes #27682
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Bam4d