Skip to content

Commit a722703

Browse files
authored
use sacremoses normalizer, ensure pretranslate.src.json and pretranslate.trg.json use same directory (#96)
* use sacremoses normalizer, ensure pretranslate.src.json and pretranslate.trg.json use same directory * restore launch.json to commit in main branch * address efficiency issues * refactor to have separate uri and folder for shared_file, only normalize with sacremoses for NLLB
1 parent 17f82a8 commit a722703

File tree

5 files changed

+87
-24
lines changed

5 files changed

+87
-24
lines changed

machine/jobs/clearml_shared_file_service.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class ClearMLSharedFileService(SharedFileService):
1414
def _download_file(self, path: str, cache: bool = False) -> Path:
15-
uri = f"{self._shared_file_uri}/{path}"
15+
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
1616
local_folder: Optional[str] = None
1717
if not cache:
1818
local_folder = str(self._data_dir)
@@ -22,7 +22,7 @@ def _download_file(self, path: str, cache: bool = False) -> Path:
2222
return Path(file_path)
2323

2424
def _download_folder(self, path: str, cache: bool = False) -> Path:
25-
uri = f"{self._shared_file_uri}/{path}"
25+
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
2626
local_folder: Optional[str] = None
2727
if not cache:
2828
local_folder = str(self._data_dir)
@@ -32,22 +32,36 @@ def _download_folder(self, path: str, cache: bool = False) -> Path:
3232
return Path(folder_path) / path
3333

3434
def _exists_file(self, path: str) -> bool:
35-
uri = f"{self._shared_file_uri}/{path}"
35+
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
3636
return try_n_times(lambda: StorageManager.exists_file(uri)) # type: ignore
3737

3838
def _upload_file(self, path: str, local_file_path: Path) -> None:
3939
final_destination = try_n_times(
40-
lambda: StorageManager.upload_file(str(local_file_path), f"{self._shared_file_uri}/{path}")
40+
lambda: StorageManager.upload_file(
41+
str(local_file_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
42+
)
4143
)
4244
if final_destination is None:
43-
logger.error(f"Failed to upload file {str(local_file_path)} to {self._shared_file_uri}/{path}.")
45+
logger.error(
46+
(
47+
f"Failed to upload file {str(local_file_path)} "
48+
f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}."
49+
)
50+
)
4451

4552
def _upload_folder(self, path: str, local_folder_path: Path) -> None:
4653
final_destination = try_n_times(
47-
lambda: StorageManager.upload_folder(str(local_folder_path), f"{self._shared_file_uri}/{path}")
54+
lambda: StorageManager.upload_folder(
55+
str(local_folder_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
56+
)
4857
)
4958
if final_destination is None:
50-
logger.error(f"Failed to upload folder {str(local_folder_path)} to {self._shared_file_uri}/{path}.")
59+
logger.error(
60+
(
61+
f"Failed to upload folder {str(local_folder_path)} "
62+
f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}."
63+
)
64+
)
5165

5266

5367
def try_n_times(func: Callable, n=10):

machine/jobs/settings.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
default:
22
model_type: huggingface
33
data_dir: ~/machine
4+
shared_file_uri: s3://aqua-ml-data/
5+
shared_file_folder: production
46
pretranslation_batch_size: 1024
57
huggingface:
68
parent_model_name: facebook/nllb-200-distilled-1.3B
@@ -25,12 +27,13 @@ default:
2527
add_unk_src_tokens: true
2628
add_unk_trg_tokens: true
2729
development:
28-
shared_file_uri: s3://aqua-ml-data/dev/
30+
shared_file_folder: dev
2931
huggingface:
3032
parent_model_name: facebook/nllb-200-distilled-600M
3133
generate_params:
3234
num_beams: 1
3335
staging:
36+
shared_file_folder: ext-qa
3437
huggingface:
3538
parent_model_name: hf-internal-testing/tiny-random-nllb
3639
train_params:

machine/jobs/shared_file_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def generator() -> Generator[PretranslationInfo, None, None]:
6464
@contextmanager
6565
def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]:
6666
build_id: str = self._config.build_id
67-
build_dir = self._data_dir / "builds" / build_id
67+
build_dir = self._data_dir / self._shared_file_folder / "builds" / build_id
6868
build_dir.mkdir(parents=True, exist_ok=True)
6969
target_pretranslate_path = build_dir / "pretranslate.trg.json"
7070
with target_pretranslate_path.open("w", encoding="utf-8", newline="\n") as file:
@@ -96,6 +96,11 @@ def _shared_file_uri(self) -> str:
9696
shared_file_uri: str = self._config.shared_file_uri
9797
return shared_file_uri.rstrip("/")
9898

99+
@property
100+
def _shared_file_folder(self) -> str:
101+
shared_file_folder: str = self._config.shared_file_folder
102+
return shared_file_folder.rstrip("/")
103+
99104
@abstractmethod
100105
def _download_file(self, path: str, cache: bool = False) -> Path:
101106
...

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22

33
import gc
44
import logging
5+
import re
56
from math import exp, prod
6-
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast
7+
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union, cast
78

89
import torch # pyright: ignore[reportMissingImports]
9-
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, TranslationPipeline
10+
from sacremoses import MosesPunctNormalizer
11+
from transformers import (
12+
AutoConfig,
13+
AutoModelForSeq2SeqLM,
14+
AutoTokenizer,
15+
NllbTokenizer,
16+
NllbTokenizerFast,
17+
PreTrainedModel,
18+
PreTrainedTokenizer,
19+
PreTrainedTokenizerFast,
20+
TranslationPipeline,
21+
)
1022
from transformers.generation import BeamSearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput
1123
from transformers.tokenization_utils import BatchEncoding, TruncationStrategy
1224

@@ -38,6 +50,11 @@ def __init__(
3850
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
3951
)
4052
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)
53+
if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)):
54+
self._mpn = MosesPunctNormalizer()
55+
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]
56+
else:
57+
self._mpn = None
4158

4259
src_lang = self._pipeline_kwargs.get("src_lang")
4360
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
@@ -71,6 +88,7 @@ def __init__(
7188
self._pipeline = _TranslationPipeline(
7289
model=self._model,
7390
tokenizer=self._tokenizer,
91+
mpn=self._mpn,
7492
batch_size=self._batch_size,
7593
**self._pipeline_kwargs,
7694
)
@@ -149,15 +167,34 @@ def close(self) -> None:
149167

150168

151169
class _TranslationPipeline(TranslationPipeline):
170+
def __init__(
171+
self,
172+
model: Union[PreTrainedModel, StrPath, str],
173+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
174+
batch_size: int,
175+
mpn: Optional[MosesPunctNormalizer] = None,
176+
**kwargs,
177+
) -> None:
178+
super().__init__(model=model, tokenizer=tokenizer, batch_size=batch_size, **kwargs)
179+
self._mpn = mpn
180+
152181
def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
153182
if self.tokenizer is None:
154183
raise RuntimeError("No tokenizer is specified.")
155-
sentences = [
156-
s
157-
if isinstance(s, str)
158-
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
159-
for s in args
160-
]
184+
if self._mpn:
185+
sentences = [
186+
self._mpn.normalize(s)
187+
if isinstance(s, str)
188+
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
189+
for s in args
190+
]
191+
else:
192+
sentences = [
193+
s
194+
if isinstance(s, str)
195+
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
196+
for s in args
197+
]
161198
inputs = cast(
162199
BatchEncoding, super().preprocess(*sentences, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
163200
)

machine/translation/huggingface/hugging_face_nmt_model_trainer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(
9696
self.max_target_length = max_target_length
9797
self._add_unk_src_tokens = add_unk_src_tokens
9898
self._add_unk_trg_tokens = add_unk_trg_tokens
99+
self._mpn = MosesPunctNormalizer()
100+
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]
99101

100102
@property
101103
def stats(self) -> TrainStats:
@@ -169,9 +171,8 @@ def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes:
169171
for lang_code in lang_codes:
170172
for ex in train_dataset["translation"]:
171173
charset = charset | set(ex[lang_code])
172-
mpn = MosesPunctNormalizer()
173-
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
174-
charset = {mpn.normalize(char) for char in charset}
174+
if isinstance(tokenizer, (NllbTokenizerFast)):
175+
charset = {self._mpn.normalize(char) for char in charset}
175176
charset = {tokenizer.backend_tokenizer.normalizer.normalize_str(char) for char in charset}
176177
charset = set(filter(None, {char.strip() for char in charset}))
177178
missing_characters = sorted(list(charset - vocab))
@@ -302,11 +303,14 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
302303
)
303304

304305
def preprocess_function(examples):
305-
inputs = [ex[src_lang] for ex in examples["translation"]]
306-
targets = [ex[tgt_lang] for ex in examples["translation"]]
307-
inputs = [prefix + inp for inp in inputs]
308-
model_inputs = tokenizer(inputs, max_length=max_source_length, truncation=True)
306+
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
307+
inputs = [self._mpn.normalize(prefix + ex[src_lang]) for ex in examples["translation"]]
308+
targets = [self._mpn.normalize(ex[tgt_lang]) for ex in examples["translation"]]
309+
else:
310+
inputs = [prefix + ex[src_lang] for ex in examples["translation"]]
311+
targets = [ex[tgt_lang] for ex in examples["translation"]]
309312

313+
model_inputs = tokenizer(inputs, max_length=max_source_length, truncation=True)
310314
# Tokenize targets with the `text_target` keyword argument
311315
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
312316

0 commit comments

Comments
 (0)