Skip to content

Conversation

@gante
Copy link
Member

@gante gante commented Feb 25, 2025

What does this PR do?

This PR expands our modeling code on the hub support: it enables the definition of custom decoding methods on the hub, sharing the same interface as generate.

How does it work?

Assumptions:
📖 custom decoding methods are defined in model repos
✍️ the custom generation code is defined in custom_generate/generate.py
🔎 these repos have a custom_generate tag so we can quickly find them

The custom generation code defined in the repo "foo/bar" can be accessed in two ways:

  1. If we load model = AutoXXX.from_pretrained("foo/bar"): model.generate will always call the custom generate defined in "foo/bar";
  2. Any other model: we can specify the repo to load the custom generate from, i.e. model.generate(..., custom_generate="foo/bar")

(for more info, read the added documentation)

Expected benefits

🧠 Creators of new decoding methods:

  • No need to go through a lengthy PR review cycle;
  • Day 0 support of their techniques, no need to wait for a transformers release;
  • Same model.generate interface, but much higher implementation freedom
  • Users don't need to install a new Python repo in their environment
  • Hub features (custom readme, discussion pages, track downloads, ...)

🔨 transformers team:

  • No more gatekeeping (technique A gets added, technique B doesn't get added)
  • More generation methods, lower expected maintenance 🤞
  • Empower our community 💛

Working Example

(generate.py in transformers-community/custom_generate_example)

from transformers import AutoModelForCausalLM, AutoTokenizer

# Qwen/Qwen2.5-0.5B-Instruct copy, but with custom generation code -> `generate` uses the custom code
# note: calling the custom method prints "✨ using a custom generation method ✨"
tokenizer = AutoTokenizer.from_pretrained("transformers-community/custom_generate_example")
model = AutoModelForCausalLM.from_pretrained("transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True)

inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs)
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))

# `generate` with `custom_generate` -> `generate` uses custom code
# note: calling the custom method prints "✨ using a custom generation method ✨"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")

inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True)
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))

Sanity checks, copy paste after the script above:

# Sanity check: `transformers-community/custom_generate_example` contains a simplified greedy decoding method
custom_generated_text = tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]
gen_out = model.generate(**inputs, do_sample=False, repetition_penalty=1.0, max_length=20)
greedy_generated_text = tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]
assert custom_generated_text == greedy_generated_text

# Sanity check 2: `trust_remote_code` is required
try:
    model.generate(**inputs, custom_generate="transformers-community/custom_generate_example")
except Exception as e:
    print("got exception 👍")
else:
    raise Exception("This should have failed -> no `trust_remote_code`")

try:
    model = AutoModelForCausalLM.from_pretrained("transformers-community/custom_generate_example", device_map="auto")
except Exception as e:
    print("got exception 👍")
else:
    raise Exception("This should have failed -> no `trust_remote_code`")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante force-pushed the generate_from_hub branch from 67dd305 to 545e1d0 Compare March 27, 2025 15:25
@gante gante marked this pull request as ready for review March 27, 2025 18:30
> [!TIP]
> `transformers` can also load custom generation strategies from the Hub! See the `Custom decoding methods` section below for instructions on how to load or share a custom generation strategy.
## Basic decoding methods
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized the generation strategies docs to make it easier to find the docs regarding custom generation methods.

The outer level is now basic decoding methods (moved greedy decoding, sampling, and beam search here), advanced decoding methods (moved the other ones here), and custom decoding methods (new section for the code on the Hub)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool! I think it could be even more helpful if we add under "Basic decoding methods" a few sentences about when you should use these methods, and the same for "Advanced decoding methods" to help users further differentiate between basic and advanced.

@gante gante requested review from ArthurZucker and LysandreJik and removed request for Rocketknight1 and stevhliu March 27, 2025 18:45
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow super cool! Feels like we can load more and more custom things from the Hub into Transformers these days 🚀

> [!TIP]
> `transformers` can also load custom generation strategies from the Hub! See the `Custom decoding methods` section below for instructions on how to load or share a custom generation strategy.
## Basic decoding methods
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool! I think it could be even more helpful if we add under "Basic decoding methods" a few sentences about when you should use these methods, and the same for "Advanced decoding methods" to help users further differentiate between basic and advanced.

```


## Custom decoding methods
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new docs start here


except OSError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were throwing an exception and a warning 🤔

Lowered the severity, otherwise we were throwing warnings in most model from_pretrained (see changes in modeling_utils.py)

Comment on lines +422 to +423
full_submodule_module_file_path = os.path.join(full_submodule, module_file)
create_dynamic_module(Path(full_submodule_module_file_path).parent)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously: we could only load root level custom modules
With this change: we can load custom modules in any folder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

return trust_remote_code


def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic for the requirements exception as below

ValueError: Missing requirements for joaogante/test_generate_from_hub_bad_requirements:
foo (installed: None)
bar==0.0.0 (installed: None)
torch>=99.0 (installed: 2.6.0+cu126)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!

As mentioned offline, let's make sure that it's only runnable with trust_remote_code=True.

Also, this was very recently merged and might make sense for your versions comparisons:

class VersionComparison(Enum):
EQUAL = operator.eq
NOT_EQUAL = operator.ne
GREATER_THAN = operator.gt
LESS_THAN = operator.lt
GREATER_THAN_OR_EQUAL = operator.ge
LESS_THAN_OR_EQUAL = operator.le
@staticmethod
def from_string(version_string: str) -> "VersionComparison":
string_to_operator = {
"=": VersionComparison.EQUAL.value,
"==": VersionComparison.EQUAL.value,
"!=": VersionComparison.NOT_EQUAL.value,
">": VersionComparison.GREATER_THAN.value,
"<": VersionComparison.LESS_THAN.value,
">=": VersionComparison.GREATER_THAN_OR_EQUAL.value,
"<=": VersionComparison.LESS_THAN_OR_EQUAL.value,
}
return string_to_operator[version_string]
@lru_cache()
def split_package_version(package_version_str) -> Tuple[str, str, str]:
pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
match = re.match(pattern, package_version_str)
if match:
return (match.group(1), match.group(2), match.group(3))
else:
raise ValueError(f"Invalid package version string: {package_version_str}")
class Backend:
def __init__(self, backend_requirement: str):
self.package_name, self.version_comparison, self.version = split_package_version(backend_requirement)
if self.package_name not in BACKENDS_MAPPING:
raise ValueError(
f"Backends should be defined in the BACKENDS_MAPPING. Offending backend: {self.package_name}"
)
def is_satisfied(self) -> bool:
return VersionComparison.from_string(self.version_comparison)(
version.parse(importlib.metadata.version(self.package_name)), version.parse(self.version)
)
def __repr__(self) -> str:
return f'Backend("{self.package_name}", {VersionComparison[self.version_comparison]}, "{self.version}")'
@property
def error_message(self):
return (
f"{{0}} requires the {self.package_name} library version {self.version_comparison}{self.version}. That"
f" library was not found with this version in your environment."
)

cc @Rocketknight1 in case you want to give a review as the owner of remote code :)

Comment on lines +422 to +423
full_submodule_module_file_path = os.path.join(full_submodule, module_file)
create_dynamic_module(Path(full_submodule_module_file_path).parent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@gante gante force-pushed the generate_from_hub branch from 6337231 to 88d7b58 Compare May 12, 2025 16:43
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! hopping to see this enable otherdevs!

Comment on lines +2332 to +2339
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
# methods from `GenerationMixin` through `model`.
global_keys_to_exclude = {"self", "kwargs"}
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
generate_arguments.update(kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this can be done in loda_custom_generate that would return the generate arguments!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs locals() from generate to redirect non-keyword arguments :')

@gante gante merged commit 0e0e5c1 into huggingface:main May 15, 2025
20 checks passed
@gante gante deleted the generate_from_hub branch May 15, 2025 09:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants