-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Description
Environment info
transformersversion: 4.17.0- Platform: linux-64
- Python version: 3.8
- PyTorch version (GPU?): 1.10.2 (CPU-only)
- Tensorflow version (GPU?): none
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: yes
Who can help
@patrickvonplaten
@SaulLu
@Narsil
Information
Model I am using (Bert, XLNet ...): Marian (Helsinki-NLP/opus-mt-it-en)
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
I'm trying to parallelize inference through Apache Spark. While this works for other models/pipelines (e.g. https://huggingface.co/joeddav/xlm-roberta-large-xnli), it doesn't work for Marian (e.g. https://huggingface.co/Helsinki-NLP/opus-mt-it-en). The problem is that the tokenizer/model/pipeline object needs to be serialized and broadcasted to the worker nodes, so the tokenizer/model/pipeline object needs to include all the required data. However, for the Marian tokenizer, when the tokenizer/model/pipeline is unserialized and __setstate__ is called, it tries to reload the tokenizer files (source.spm, target.spm, etc.) from the filesystem (see https://github.com/huggingface/transformers/blob/master/src/transformers/models/marian/tokenization_marian.py#L330), but those files aren't available anymore to the worker nodes, so it fails. The __setstate__ method shouldn't access the filesystem anymore.
The tasks I am working on is:
- an official GLUE/SQUaD task: translation
- my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
from transformers import pipeline
translator = pipeline("translation", model=model_dir)
broadcasted_translator = spark_session.sparkContext.broadcast(translator)
def compute_values(iterator):
for df in iterator:
batch_size = 32
sequences = df["text"].to_list()
res = []
for i in range(0, len(sequences), batch_size):
res += broadcasted_translator.value(sequences[i:i+batch_size])
df["translation"] = [item["translation_text"] for item in res]
yield df
schema = "text STRING, translation STRING"
sdf = spark_dataframe.mapInPandas(compute_values, schema=schema)
I get the following error:
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/pyspark/broadcast.py", line 129, in load
return pickle.load(file)
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/transformers/models/marian/tokenization_marian.py", line 330, in __setstate__
self.spm_source, self.spm_target = (load_spm(f, self.sp_model_kwargs) for f in self.spm_files)
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/transformers/models/marian/tokenization_marian.py", line 330, in <genexpr>
self.spm_source, self.spm_target = (load_spm(f, self.sp_model_kwargs) for f in self.spm_files)
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/transformers/models/marian/tokenization_marian.py", line 357, in load_spm
spm.Load(path)
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/sentencepiece/__init__.py", line 367, in Load
return self.LoadFromFile(model_file)
File "/tmp/conda-78ffd793-e3a4-4b56-a869-cedd86c5eeaa/real/envs/conda-env/lib/python3.8/site-packages/sentencepiece/__init__.py", line 171, in LoadFromFile
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
OSError: Not found: "/container_e165_1645611551581_313304_01_000001/tmp/model_dir/source.spm": No such file or directory Error #2
Expected behavior
As I've explained in the "Information" section, I should be able to serialize, broadcast, unserialize and apply the tokenizer/model/pipeline within the worker nodes. However, it fails because __setstate__ is called and it tries to reload the tokenizer files (source.spm, target.spm, etc.) from a filesystem which is not available to the worker nodes. The __setstate__ method shouldn't access the filesystem.