Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4853a36
fix merge conflicts
patrickvonplaten Mar 26, 2020
bf4d84b
fix merge conflicts
patrickvonplaten Mar 26, 2020
5162c60
fix merge conflicts
patrickvonplaten Mar 26, 2020
43c5327
fix merge conflicts
patrickvonplaten Mar 26, 2020
c0f72b6
fix merge conflicts
patrickvonplaten Mar 26, 2020
c43866f
fix merge conflicts
patrickvonplaten Mar 26, 2020
d02ef0b
fix merge conflicts
patrickvonplaten Mar 26, 2020
e11fdb0
remove unused patterns
patrickvonplaten Mar 26, 2020
661e32b
solve conflicts
patrickvonplaten Mar 26, 2020
aeb5483
add t5 summarization example
patrickvonplaten Mar 24, 2020
8f82962
fix conflicts
patrickvonplaten Mar 25, 2020
ca2ad3f
change parameters for t5 summarization
patrickvonplaten Mar 24, 2020
82b6833
make style
patrickvonplaten Mar 24, 2020
35b7348
fix conflicts
patrickvonplaten Mar 25, 2020
c0692ba
only add prefixes
patrickvonplaten Mar 25, 2020
c560985
add prefix patterns
patrickvonplaten Mar 25, 2020
e2c139b
make style
patrickvonplaten Mar 25, 2020
8710134
fix conflicts
patrickvonplaten Mar 25, 2020
8b15478
renaming
patrickvonplaten Mar 25, 2020
a03a368
add first code snippet for translation
patrickvonplaten Mar 25, 2020
2566f62
fix merge conflicts
patrickvonplaten Mar 26, 2020
c86921f
remove translation example
patrickvonplaten Mar 26, 2020
f7fefaa
remove summarization example
patrickvonplaten Mar 26, 2020
c0c840c
make sure tensors are in numpy for float comparsion
patrickvonplaten Mar 26, 2020
bfe770e
re-add t5 config
patrickvonplaten Mar 26, 2020
46bd060
fix t5 import config typo
patrickvonplaten Mar 26, 2020
0956a76
make style
patrickvonplaten Mar 26, 2020
43dd428
remove unused numpy statements
patrickvonplaten Mar 26, 2020
bc8287e
update doctstring
patrickvonplaten Mar 26, 2020
a5160a6
import translation pipeline
patrickvonplaten Mar 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
SummarizationPipeline,
TextClassificationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
pipeline,
)
from .tokenization_albert import AlbertTokenizer
Expand Down
168 changes: 158 additions & 10 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class PipelineDataFormat:

SUPPORTED_FORMATS = ["json", "csv", "pipe"]

def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
self.output_path = output_path
self.input_path = input_path
self.column = column.split(",") if column is not None else [""]
Expand Down Expand Up @@ -176,7 +178,7 @@ def save_binary(self, data: Union[dict, List[dict]]) -> str:

@staticmethod
def from_str(
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
if format == "json":
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
Expand All @@ -189,7 +191,9 @@ def from_str(


class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
super().__init__(output_path, input_path, column, overwrite=overwrite)

def __iter__(self):
Expand All @@ -210,7 +214,9 @@ def save(self, data: List[dict]):


class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
super().__init__(output_path, input_path, column, overwrite=overwrite)

with open(input_path, "r") as f:
Expand Down Expand Up @@ -1120,7 +1126,11 @@ def span_to_answer(self, text: str, start: int, end: int):
chars_idx += len(word) + 1

# Join text with spaces
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
return {
"answer": " ".join(words),
"start": max(0, char_start_idx),
"end": min(len(text), char_end_idx),
}


class SummarizationPipeline(Pipeline):
Expand Down Expand Up @@ -1223,18 +1233,18 @@ def __call__(
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1]
input_length = tf.shape(inputs["input_ids"])[-1].numpy()

if input_length < self.model.config.min_length // 2:
logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
self.model.config.min_length, input_length
)
)

if input_length < self.model.config.max_length:
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
self.model.config.max_length, input_length
)
)
Expand All @@ -1250,7 +1260,115 @@ def __call__(
record["summary_token_ids"] = summary
if return_text:
record["summary_text"] = self.tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results


class TranslationPipeline(Pipeline):
"""
Translates from one language to another.

Usage::
en_fr_translator = pipeline("translation_en_to_fr")
en_fr_translator("How old are you?")

Supported Models: "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"

Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""

def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
r"""
Args:
*texts: (list of strings) articles to be summarized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

texts to be translated instead of summarized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result

**generate_kwargs: extra kwargs passed to `self.model.generate`_

Returns:
list of dicts with 'translation_text' and/or 'translation_token_ids' for each text_to_translate
.. _`self.model.generate`:
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"

prefix = self.model.config.prefix if self.model.config.prefix is not None else ""

if isinstance(texts[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
texts = ([prefix + text for text in texts[0]],)
pad_to_max_length = True

elif isinstance(texts[0], str):
texts = (prefix + texts[0],)
pad_to_max_length = False
else:
raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
texts[0]
)
)

with self.device_placement():
inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length)

if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]

elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()

if input_length > 0.9 * self.model.config.max_length:
logger.warning(
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
input_length, self.model.config.max_length
)
)

translations = self.model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)
results = []
for translation in translations:
record = {}
if return_tensors:
record["translation_token_ids"] = translation
if return_text:
record["translation_text"] = self.tokenizer.decode(
translation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
Expand Down Expand Up @@ -1324,6 +1442,36 @@ def __call__(
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
},
},
"translation_en_to_fr": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"translation_en_to_de": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
}


Expand Down Expand Up @@ -1472,4 +1620,4 @@ def pipeline(
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)

return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs,)
28 changes: 28 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}

TRANSLATION_FINETUNED_MODELS = {
("t5-small", "t5-small", "translation_en_to_de"),
("t5-small", "t5-small", "translation_en_to_ro"),
}
TF_TRANSLATION_FINETUNED_MODELS = {("t5-small", "t5-small", "translation_en_to_fr")}


class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(
Expand Down Expand Up @@ -272,6 +278,28 @@ def test_tf_summarization(self):
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)

@require_torch
def test_translation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
for model, tokenizer, task in TRANSLATION_FINETUNED_MODELS:
nlp = pipeline(task=task, model=model, tokenizer=tokenizer)
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)

@require_tf
def test_tf_translation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
for model, tokenizer, task in TF_TRANSLATION_FINETUNED_MODELS:
nlp = pipeline(task=task, model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)


class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
Expand Down