-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Efficient Transformers backend support #2858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) | ||
| else: | ||
| return CausalLM.fallback( | ||
| return transformers_causal_lm_class.fallback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I like to remove indirections.
Here, transformers_causal_lm_class is not known by the reader, he requires looking up where that's define which means following the flow of code is hard.
We know if models support flex attention or not. We can hardcode them CausalLM -> TransformersFlashCausalLM.
That removes the need to "guess" and the dependency on the private bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the dynamic behavior is simpler as we will roll support for more and more models in transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But can obviously be changed if this is a blocker on your side 😁
| softmax_scale: Optional[float] = None, | ||
| sliding_window: Optional[int] = None, | ||
| softcap: Optional[float] = None, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are needed here to easily "absorb" whatever is passed internally in Transformers and not used in tgi's attention. Made it more explicit for the arguments we do use though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, then we can mark kwargs as _kwargs (just trying to explicitly say there might be arguments we do not use)
server/text_generation_server/models/transformers_flash_causal_lm.py
Outdated
Show resolved
Hide resolved
server/text_generation_server/models/transformers_flash_causal_lm.py
Outdated
Show resolved
Hide resolved
| prefill_cache_indices, | ||
| lm_head_indices, | ||
| ): | ||
| hidden_states = self.model.model.forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IS that consistent enough ?? I thought some models defined self.transformer instead of self.model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of now yes - all models that are supported by our refactors are consistent with that naming. However I agree that it is quite an ugly workaround, and I'll open a PR asap in Transformers to allow logit slicing with Tensor (for now we only support int slicing with num_logits_to_keep)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved it with huggingface/transformers#35757 in Transformers for cleaner and more robust interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's consistent now, then by all means let's use that. We don't care about legacy transformers versions :)
| def forward( | ||
| self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| # NOTE: adapter_data: not supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a hard fail at least for every config that would otherwise be silently ignored.
speculation, and adapter data (the checks might already exist)
| return logits | ||
|
|
||
|
|
||
| def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems everything downward is a copy from what was before.
What are the key differences if any ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calls to self.model.forward() are replaced by self._model_forward()
What does this PR do?
Initial draft to support
transformersas a (more efficient) backend in TGI. huggingface/transformers#35235 added support for a bunch of models already, and more will come progressively.However, I do need some guidance on how to best support multi-gpu setups 🤗
cc @OlivierDehaene @Narsil