-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[generate] Run custom generation code from the Hub #36405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
67dd305 to
545e1d0
Compare
| > [!TIP] | ||
| > `transformers` can also load custom generation strategies from the Hub! See the `Custom decoding methods` section below for instructions on how to load or share a custom generation strategy. | ||
| ## Basic decoding methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reorganized the generation strategies docs to make it easier to find the docs regarding custom generation methods.
The outer level is now basic decoding methods (moved greedy decoding, sampling, and beam search here), advanced decoding methods (moved the other ones here), and custom decoding methods (new section for the code on the Hub)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool! I think it could be even more helpful if we add under "Basic decoding methods" a few sentences about when you should use these methods, and the same for "Advanced decoding methods" to help users further differentiate between basic and advanced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow super cool! Feels like we can load more and more custom things from the Hub into Transformers these days 🚀
| > [!TIP] | ||
| > `transformers` can also load custom generation strategies from the Hub! See the `Custom decoding methods` section below for instructions on how to load or share a custom generation strategy. | ||
| ## Basic decoding methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool! I think it could be even more helpful if we add under "Basic decoding methods" a few sentences about when you should use these methods, and the same for "Advanced decoding methods" to help users further differentiate between basic and advanced.
| ``` | ||
|
|
||
|
|
||
| ## Custom decoding methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new docs start here
|
|
||
| except OSError: | ||
| logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | ||
| logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were throwing an exception and a warning 🤔
Lowered the severity, otherwise we were throwing warnings in most model from_pretrained (see changes in modeling_utils.py)
| full_submodule_module_file_path = os.path.join(full_submodule, module_file) | ||
| create_dynamic_module(Path(full_submodule_module_file_path).parent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously: we could only load root level custom modules
With this change: we can load custom modules in any folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| return trust_remote_code | ||
|
|
||
|
|
||
| def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic for the requirements exception as below
ValueError: Missing requirements for joaogante/test_generate_from_hub_bad_requirements:
foo (installed: None)
bar==0.0.0 (installed: None)
torch>=99.0 (installed: 2.6.0+cu126)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me!
As mentioned offline, let's make sure that it's only runnable with trust_remote_code=True.
Also, this was very recently merged and might make sense for your versions comparisons:
transformers/src/transformers/utils/import_utils.py
Lines 2091 to 2146 in 23d79ce
| class VersionComparison(Enum): | |
| EQUAL = operator.eq | |
| NOT_EQUAL = operator.ne | |
| GREATER_THAN = operator.gt | |
| LESS_THAN = operator.lt | |
| GREATER_THAN_OR_EQUAL = operator.ge | |
| LESS_THAN_OR_EQUAL = operator.le | |
| @staticmethod | |
| def from_string(version_string: str) -> "VersionComparison": | |
| string_to_operator = { | |
| "=": VersionComparison.EQUAL.value, | |
| "==": VersionComparison.EQUAL.value, | |
| "!=": VersionComparison.NOT_EQUAL.value, | |
| ">": VersionComparison.GREATER_THAN.value, | |
| "<": VersionComparison.LESS_THAN.value, | |
| ">=": VersionComparison.GREATER_THAN_OR_EQUAL.value, | |
| "<=": VersionComparison.LESS_THAN_OR_EQUAL.value, | |
| } | |
| return string_to_operator[version_string] | |
| @lru_cache() | |
| def split_package_version(package_version_str) -> Tuple[str, str, str]: | |
| pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)" | |
| match = re.match(pattern, package_version_str) | |
| if match: | |
| return (match.group(1), match.group(2), match.group(3)) | |
| else: | |
| raise ValueError(f"Invalid package version string: {package_version_str}") | |
| class Backend: | |
| def __init__(self, backend_requirement: str): | |
| self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement) | |
| if self.package_name not in BACKENDS_MAPPING: | |
| raise ValueError( | |
| f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}" | |
| ) | |
| def is_satisfied(self) -> bool: | |
| return VersionComparison.from_string(self.version_comparison)( | |
| version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version) | |
| ) | |
| def __repr__(self) -> str: | |
| return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")' | |
| @property | |
| def error_message(self): | |
| return ( | |
| f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That" | |
| f" library was not found with this version in your environment." | |
| ) |
cc @Rocketknight1 in case you want to give a review as the owner of remote code :)
| full_submodule_module_file_path = os.path.join(full_submodule, module_file) | ||
| create_dynamic_module(Path(full_submodule_module_file_path).parent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Manuel de Prada Corral <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! hopping to see this enable otherdevs!
| trust_remote_code = kwargs.pop("trust_remote_code", None) | ||
| # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: | ||
| # they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to | ||
| # methods from `GenerationMixin` through `model`. | ||
| global_keys_to_exclude = {"self", "kwargs"} | ||
| generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} | ||
| generate_arguments.update(kwargs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this can be done in loda_custom_generate that would return the generate arguments!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs locals() from generate to redirect non-keyword arguments :')
What does this PR do?
This PR expands our modeling code on the hub support: it enables the definition of custom decoding methods on the hub, sharing the same interface as
generate.How does it work?
Assumptions:
📖 custom decoding methods are defined in model repos
✍️ the custom generation code is defined in
custom_generate/generate.py🔎 these repos have a
custom_generatetag so we can quickly find themThe custom generation code defined in the repo
"foo/bar"can be accessed in two ways:model = AutoXXX.from_pretrained("foo/bar"):model.generatewill always call the custom generate defined in"foo/bar";model: we can specify the repo to load the custom generate from, i.e.model.generate(..., custom_generate="foo/bar")(for more info, read the added documentation)
Expected benefits
🧠 Creators of new decoding methods:
transformersrelease;model.generateinterface, but much higher implementation freedom🔨
transformersteam:Working Example
(generate.py in transformers-community/custom_generate_example)
Sanity checks, copy paste after the script above: