Skip to content

Conversation

@lorabit110
Copy link
Contributor

@lorabit110 lorabit110 commented Nov 16, 2023

What does this PR do?

Fix the below issue:
When use mistral model to generate texts, if prompt + max_tokens > 4095 and use_cache=True, you would get the below error.
ValueError: past key much have a shape of (batch_size, num_heads, self.config.sliding_window-1, head_dim), got torch.Size([1, 8, 2989, 128]).

This PR fixes the logic that determine which part of the cached key and value should be used for predicting future tokens.

Fixes #27682

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Bam4d

@amyeroberts
Copy link
Contributor

cc @gante @ArthurZucker

@gante
Copy link
Contributor

gante commented Nov 17, 2023

Hey @lorabit110 👋

I'm not sure if I follow the need for a fix here :) On my end, the test case you added passes:

from transformers import AutoModelForCausalLM
import torch

EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]

input_ids = [1] + [306, 338] * 2048
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16, use_flash_attention_2=True)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
print(input_ids.shape)
generated_ids = model.generate(input_ids, max_new_tokens=2)
print(generated_ids[0][-2:].tolist() == EXPECTED_OUTPUT_TOKEN_IDS)
# True

Can you share a short reproducible script that results in the failure?

@lorabit110
Copy link
Contributor Author

lorabit110 commented Nov 17, 2023

Hey @lorabit110 👋

I'm not sure if I follow the need for a fix here :) On my end, the test case you added passes:

from transformers import AutoModelForCausalLM
import torch

EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]

input_ids = [1] + [306, 338] * 2048
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16, use_flash_attention_2=True)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
print(input_ids.shape)
generated_ids = model.generate(input_ids, max_new_tokens=2)
print(generated_ids[0][-2:].tolist() == EXPECTED_OUTPUT_TOKEN_IDS)
# True

Can you share a short reproducible script that results in the failure?

@gante

Actually, we need to set max_new_tokens to at least 3 to trigger the error. The below code can reproduce the issue:

from transformers import AutoModelForCausalLM
import torch

EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16, use_flash_attention_2=True)
input_ids = [1] + [306, 338] * 2048
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4)
print(generated_ids[0][-2:].tolist() == EXPECTED_OUTPUT_TOKEN_IDS)
Screenshot 2023-11-17 at 11 14 35 AM

@lorabit110
Copy link
Contributor Author

Can anyone take a look?

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lorabit110 I see, it makes sense!

I've added a comment to expand your solution to a more general one. I'm also tagging @younesbelkada (who added the FA2 code) for a quick double-check :)

# Activate slicing cache only if the config has a value `sliding_windows` attribute
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
slicing_tokens = kv_seq_len - self.config.sliding_window
slicing_tokens = 1 - self.config.sliding_window
Copy link
Contributor

@gante gante Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of 1, we should place query_length (= query_states.shape[-2]) here.

Query length is almost always 1 in autoregressive generation, corresponding to the latest token. However, in some advanced applications like assisted generation, it can be larger than 1. Adding the generation solution here would be nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assisted generation test case. It seems to work. let me know if we still want to update this to query_length.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, I had a deep look at this PR. Note before #25242 the generation with large sequence length was working because prepare_inputs_for_generation was always slicing the input_ids to consider the last token only. i.e. input_ids = input_ids[:, -1]
With the changes of #25242 being merged, it changes the assumption we made on mistral modeling code about 1D tokens being passed in case one uses cache.

I can also confirm the generations looks correct after this change (input of ~4800 tokens):

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

text = """
# LoRA

This conceptual guide gives a brief overview of LoRA, a technique that accelerates the fine-tuning of large models while consuming less memory.

To make fine-tuning more efficient, LoRA’s approach is to represent the weight updates with two smaller matrices (called update matrices) through low-rank decomposition. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are combined.

This approach has a number of advantages:

LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.
The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them.
LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them.
Performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models.
LoRA does not add any inference latency because adapter weights can be merged with the base model.
In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank update matrices, which is determined mainly by the rank r and the shape of the original weight matrix.

## Merge LoRA weights into the base model

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA model. To eliminate latency, use the merge_and_unload() function to merge the adapter weights with the base model which allows you to effectively use the newly merged model as a standalone model.
This works because during training, the smaller weight matrices (A and B in the diagram above) are separate. But once training is complete, the weights can actually be merged into a new weight matrix that is identical.

## Utils for LoRA
Use merge_adapter() to merge the LoRa layers into the base model while retaining the PeftModel. This will help in later unmerging, deleting, loading different adapters and so on.

Use unmerge_adapter() to unmerge the LoRa layers from the base model while retaining the PeftModel. This will help in later merging, deleting, loading different adapters and so on.

Use unload() to get back the base model without the merging of the active lora modules. This will help when you want to get back the pretrained base model in some applications when you want to reset the model to its original state. For example, in Stable Diffusion WebUi, when the user wants to infer with base model post trying out LoRAs.

Use delete_adapter() to delete an existing adapter.

Use add_weighted_adapter() to combine multiple LoRAs into a new adapter based on the user provided weighing scheme.

## Common LoRA parameters in PEFT
As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to:

Instantiate a base model.
Create a configuration (LoraConfig) where you define LoRA-specific parameters.
Wrap the base model with get_peft_model() to get a trainable PeftModel.
Train the PeftModel as you normally would train the base model.
LoraConfig allows you to control how LoRA is applied to the base model through the following parameters:

r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
alpha: LoRA scaling factor.
bias: Specifies if the bias parameters should be trained. Can be 'none', 'all' or 'lora_only'.
modules_to_save: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model’s custom head that is randomly initialized for the fine-tuning task.
layers_to_transform: List of layers to be transformed by LoRA. If not specified, all layers in target_modules are transformed.
layers_pattern: Pattern to match layer names in target_modules, if layers_to_transform is specified. By default PeftModel will look at common layer pattern (layers, h, blocks, etc.), use it for exotic and custom models.
rank_pattern: The mapping from layer names or regexp expression to ranks which are different from the default rank specified by r.
alpha_pattern: The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by lora_alpha.
## LoRA examples
For an example of LoRA method application to various downstream tasks, please refer to the following guides:

- Image classification using LoRA
- Semantic segmentation
While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning models. As such, you can leverage this technique with diffusion models. See Dreambooth fine-tuning with LoRA task guide for an example.
""" * 4

text = text + """
# Conclusion

"""

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", low_cpu_mem_usage=True, torch_dtype=torch.float16, use_flash_attention_2=True).to(0)

inputs = tokenizer(text, return_tensors="pt").to(0)
print(inputs.input_ids.shape)

outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> In this guide, we have seen how to apply LoRA method to various downstream tasks, please refer to the LoRA method to various downstream tasks ...

Note the changes #25242 also broke this 1D cache tokens (i.e. key_seq_len = 1) assumptions that are being made in other libraries such as autoawq: casper-hansen/AutoAWQ#146 so I wonder if we shouldn't think of a more global solution to revert back to pass key_seq_len = 1 for cache

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context, inside the logic in prepare_inputs_for_generation:

        if past_key_values:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

If one passes a very large context (let's say 4800) we fall back into the condition

if input_ids.shape[1] > past_length:

And past_length is always equal to 4096 for mistral, therefore the new sliced input_ids shape will be equal to batch_size, 724 which breaks the assumptions we made in the flash attention layer about input_ids having a key_seq_len being equal to 1 for cache.

The fix is correct thanks @lorabit110 , but flagging it just in case we can think of a more global solution !

@gante
Copy link
Contributor

gante commented Nov 21, 2023

@younesbelkada I see. I'm missing one variable here, which is how many tokens the model has seen so far through its cache -- does any of the inputs to prepare_inputs_for_generation have this information?

Without it, I'm not sure how to properly slice input_ids without the setting the assumption that only one token can be consumed at a time (which would disable the use of mistral with assisted generation and the newly added ability to pass past_key_values across generations)

@gante gante requested a review from ArthurZucker November 21, 2023 13:49
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this is more resilient!
Would say let's add a test for @gante's example (assisted decoding) but the cache should always be sliced to have a max length of window_size, negative index makes sense to me.

@ArthurZucker
Copy link
Collaborator

Let's make sure the models are quantized for testing

@lorabit110
Copy link
Contributor Author

Thanks everyone for the discussion. I will need to spend sometime to figure out how to implement assisted decoding test case.

@lorabit110
Copy link
Contributor Author

All comments have been addressed. I need a maintainer to approve and run a workflow in order to merge the PR.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the test more memory efficient!
@ArthurZucker @gante I think we can merge no? It also fixes #27682

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker merged commit b09912c into huggingface:main Nov 27, 2023
@ArthurZucker
Copy link
Collaborator

Thanks @lorabit110 for the fix 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mistral with Flash atteniton v2 give error on long sequence input and max_new_tokens

7 participants