-
Couldn't load subscription status.
- Fork 31k
Add t5 to pipeline(task='summarization') #3413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aa0614e
c9ecbd8
1690d7e
ad31415
26b6fca
039aed9
16faf2b
d82b942
62ffb38
8d4b05f
fd5183e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from .configuration_bart import BartConfig | ||
| from .configuration_distilbert import DistilBertConfig | ||
| from .configuration_roberta import RobertaConfig | ||
| from .configuration_t5 import T5Config | ||
| from .configuration_utils import PretrainedConfig | ||
| from .configuration_xlm import XLMConfig | ||
| from .data import SquadExample, squad_convert_examples_to_features | ||
|
|
@@ -60,7 +61,6 @@ | |
| AutoModelForTokenClassification, | ||
| AutoModelWithLMHead, | ||
| ) | ||
| from .modeling_bart import BartForConditionalGeneration | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -336,6 +336,7 @@ def __init__( | |
| tokenizer: PreTrainedTokenizer, | ||
| modelcard: Optional[ModelCard] = None, | ||
| framework: Optional[str] = None, | ||
| task: str = "", | ||
| args_parser: ArgumentHandler = None, | ||
| device: int = -1, | ||
| binary_output: bool = False, | ||
|
|
@@ -356,6 +357,11 @@ def __init__( | |
| if self.framework == "pt" and self.device.type == "cuda": | ||
| self.model = self.model.to(self.device) | ||
|
|
||
| # Update config with task specific parameters | ||
| task_specific_params = self.model.config.task_specific_params | ||
| if task_specific_params is not None and task in task_specific_params: | ||
| self.model.config.update(task_specific_params.get(task)) | ||
|
|
||
| def save_pretrained(self, save_directory): | ||
| """ | ||
| Save the pipeline's model and tokenizer to the specified save_directory | ||
|
|
@@ -420,7 +426,7 @@ def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict: | |
| """ | ||
| args = ["input_ids", "attention_mask"] | ||
|
|
||
| if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)): | ||
| if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)): | ||
| args += ["token_type_ids"] | ||
|
|
||
| # PR #1548 (CLI) There is an issue with attention_mask | ||
|
|
@@ -432,14 +438,18 @@ def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict: | |
| else: | ||
| return {k: [feature[k] for feature in features] for k in args} | ||
|
|
||
| def _parse_and_tokenize(self, *texts, **kwargs): | ||
| def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs): | ||
| """ | ||
| Parse arguments and tokenize | ||
| """ | ||
| # Parse arguments | ||
| inputs = self._args_parser(*texts, **kwargs) | ||
| inputs = self.tokenizer.batch_encode_plus( | ||
| inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len | ||
| inputs, | ||
| add_special_tokens=True, | ||
| return_tensors=self.framework, | ||
| max_length=self.tokenizer.max_len, | ||
| pad_to_max_length=pad_to_max_length, | ||
|
||
| ) | ||
|
|
||
| # Filter out features not available on specific models | ||
|
|
@@ -520,6 +530,7 @@ def __init__( | |
| framework: Optional[str] = None, | ||
| args_parser: ArgumentHandler = None, | ||
| device: int = -1, | ||
| task: str = "", | ||
| ): | ||
| super().__init__( | ||
| model=model, | ||
|
|
@@ -529,6 +540,7 @@ def __init__( | |
| args_parser=args_parser, | ||
| device=device, | ||
| binary_output=True, | ||
| task=task, | ||
| ) | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
|
|
@@ -625,6 +637,7 @@ def __init__( | |
| args_parser: ArgumentHandler = None, | ||
| device: int = -1, | ||
| topk=5, | ||
| task: str = "", | ||
| ): | ||
| super().__init__( | ||
| model=model, | ||
|
|
@@ -634,6 +647,7 @@ def __init__( | |
| args_parser=args_parser, | ||
| device=device, | ||
| binary_output=True, | ||
| task=task, | ||
| ) | ||
|
|
||
| self.topk = topk | ||
|
|
@@ -725,6 +739,7 @@ def __init__( | |
| device: int = -1, | ||
| binary_output: bool = False, | ||
| ignore_labels=["O"], | ||
| task: str = "", | ||
| ): | ||
| super().__init__( | ||
| model=model, | ||
|
|
@@ -734,6 +749,7 @@ def __init__( | |
| args_parser=args_parser, | ||
| device=device, | ||
| binary_output=binary_output, | ||
| task=task, | ||
| ) | ||
|
|
||
| self._basic_tokenizer = BasicTokenizer(do_lower_case=False) | ||
|
|
@@ -896,6 +912,7 @@ def __init__( | |
| modelcard: Optional[ModelCard] = None, | ||
| framework: Optional[str] = None, | ||
| device: int = -1, | ||
| task: str = "", | ||
| **kwargs | ||
| ): | ||
| super().__init__( | ||
|
|
@@ -905,6 +922,7 @@ def __init__( | |
| framework=framework, | ||
| args_parser=QuestionAnsweringArgumentHandler(), | ||
| device=device, | ||
| task=task, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -1111,12 +1129,16 @@ class SummarizationPipeline(Pipeline): | |
|
|
||
| Usage:: | ||
|
|
||
| # use bart in pytorch | ||
| summarizer = pipeline("summarization") | ||
| summarizer("Sam Shleifer writes the best docstring examples in the whole world.") | ||
| summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20) | ||
|
|
||
patrickvonplaten marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # use t5 in tf | ||
| summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf") | ||
| summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20) | ||
|
|
||
| Supported Models: | ||
| The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is | ||
| currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')`` | ||
| The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. | ||
|
|
||
| Arguments: | ||
| model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`): | ||
|
|
@@ -1147,28 +1169,15 @@ class SummarizationPipeline(Pipeline): | |
| on the associated CUDA device id. | ||
| """ | ||
|
|
||
| task = "summarization" | ||
|
|
||
| def __call__( | ||
| self, | ||
| *documents, | ||
| return_tensors=False, | ||
| return_text=True, | ||
| max_length=142, | ||
| min_length=21, | ||
| clean_up_tokenization_spaces=False, | ||
| **generate_kwargs | ||
| self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs | ||
| ): | ||
| r""" | ||
| Args: | ||
| *documents: (list of strings) articles to be summarized | ||
| return_text: (bool, default=True) whether to add a decoded "summary_text" to each result | ||
| return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result | ||
|
|
||
| max_length: (`optional`) int | ||
| The max length of the sequence to be generated. Does not include tokens in input_ids. | ||
| min_len: (`optional`) int | ||
| no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text | ||
| clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output | ||
| **generate_kwargs: extra kwargs passed to `self.model.generate`_ | ||
|
|
||
|
|
@@ -1180,19 +1189,60 @@ def __call__( | |
|
|
||
| """ | ||
| assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" | ||
| if self.framework == "tf": | ||
| raise NotImplementedError("Tensorflow not supported") | ||
| assert len(documents) > 0, "Please provide a document to summarize" | ||
|
|
||
| if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__: | ||
| raise NotImplementedError( | ||
| "Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`" | ||
| ) | ||
|
|
||
| prefix = self.model.config.prefix if self.model.config.prefix is not None else "" | ||
|
|
||
| if isinstance(documents[0], list): | ||
| assert ( | ||
| self.tokenizer.pad_token_id is not None | ||
| ), "Please make sure that the tokenizer has a pad_token_id when using a batch input" | ||
|
|
||
| documents = ([prefix + document for document in documents[0]],) | ||
| pad_to_max_length = True | ||
|
|
||
| elif isinstance(documents[0], str): | ||
| documents = (prefix + documents[0],) | ||
| pad_to_max_length = False | ||
| else: | ||
| raise ValueError( | ||
| " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format( | ||
| documents[0] | ||
| ) | ||
| ) | ||
|
|
||
| with self.device_placement(): | ||
| inputs = self._parse_and_tokenize(*documents) | ||
| inputs = self.ensure_tensor_on_device(**inputs) | ||
| inputs = self._parse_and_tokenize(*documents, pad_to_max_length=pad_to_max_length) | ||
|
|
||
| if self.framework == "pt": | ||
| inputs = self.ensure_tensor_on_device(**inputs) | ||
| input_length = inputs["input_ids"].shape[-1] | ||
| elif self.framework == "tf": | ||
| input_length = tf.shape(inputs["input_ids"])[-1] | ||
|
||
|
|
||
| if input_length < self.model.config.min_length // 2: | ||
| logger.warning( | ||
| "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format( | ||
| self.model.config.min_length, input_length | ||
| ) | ||
| ) | ||
|
|
||
| if input_length < self.model.config.max_length: | ||
| logger.warning( | ||
| "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format( | ||
| self.model.config.max_length, input_length | ||
| ) | ||
| ) | ||
|
|
||
| summaries = self.model.generate( | ||
| inputs["input_ids"], | ||
| attention_mask=inputs["attention_mask"], | ||
| max_length=max_length, | ||
| min_length=min_length, | ||
| do_sample=False, | ||
| **generate_kwargs, | ||
| inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs, | ||
| ) | ||
|
|
||
| results = [] | ||
| for summary in summaries: | ||
| record = {} | ||
|
|
@@ -1266,8 +1316,8 @@ def __call__( | |
| }, | ||
| "summarization": { | ||
| "impl": SummarizationPipeline, | ||
| "pt": BartForConditionalGeneration if is_torch_available() else None, | ||
| "tf": None, | ||
| "tf": TFAutoModelWithLMHead if is_tf_available() else None, | ||
| "pt": AutoModelWithLMHead if is_torch_available() else None, | ||
| "default": { | ||
| "model": {"pt": "bart-large-cnn", "tf": None}, | ||
| "config": None, | ||
|
|
@@ -1361,7 +1411,7 @@ def pipeline( | |
| framework = framework or get_framework(model) | ||
|
|
||
| targeted_task = SUPPORTED_TASKS[task] | ||
| task, model_class = targeted_task["impl"], targeted_task[framework] | ||
| task_class, model_class = targeted_task["impl"], targeted_task[framework] | ||
|
|
||
| # Use default model/config/tokenizer for the task if no model is provided | ||
| if model is None: | ||
|
|
@@ -1422,4 +1472,4 @@ def pipeline( | |
| ) | ||
| model = model_class.from_pretrained(model, config=config, **model_kwargs) | ||
|
|
||
| return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs) | ||
| return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, note that we can now remove the
inputs_for_modelfrom pipelines since #3116Let's do it in another PR cleaning up pipelines later.