-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Description
🚀 Feature request
"Disjunctive Positive Constraint Decoding" Algorithm proposed by a recent paper: Guided Generation of Cause and Effect
Currently the model.generation() method limits the user to just the highest probability outputs. Here the user is able to force diverse outputs by forcing the model to include diverse tokens across multiple generations.
This method is called "Disjunctive Positive Constraint Decoding", and it forces the model.generate() process to generate sequences with the highest probabilities under the constraint of needing to include a set of provided tokens.
This "disjunctive" method is powerful in that it can handle lemmatizing these forced tokens. For instance, when asking the model to autoregressively generate the completion tokens from "Babies cry because" and want to force the generation to include the word "lonely", it can induce the model to generate sequences like "Babies cry because they are lonely", as well as "Babies cry because of their loneliness".
I think that this could be implemented as:
model.generate(force_tokens=[["lonely", "loneliness", ...], ["happy", "happiness", ...]])
where the input to force_tokens is a 2D array, where each 1D array is a list of different forms of the desired token.
Otherwise it could be:
model.generate(force_tokens=["lonely", "happy"])
but in this case the transformers library would need to have a built-in lemmatization engine, which I think might be something better left for the practictioner to figure out (i.e. go with the first method instead).
Motivation
Diversity in outputs from sequence generation of LMs have always been an active problem. Though usually diversity inducing methods involved some dual implementation with a VAE or modifying the training scheme, this feature would allow the practioner to induce diversity in a very controllable way.
A large pretrained language model probably has the capacity to generate all kinds of different expressions given an input, but usually the generation gets limited to the highest probable outputs. Clearly one solution is to use sampling instead, but this goes back to the problem of controllability. This level of control is extremely useful in model implementations that aim to learn a syntactic transformations that need to preserve certain entities, or QA verbalizers where we have pre-existing knowledge of what the answer should be.
Instead of making it generate a lot of sequences and filtering out for desired ones, this would allow to force it to generate an output that we want, which a large LM probably can do well; even if it can't figure out a way, then we can just filter out the low probabability outputs based on a threshold.
Your contribution
I'm happy to submit a PR for the full implementation if there aren't any reasons to object this feature.
But I do believe I should get some feedback on this idea before proceeding with an implementation, since it's not exactly clear what's the best way to introduce this functionality to the library.