Skip to content

Conversation

@Cyrilvallez
Copy link
Member

What does this PR do?

Initial draft to support transformers as a (more efficient) backend in TGI. huggingface/transformers#35235 added support for a bunch of models already, and more will come progressively.

However, I do need some guidance on how to best support multi-gpu setups 🤗

cc @OlivierDehaene @Narsil

@Cyrilvallez Cyrilvallez marked this pull request as ready for review January 15, 2025 18:07
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I like to remove indirections.

Here, transformers_causal_lm_class is not known by the reader, he requires looking up where that's define which means following the flow of code is hard.

We know if models support flex attention or not. We can hardcode them CausalLM -> TransformersFlashCausalLM.

That removes the need to "guess" and the dependency on the private bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the dynamic behavior is simpler as we will roll support for more and more models in transformers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can obviously be changed if this is a blocker on your side 😁

softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No kwargs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are needed here to easily "absorb" whatever is passed internally in Transformers and not used in tgi's attention. Made it more explicit for the arguments we do use though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, then we can mark kwargs as _kwargs (just trying to explicitly say there might be arguments we do not use)

prefill_cache_indices,
lm_head_indices,
):
hidden_states = self.model.model.forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IS that consistent enough ?? I thought some models defined self.transformer instead of self.model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of now yes - all models that are supported by our refactors are consistent with that naming. However I agree that it is quite an ugly workaround, and I'll open a PR asap in Transformers to allow logit slicing with Tensor (for now we only support int slicing with num_logits_to_keep)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved it with huggingface/transformers#35757 in Transformers for cleaner and more robust interface

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's consistent now, then by all means let's use that. We don't care about legacy transformers versions :)

def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# NOTE: adapter_data: not supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a hard fail at least for every config that would otherwise be silently ignored.

speculation, and adapter data (the checks might already exist)

return logits


def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems everything downward is a copy from what was before.

What are the key differences if any ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calls to self.model.forward() are replaced by self._model_forward()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants