diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index a047a7ba..93841d9b 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -55,6 +55,8 @@ def clearml_progress(status: ProgressStatus) -> None: except TypeError as e: raise TypeError(f"Build options could not be parsed: {e}") from e SETTINGS.update({model_type: build_options}) + if "align_pretranslations" in build_options: + SETTINGS.update({"align_pretranslations": build_options["align_pretranslations"]}) SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir)) logger.info(f"Config: {SETTINGS.as_dict()}") diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 5bb120a8..b8b74594 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -157,8 +157,8 @@ def _align( check_canceled() for i in range(len(pretranslations)): - pretranslations[i]["source_tokens"] = list(src_tokenized[i]) - pretranslations[i]["translation_tokens"] = list(trg_tokenized[i]) + pretranslations[i]["sourceTokens"] = list(src_tokenized[i]) + pretranslations[i]["translationTokens"] = list(trg_tokenized[i]) pretranslations[i]["alignment"] = alignments[i] return pretranslations diff --git a/machine/jobs/translation_file_service.py b/machine/jobs/translation_file_service.py index f29e1005..e1b27947 100644 --- a/machine/jobs/translation_file_service.py +++ b/machine/jobs/translation_file_service.py @@ -16,8 +16,8 @@ class PretranslationInfo(TypedDict): textId: str # noqa: N815 refs: List[str] translation: str - source_tokens: List[str] - translation_tokens: List[str] + sourceTokens: List[str] # noqa: N815 + translationTokens: List[str] # noqa: N815 alignment: str @@ -65,9 +65,9 @@ def generator() -> Generator[PretranslationInfo, None, None]: textId=pi["textId"], refs=list(pi["refs"]), translation=pi["translation"], - source_tokens=list(pi["source_tokens"]), - translation_tokens=list(pi["translation_tokens"]), - alignment=pi["alignment"], + sourceTokens=list(), + translationTokens=list(), + alignment="", ) return ContextManagedGenerator(generator()) diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index 014fc743..efbfe385 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -38,7 +38,7 @@ def test_run(decoy: Decoy) -> None: assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." if is_eflomal_available(): - assert pretranslations[0]["source_tokens"] == [ + assert pretranslations[0]["sourceTokens"] == [ "Por", "favor", ",", @@ -48,11 +48,11 @@ def test_run(decoy: Decoy) -> None: "habitación", ".", ] - assert pretranslations[0]["translation_tokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."] + assert pretranslations[0]["translationTokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."] assert len(pretranslations[0]["alignment"]) > 0 else: - assert pretranslations[0]["source_tokens"] == [] - assert pretranslations[0]["translation_tokens"] == [] + assert pretranslations[0]["sourceTokens"] == [] + assert pretranslations[0]["translationTokens"] == [] assert len(pretranslations[0]["alignment"]) == 0 decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1) @@ -131,8 +131,8 @@ def __init__(self, decoy: Decoy) -> None: textId="text1", refs=["ref1"], translation="Por favor, tengo reservada una habitación.", - source_tokens=[], - translation_tokens=[], + sourceTokens=[], + translationTokens=[], alignment="", ) ] diff --git a/tests/jobs/test_smt_engine_build_job.py b/tests/jobs/test_smt_engine_build_job.py index eff4649f..7c8ddb06 100644 --- a/tests/jobs/test_smt_engine_build_job.py +++ b/tests/jobs/test_smt_engine_build_job.py @@ -137,8 +137,8 @@ def __init__(self, decoy: Decoy) -> None: textId="text1", refs=["ref1"], translation="Por favor, tengo reservada una habitación.", - source_tokens=[], - translation_tokens=[], + sourceTokens=[], + translationTokens=[], alignment="", ) ]