-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[Docs] Improve PyTorch, Flax generate API #15988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Docs] Improve PyTorch, Flax generate API #15988
Conversation
|
No doc-builder triggered here? 😢 |
…trickvonplaten/transformers into move_generate_to_its_own_page
The docs are not updated on the link if the PR is changed (or it takes too long). Will build the docs locally now, but I think it makes it quite difficult for the community to add/change docs. |
|
The job updates the docs. Are they not up to date here? https://moon-ci-docs.huggingface.co/docs/transformers/pr_15988/en/main_classes/text_generation |
Nope |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for expanding the doc!
|
|
||
| # Generation | ||
|
|
||
| The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think replacing the first paragraph with the suggestion below makes it easier for the user to map each generate to its GenerationMixin class :)
Each framework has a generate method for auto-regressive text generation implemented in their respective GenerationMixin class:
- PyTorch [
~generation_utils.GenerationMixin.generate] is implemented in [~generation_utils.GenerationMixin]. - TensorFlow [
~generation_tf_utils.TFGenerationMixin.generate] is implemented in [~generation_tf_utils.TFGenerationMixin]. - Flax/JAX [
~generation_flax_utils.FlaxGenerationMixin.generate] is implemented in [~generation_flax_utils.FlaxGenerationMixin].
| The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all | ||
| methods for auto-regressive text generation to every model class. | ||
|
|
||
| ## GenerationMixn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small typo for each of the GenerationMixin classes: GenerationMixin, TFGenerationMixin, FlaxGenerationMixin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correcting it here: #16133 (comment) . Thanks!
What does this PR do?
This PR is the first step to make
generatea 1st class citizen in the docs. It improves the generate API for PyTorch and Flax generate, improves the examples for PyTorch and adds PyTorch to the example doc tests.Once the TF generate refactor is complete - it's API can also be improved with better examples.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.