- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) #25242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| @sinking-point the PR has "WIP" in the title -- is it still under development, or is it ready to review? | 
| Not ready yet. Still have to fix more models and see what's breaking the other test. I've deprioritised this somewhat as it's quite time consuming, but I'll keep chipping away at it whenever I can. If you need this done quickly, you're welcome to help - lmk and I'll add you as a collaborator on my branch. | 
| Not urgent -- simply double-checking whether it was in need of a review or not :) | 
bffb27b    to
    bcad9c7      
    Compare
  
    | @gante This should be ready for review now. Thanks in advance. | 
bcad9c7    to
    a41bf7c      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sinking-point this is a great piece of work, thank you so much for working on this 🙏
For me, it's a green light on the logic of the PR 🟢 I've added a few comments to improve on the readability of the changes (and one performance-related comment), so that we can fulfill our role as maintainers on top of your work at super-human levels 🤗
| @ArthurZucker @LysandreJik This PR is a big one and touches in a core piece of logic for all generative models, so I'm tagging 2 core maintainers. ContextAdvanced generation techniques (like assisted generation or medusa) may generate more than one token per model forward pass. The original implementation of assisted generation had a lot of custom code, as it breaks one of the assumptions in the models'  Solution@sinking-point has kindly put forward a proposal to get rid of the custom code in assisted generation. After iterating with me, the plan was to remove the assumption of one token per  PostfaceTo reiterate: this PR gets the green light from me in terms of logic 🟢, and it is a big contribution by @sinking-point. This PR is also important to future-proof our generative techniques -- we will be ready for new types of multiple-token-per-forward-pass strategies as a result of this PR. I'll be off the next few weeks, but I'm sure this PR will get a quick resolution 🤗 | 
| Thanks @gante , I'll take a look at your comments tomorrow 👍 | 
| Hi @sinking-point! Sorry for the delay - I'm taking over this PR from @gante because he's out on a well-deserved rest right now. Is everything ready for review, or are there any other issues you want to discuss with the team before we take a final look at it? | 
| No worries @Rocketknight1 . Thanks for taking this on. There's one discussion gante opened that I haven't resolved. Could you give your input on this? #25242 (comment) | 
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. | 
| @sinking-point Replied to the last remaining issue up above! | 
| Thanks @Rocketknight1 , I'll take a look on Monday | 
72a343d    to
    8d1fbc2      
    Compare
  
    | Some random tests started failing so I rebased onto main where they're fixed, but it looks like I have some more work to do now. | 
| Ick, yeah. I'm not sure what's causing those test failures, but if you can't figure it out, let me know and I'll dive in! | 
| Should be ready to merge if you're happy with it. Thanks! | 
| Looks like doc tests passed @Rocketknight1 , so as you said let's make this a priority before any more models are added. | 
| Understood! It's quite a big PR since it touches so many models, but I'll try to get an internal review in the next few days. | 
| 
 Alternatively, could you require that new generative models'  | 
| Hi @Rocketknight1 , any update on this? | 
| Hey @sinking-point 👋 I'm back from holidays and I'll be doing a quick final check now. Assuming the check comes out positive, we'll tag a core maintainer to greenlight the merge. Our apologies for the slow process, it should be quick now 🤗 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 🙏
| @sinking-point regarding the failing test: rebasing the PR should fix it, the bug was fixed last week :) | 
| ping @LysandreJik -- this PR should be ready to be merged after it is rebased. Please read this comment for context :) | 
Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed.
…to support assisted generation
cbf75a3    to
    8ce040d      
    Compare
  
    | Thanks @gante :) | 
| This seems ok to me but I'd like to ask @patrickvonplaten for his opinion and eventual approval given the experience maintaining this part of the code | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Code is cleaned up and made more extendable - very much in favor of this change
| Amazing, thank you @LysandreJik and @patrickvonplaten | 
| And thank you @sinking-point for this big contribution 🔥 💛 | 
What does this PR do?
Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call.
The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed.
Fixes #25020
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante