Skip to content
11 changes: 11 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,14 @@ def to_json_file(self, json_file_path):
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())

def update(self, config_dict: Dict):
"""
Updates attributes of this class
with attributes from `config_dict`.

Args:
:obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
4 changes: 3 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,10 +999,12 @@ def _generate_beam_search(
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
num_batch_hypotheses = batch_size * num_beams

is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])

scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))

Expand Down
120 changes: 85 additions & 35 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .configuration_bart import BartConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig
from .configuration_t5 import T5Config
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLMConfig
from .data import SquadExample, squad_convert_examples_to_features
Expand Down Expand Up @@ -60,7 +61,6 @@
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
from .modeling_bart import BartForConditionalGeneration


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -336,6 +336,7 @@ def __init__(
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
Expand All @@ -356,6 +357,11 @@ def __init__(
if self.framework == "pt" and self.device.type == "cuda":
self.model = self.model.to(self.device)

# Update config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))

def save_pretrained(self, save_directory):
"""
Save the pipeline's model and tokenizer to the specified save_directory
Expand Down Expand Up @@ -420,7 +426,7 @@ def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
"""
args = ["input_ids", "attention_mask"]

if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, note that we can now remove the inputs_for_model from pipelines since #3116
Let's do it in another PR cleaning up pipelines later.

args += ["token_type_ids"]

# PR #1548 (CLI) There is an issue with attention_mask
Expand All @@ -432,14 +438,18 @@ def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
else:
return {k: [feature[k] for feature in features] for k in args}

def _parse_and_tokenize(self, *texts, **kwargs):
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
inputs = self._args_parser(*texts, **kwargs)
inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
inputs,
add_special_tokens=True,
return_tensors=self.framework,
max_length=self.tokenizer.max_len,
pad_to_max_length=pad_to_max_length,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we would not optionally use pad_to_max_length here? Was not sure if I can add it, but for summarization when using a batched input, this is necessary. @LysandreJik @thomwolf @mfuntowicz

)

# Filter out features not available on specific models
Expand Down Expand Up @@ -520,6 +530,7 @@ def __init__(
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
task: str = "",
):
super().__init__(
model=model,
Expand All @@ -529,6 +540,7 @@ def __init__(
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -625,6 +637,7 @@ def __init__(
args_parser: ArgumentHandler = None,
device: int = -1,
topk=5,
task: str = "",
):
super().__init__(
model=model,
Expand All @@ -634,6 +647,7 @@ def __init__(
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)

self.topk = topk
Expand Down Expand Up @@ -725,6 +739,7 @@ def __init__(
device: int = -1,
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
):
super().__init__(
model=model,
Expand All @@ -734,6 +749,7 @@ def __init__(
args_parser=args_parser,
device=device,
binary_output=binary_output,
task=task,
)

self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
Expand Down Expand Up @@ -896,6 +912,7 @@ def __init__(
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
device: int = -1,
task: str = "",
**kwargs
):
super().__init__(
Expand All @@ -905,6 +922,7 @@ def __init__(
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device,
task=task,
**kwargs,
)

Expand Down Expand Up @@ -1111,12 +1129,16 @@ class SummarizationPipeline(Pipeline):

Usage::

# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

# use t5 in tf
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)

Supported Models:
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.

Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
Expand Down Expand Up @@ -1147,28 +1169,15 @@ class SummarizationPipeline(Pipeline):
on the associated CUDA device id.
"""

task = "summarization"

def __call__(
self,
*documents,
return_tensors=False,
return_text=True,
max_length=142,
min_length=21,
clean_up_tokenization_spaces=False,
**generate_kwargs
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
r"""
Args:
*documents: (list of strings) articles to be summarized
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result

max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
min_len: (`optional`) int
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
**generate_kwargs: extra kwargs passed to `self.model.generate`_

Expand All @@ -1180,19 +1189,60 @@ def __call__(

"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
if self.framework == "tf":
raise NotImplementedError("Tensorflow not supported")
assert len(documents) > 0, "Please provide a document to summarize"

if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
raise NotImplementedError(
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
)

prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

if isinstance(documents[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"

documents = ([prefix + document for document in documents[0]],)
pad_to_max_length = True

elif isinstance(documents[0], str):
documents = (prefix + documents[0],)
pad_to_max_length = False
else:
raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
documents[0]
)
)

with self.device_placement():
inputs = self._parse_and_tokenize(*documents)
inputs = self.ensure_tensor_on_device(**inputs)
inputs = self._parse_and_tokenize(*documents, pad_to_max_length=pad_to_max_length)

if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1]
Comment on lines +1222 to +1226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now but we'll probably refactor this in the future to have framework-agnostic ensure_tensor_on_device and get_tensor_length methods (maybe have a base class and framework-specific derived class for instance).


if input_length < self.model.config.min_length // 2:
logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
self.model.config.min_length, input_length
)
)

if input_length < self.model.config.max_length:
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
self.model.config.max_length, input_length
)
)

summaries = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
do_sample=False,
**generate_kwargs,
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)

results = []
for summary in summaries:
record = {}
Expand Down Expand Up @@ -1266,8 +1316,8 @@ def __call__(
},
"summarization": {
"impl": SummarizationPipeline,
"pt": BartForConditionalGeneration if is_torch_available() else None,
"tf": None,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "bart-large-cnn", "tf": None},
"config": None,
Expand Down Expand Up @@ -1361,7 +1411,7 @@ def pipeline(
framework = framework or get_framework(model)

targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task["impl"], targeted_task[framework]
task_class, model_class = targeted_task["impl"], targeted_task[framework]

# Use default model/config/tokenizer for the task if no model is provided
if model is None:
Expand Down Expand Up @@ -1422,4 +1472,4 @@ def pipeline(
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)

return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
23 changes: 19 additions & 4 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]

SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}


class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(
Expand Down Expand Up @@ -252,10 +255,22 @@ def test_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
nlp = pipeline(task="summarization")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
for model, tokenizer in SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)

@require_tf
def test_tf_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
for model, tokenizer in TF_SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)


class MultiColumnInputTestCase(unittest.TestCase):
Expand Down