Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 31, 2024

What does this PR do?

Fixes: #30523

Click to see the snippet (make sure to run `accelerate config` and select FSDP options before hand and run `accelerate launch script.py`)
from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

#30743 introduced a breaking change for users that use Llama-based models + FSDP + activation checkpointing with FSDP.

Before #30743 - we were able to pass arbitrary kwargs within Llama modules that were silently ignored. When doing FSDP + activation checkpointing, the target gradient checkpointing classes are wrapped in a new class, and additional kwargs are passed along that class forward pass

The script above used to work for transformers <= 4.40.0 and does not work anymore due to #30743 , re-intoducing kwargs in the foward method signature fixes the bug

cc @amyeroberts

@younesbelkada younesbelkada requested a review from LysandreJik May 31, 2024 10:37
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and apologies for breaking this!

Some questions before we can merge

  • Would it make sense to add a test to make sure we don't accidentally break this again?
  • Having **kwargs in the forward method isn't standard amongst transformers models. Is there something special about these models which need this for FSDP? If not, should we be adding to other models?
  • Is there an alternative to using this injection? Relying on kwargs being passed isn't ideal

@younesbelkada
Copy link
Contributor Author

Thanks !

Would it make sense to add a test to make sure we don't accidentally break this again?

Yes, I'll add a test in this PR to test this behavior and catch bugs in the future!

Having **kwargs in the forward method isn't standard amongst transformers models. Is there something special about these models which need this for FSDP? If not, should we be adding to other models?

Yes agreed, I think we should add it to all 'most-used' models. FSDP is useful for large models, so I would say we should add it for LLMs (llama, gemma, mistral, mixtral, gpt-neo, etc.) to make things consistent. Happy to do that within this PR !

Is there an alternative to using this injection? Relying on kwargs being passed isn't ideal

I am not sure, this seems to be something internal to FSDP + CPU offloading, I don't think we can find a workaround to this :/ for me since it used to work before, it should be still working for future transformers versions to ensure BC. What do you think?

@amyeroberts
Copy link
Contributor

Yes, I'll add a test in this PR to test this behavior and catch bugs in the future!
Yes agreed, I think we should add it to all 'most-used' models. FSDP is useful for large models, so I would say we should add it for LLMs (llama, gemma, mistral, mixtral, gpt-neo, etc.) to make things consistent. Happy to do that within this PR !

Awesome - thank you!

I am not sure, this seems to be something internal to FSDP + CPU offloading, I don't think we can find a workaround to this :/ for me since it used to work before, it should be still working for future transformers versions to ensure BC. What do you think?

Make sense - let's leave as-is :)

@amyeroberts
Copy link
Contributor

amyeroberts commented Jun 26, 2024

@younesbelkada I'm really sorry I missed the rerequest for review. I don't have permissions to make changes, so copied the branch here: #31638 and sync with main. I couldn't push working locally but could change through the editor

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing @younesbelkada, and apologies for the delay in reviewing.

I was able to make the necessary updates to resolve conflicts with main through the online editor. As this was just merging new input argument it didn't affect the structure of the PR. I did remove the testing_utils scripts (which I would have asked you to remove in a review :) )

@amyeroberts amyeroberts merged commit 3f93fd0 into main Jun 26, 2024
@amyeroberts amyeroberts deleted the fix-llama-fsdp branch June 26, 2024 13:50
@muellerzr muellerzr mentioned this pull request Mar 6, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Llama Attention Call should not pass **kwargs

4 participants