diff --git a/poetry.lock b/poetry.lock index cb2b3b91..afc3a87e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5174,13 +5174,13 @@ type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12 [[package]] name = "sil-machine" -version = "1.4.0" +version = "1.7.1" description = "A natural language processing library that is focused on providing tools for resource-poor languages." optional = false python-versions = "<3.13,>=3.9" files = [ - {file = "sil_machine-1.4.0-py3-none-any.whl", hash = "sha256:ceea8358c426d5bd129e6a8dff779aa19430ac69087b99a7177660fd24f26eaf"}, - {file = "sil_machine-1.4.0.tar.gz", hash = "sha256:7f66567313e29654b8cf44d4c9d1156700a879313450a400e018bf06a1818e44"}, + {file = "sil_machine-1.7.1-py3-none-any.whl", hash = "sha256:346c02ac083f2f571b219e19d18b0dbc12f1c47d3c8d5a911ea1281a2c77d6f9"}, + {file = "sil_machine-1.7.1.tar.gz", hash = "sha256:00955c0c11ef755f666662817e1f92d59ba316c50b780ce9a918d15947c9a5f5"}, ] [package.dependencies] @@ -5190,15 +5190,15 @@ charset-normalizer = ">=2.1.1,<3.0.0" networkx = ">=3,<4" numpy = ">=1.24.4,<2.0.0" regex = ">=2021.7.6" -sil-thot = {version = ">=3.4.4,<4.0.0", optional = true, markers = "extra == \"thot\""} +sil-thot = {version = ">=3.4.6,<4.0.0", optional = true, markers = "extra == \"thot\""} sortedcontainers = ">=2.4.0,<3.0.0" urllib3 = "<2" [package.extras] huggingface = ["datasets (>=2.4.0,<3.0.0)", "sacremoses (>=0.0.53,<0.0.54)", "transformers (>=4.38.0,<4.46)"] -jobs = ["clearml[s3] (>=1.13.1,<2.0.0)", "dynaconf (>=3.2.5,<4.0.0)", "json-stream (>=1.3.0,<2.0.0)"] +jobs = ["clearml[s3] (>=1.13.1,<2.0.0)", "dynaconf (>=3.2.5,<4.0.0)", "eflomal (>=2.0.0,<3.0.0)", "json-stream (>=1.3.0,<2.0.0)"] sentencepiece = ["sentencepiece (>=0.2.0,<0.3.0)"] -thot = ["sil-thot (>=3.4.4,<4.0.0)"] +thot = ["sil-thot (>=3.4.6,<4.0.0)"] [[package]] name = "sil-thot" @@ -6217,4 +6217,4 @@ eflomal = ["eflomal"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "54d8af4e3f58502aa4dd5ed36edec8ceb0f4ea759b220d292252de9c28bb73e5" +content-hash = "ba728bfe6d200ee3f20f9c430df4773d255c0cf2aa8016267b2c14156062887a" diff --git a/pyproject.toml b/pyproject.toml index b17a12f0..b4408008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ s3path = "0.3.4" sacrebleu = "^2.3.1" ctranslate2 = "^3.5.1" libclang = "14.0.6" -sil-machine = {extras = ["thot"], version = "1.4.0"} +sil-machine = {extras = ["thot"], version = "1.7.1"} datasets = "^2.7.1" torch = {version = "^2.4", source = "torch"} sacremoses = "^0.0.53" diff --git a/silnlp/common/compare_usfm_structure.py b/silnlp/common/compare_usfm_structure.py index 9883ec9d..3142bd43 100644 --- a/silnlp/common/compare_usfm_structure.py +++ b/silnlp/common/compare_usfm_structure.py @@ -26,25 +26,32 @@ def _is_whitespace(self, c: str) -> bool: # Filter out embeds and ignored markers from sentence and create list of all remaining markers def filter_markers( - sent: str, stylesheet: UsfmStylesheet = UsfmStylesheet("usfm.sty"), to_ignore: List[str] = [] + sent: str, + stylesheet: UsfmStylesheet = UsfmStylesheet("usfm.sty"), + only_paragraph: bool = False, + only_style: bool = False, + to_ignore: List[str] = [], ) -> Tuple[str, List[str]]: markers = [] usfm_tokenizer = UsfmTokenizer(stylesheet) curr_embed = None filtered_sent = "" for tok in usfm_tokenizer.tokenize(sent): + base_marker = tok.marker.strip("+*") if tok.marker is not None else None if curr_embed is not None: - if tok.type == UsfmTokenType.END and tok.marker[:-1] == curr_embed.marker: + if tok.type == UsfmTokenType.END and base_marker == curr_embed: curr_embed = None elif tok.type == UsfmTokenType.TEXT: filtered_sent += tok.to_usfm() elif tok.marker is None: continue - elif tok.type == UsfmTokenType.NOTE or tok.marker in CHARACTER_TYPE_EMBEDS: + elif tok.type == UsfmTokenType.NOTE or base_marker in CHARACTER_TYPE_EMBEDS: if tok.end_marker is not None: - curr_embed = tok - elif tok.type in [UsfmTokenType.PARAGRAPH, UsfmTokenType.CHARACTER, UsfmTokenType.END]: - if tok.marker not in to_ignore: + curr_embed = base_marker + elif (tok.type == UsfmTokenType.PARAGRAPH and not only_style) or ( + tok.type in [UsfmTokenType.CHARACTER, UsfmTokenType.END] and not only_paragraph + ): + if base_marker not in to_ignore: filtered_sent += tok.to_usfm() markers.append(tok.marker) @@ -53,8 +60,13 @@ def filter_markers( # Assumes that the files have identical USFM structure def evaluate_usfm_marker_placement( - gold_book_path: Path, pred_book_path: Path, book: Optional[str] = None, to_ignore: Optional[List[str]] = [] -) -> Tuple[float, float]: + gold_book_path: Path, + pred_book_path: Path, + book: Optional[str] = None, + only_paragraph: bool = False, + only_style: bool = False, + to_ignore: List[str] = [], +) -> Optional[Tuple[float, float]]: try: settings = FileParatextProjectSettingsParser(gold_book_path.parent).parse() stylesheet = settings.stylesheet @@ -63,7 +75,6 @@ def evaluate_usfm_marker_placement( settings.encoding, settings.get_book_id(gold_book_path.name), gold_book_path, - settings.versification, include_markers=True, include_all_text=True, ) @@ -72,7 +83,6 @@ def evaluate_usfm_marker_placement( settings.encoding, settings.get_book_id(gold_book_path.name), pred_book_path, - settings.versification, include_markers=True, include_all_text=True, ) @@ -97,8 +107,8 @@ def evaluate_usfm_marker_placement( if len(gs.ref.path) > 0 and gs.ref.path[-1].name in PARAGRAPH_TYPE_EMBEDS: continue - gs_text, gold_markers = filter_markers(gs.text, stylesheet, to_ignore) - ps_text, pred_markers = filter_markers(ps.text, stylesheet, to_ignore) + gs_text, gold_markers = filter_markers(gs.text, stylesheet, only_paragraph, only_style, to_ignore) + ps_text, pred_markers = filter_markers(ps.text, stylesheet, only_paragraph, only_style, to_ignore) if len(gold_markers) == 0: continue @@ -112,6 +122,10 @@ def evaluate_usfm_marker_placement( gold_sent_toks.append(list(tokenizer.tokenize(gs_text))) pred_sent_toks.append(list(tokenizer.tokenize(ps_text)) + gold_markers) + # No verses with markers that should be evaluated + if len(gold_sent_toks) == 0: + return None + jaro_scores = [] dists_per_marker = [] for gs, ps, n in zip(gold_sent_toks, pred_sent_toks, num_markers): @@ -156,10 +170,42 @@ def main() -> None: to_ignore = args.ignored_markers[0].split(";") if len(args.ignored_markers) == 1 else args.ignored_markers - avg_jaro, avg_dist = evaluate_usfm_marker_placement(Path(args.gold), Path(args.pred), args.book, to_ignore) + # Evaluate all marker placement + scores = evaluate_usfm_marker_placement(Path(args.gold), Path(args.pred), args.book, to_ignore=to_ignore) + if scores is None: + LOGGER.info("No verses with markers found.") + exit() + LOGGER.info(f"Average (scaled) Jaro similarity of verses with placed markers: {scores[0]}") + LOGGER.info(f"Average Levenshtein distance per marker of verses with placed markers: {scores[1]}") + + # Evaluate paragraph marker placement + scores_para = evaluate_usfm_marker_placement( + Path(args.gold), Path(args.pred), args.book, only_paragraph=True, to_ignore=to_ignore + ) + if scores_para is None: + LOGGER.info("No verses with paragraph markers found.") + exit() + + # Evaluate style marker placement + scores_style = evaluate_usfm_marker_placement( + Path(args.gold), Path(args.pred), args.book, only_style=True, to_ignore=to_ignore + ) + if scores_style is None: + LOGGER.info("No verses with style markers found.") + exit() - LOGGER.info(f"Average (scaled) Jaro similarity of verses with placed markers: {avg_jaro}") - LOGGER.info(f"Average Levenshtein distance per marker of verses with placed markers: {avg_dist}") + LOGGER.info( + f"Average (scaled) Jaro similarity of verses with placed markers (only paragraph markers): {scores_para[0]}" + ) + LOGGER.info( + f"Average Levenshtein distance per marker of verses with placed markers (only paragraph markers): {scores_para[1]}" + ) + LOGGER.info( + f"Average (scaled) Jaro similarity of verses with placed markers (only style markers): {scores_style[0]}" + ) + LOGGER.info( + f"Average Levenshtein distance per marker of verses with placed markers (only style markers): {scores_style[1]}" + ) if __name__ == "__main__": diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index e8478ef6..c9ab503e 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -10,63 +10,34 @@ import docx import nltk from iso639 import Lang -from machine.corpora import ( # UpdateUsfmMarkerBehavior, UpdateUsfmTextBehavior, FileParatextProjectTextUpdater, +from machine.corpora import ( FileParatextProjectSettingsParser, - UpdateUsfmBehavior, + FileParatextProjectTextUpdater, + UpdateUsfmMarkerBehavior, UpdateUsfmParserHandler, + UpdateUsfmTextBehavior, UsfmFileText, - UsfmParserState, UsfmStylesheet, - UsfmTokenType, + UsfmTextType, parse_usfm, ) from machine.scripture import VerseRef from .corpus import load_corpus, write_corpus from .paratext import get_book_path, get_iso, get_project_dir -from .usfm_preservation import StatisticalUsfmPreserver +from .usfm_preservation import PARAGRAPH_TYPE_EMBEDS, construct_place_markers_handler LOGGER = logging.getLogger(__package__ + ".translate") nltk.download("punkt") -class ParagraphUpdateUsfmParserHandler(UpdateUsfmParserHandler): - def _collect_tokens(self, state: UsfmParserState) -> None: - self._tokens.extend(self._new_tokens) - self._new_tokens.clear() - while self._token_index <= state.index + state.special_token_count: - if ( - state.tokens[self._token_index].type == UsfmTokenType.PARAGRAPH - and state.tokens[self._token_index].marker != "rem" - ): - # Because paragraph marker tokens are passed at the end of the verse, - # it is necessary to calculate how many positions each one needs to move up. - # This logic inserts them between the earliest instance of consecutive text tokens in the verse - num_text = 0 - rem_offset = 0 - for i in range(len(self._tokens) - 1, -1, -1): - if self._tokens[i].type == UsfmTokenType.TEXT: - num_text += 1 - elif self._tokens[i].type == UsfmTokenType.PARAGRAPH and self._tokens[i].marker == "rem": - rem_offset += num_text + 1 - num_text = 0 - else: - break - if num_text >= 2: - self._tokens.insert(-(rem_offset + num_text - 1), state.tokens[self._token_index]) - self._token_index += 1 - break # should this be continue instead? is there just no difference bc only 1 paragraph marker is added at a time? - self._tokens.append(state.tokens[self._token_index]) - self._token_index += 1 - - def insert_draft_remark( usfm: str, book: str, description: str, experiment_ckpt_str: str, ) -> str: - remark = f"\\rem This draft of {book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully.\n" + remark = f"\\rem This draft of {book} was machine translated on {date.today()} from {description} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." lines = usfm.split("\n") insert_idx = ( @@ -216,14 +187,12 @@ def translate_usfm( src_settings.get_book_id(src_file_path.name), src_file_path, src_settings.versification, - include_markers=True, include_all_text=True, project=src_settings.name, ) else: - src_file_text = UsfmFileText( - "usfm.sty", "utf-8-sig", "", src_file_path, include_markers=True, include_all_text=True - ) + src_file_text = UsfmFileText("usfm.sty", "utf-8-sig", "", src_file_path, include_all_text=True) + stylesheet = src_settings.stylesheet if src_from_project else UsfmStylesheet("usfm.sty") sentences = [re.sub(" +", " ", s.text.strip()) for s in src_file_text] vrefs = [s.ref for s in src_file_text] @@ -231,23 +200,14 @@ def translate_usfm( # Filter sentences for i in reversed(range(len(sentences))): - if len(chapters) > 0 and vrefs[i].chapter_num not in chapters: + marker = vrefs[i].path[-1].name if len(vrefs[i].path) > 0 else "" + if ( + (len(chapters) > 0 and vrefs[i].chapter_num not in chapters) + or marker in PARAGRAPH_TYPE_EMBEDS + or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT + ): sentences.pop(i) vrefs.pop(i) - - usfm_preserver = StatisticalUsfmPreserver( - sentences, - vrefs, - src_settings.stylesheet if src_from_project else UsfmStylesheet("usfm.sty"), - include_paragraph_markers, - include_style_markers, - include_embeds, - "eflomal", - ) - sentences = usfm_preserver.src_sents - vrefs = usfm_preserver.vrefs - - # Don't translate empty sentences empty_sents = [] for i in reversed(range(len(sentences))): if len(sentences[i].strip()) == 0: @@ -266,61 +226,57 @@ def translate_usfm( vrefs.insert(idx, vref) output.insert(idx, [None, None, None, None]) + # Update behaviors + text_behavior = ( + UpdateUsfmTextBehavior.PREFER_NEW if trg_project is not None else UpdateUsfmTextBehavior.STRIP_EXISTING + ) + paragraph_behavior = ( + UpdateUsfmMarkerBehavior.PRESERVE if include_paragraph_markers else UpdateUsfmMarkerBehavior.STRIP + ) + style_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_style_markers else UpdateUsfmMarkerBehavior.STRIP + embed_behavior = UpdateUsfmMarkerBehavior.PRESERVE if include_embeds else UpdateUsfmMarkerBehavior.STRIP + draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - rows = usfm_preserver.construct_rows(translated_draft) + rows = [([ref], translation) for ref, translation in zip(vrefs, translated_draft)] + + update_block_handlers = [] + if include_paragraph_markers or include_style_markers or include_embeds: + update_block_handlers.append(construct_place_markers_handler(vrefs, sentences, translated_draft)) # Insert translation into the USFM structure of an existing project # If the target project is not the same as the translated file's original project, # no verses outside of the ones translated will be overwritten - with open(src_file_path, encoding=src_settings.encoding if src_from_project else "utf-8-sig") as f: - usfm = f.read() - handler = ParagraphUpdateUsfmParserHandler( - rows=rows, - id_text=vrefs[0].book, - behavior=UpdateUsfmBehavior.STRIP_EXISTING, - ) - stylesheet = src_settings.stylesheet if src_from_project else UsfmStylesheet("usfm.sty") - parse_usfm(usfm, handler, stylesheet, src_settings.versification if src_from_project else None) - usfm_out = handler.get_usfm(stylesheet) - - # NOTE: Above is a temporary use of the USFM updater, - # and below is the version compatible with the most current version of sil-machine/machine.py (1.6.2) - - # if trg_project is not None or src_from_project: - # dest_updater = FileParatextProjectTextUpdater( - # get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) - # ) - # usfm_out = dest_updater.update_usfm( - # book_id=src_file_text.id, - # rows=rows, - # text_behavior=( - # UpdateUsfmTextBehavior.PREFER_NEW - # if trg_project is not None - # else UpdateUsfmTextBehavior.STRIP_EXISTING - # ), - # paragraph_behavior=UpdateUsfmMarkerBehavior.STRIP, - # embed_behavior=UpdateUsfmMarkerBehavior.STRIP, - # style_behavior=UpdateUsfmMarkerBehavior.STRIP, - # preserve_paragraph_styles=[], - # ) - - # if usfm_out is None: - # raise FileNotFoundError(f"Book {src_file_text.id} does not exist in target project {trg_project}") - # else: # Slightly more manual version for updating an individual file - # with open(src_file_path, encoding="utf-8-sig") as f: - # usfm = f.read() - # handler = UpdateUsfmParserHandler( - # rows=rows, - # id_text=vrefs[0].book, - # text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - # paragraph_behavior=UpdateUsfmMarkerBehavior.STRIP, - # embed_behavior=UpdateUsfmMarkerBehavior.STRIP, - # style_behavior=UpdateUsfmMarkerBehavior.STRIP, - # preserve_paragraph_styles=[], - # ) - # parse_usfm(usfm, handler) - # usfm_out = handler.get_usfm() + if trg_project is not None or src_from_project: + dest_updater = FileParatextProjectTextUpdater( + get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) + ) + usfm_out = dest_updater.update_usfm( + book_id=src_file_text.id, + rows=rows, + text_behavior=text_behavior, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + + if usfm_out is None: + raise FileNotFoundError(f"Book {src_file_text.id} does not exist in target project {trg_project}") + else: # Slightly more manual version for updating an individual file + with open(src_file_path, encoding="utf-8-sig") as f: + usfm = f.read() + handler = UpdateUsfmParserHandler( + rows=rows, + id_text=vrefs[0].book, + text_behavior=text_behavior, + paragraph_behavior=paragraph_behavior, + embed_behavior=embed_behavior, + style_behavior=style_behavior, + update_block_handlers=update_block_handlers, + ) + parse_usfm(usfm, handler) + usfm_out = handler.get_usfm() # Insert draft remark and write to output path description = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" diff --git a/silnlp/common/usfm_preservation.py b/silnlp/common/usfm_preservation.py index 59c1dc50..70d67e04 100644 --- a/silnlp/common/usfm_preservation.py +++ b/silnlp/common/usfm_preservation.py @@ -1,10 +1,9 @@ from abc import abstractmethod from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Tuple +from typing import List -from machine.annotations import Range -from machine.corpora import ScriptureRef, UsfmStylesheet, UsfmToken, UsfmTokenizer, UsfmTokenType +from machine.corpora import PlaceMarkersAlignmentInfo, PlaceMarkersUsfmUpdateBlockHandler, ScriptureRef from machine.tokenization import LatinWordTokenizer from machine.translation import WordAlignmentMatrix @@ -19,366 +18,31 @@ NON_NOTE_TYPE_EMBEDS = CHARACTER_TYPE_EMBEDS + PARAGRAPH_TYPE_EMBEDS -class UsfmPreserver: - _src_sents: List[str] - _vrefs: List[ScriptureRef] - - # (sent_idx, start idx in text_only_sent, is_paragraph_marker, tok (inc. \ and spaces)) - _markers: List[Tuple[int, int, str]] - # (sent_idx, embed) - _char_embeds: List[Tuple[int, str]] - # (sent_idx, ref, embed contents) - _para_embeds: List[Tuple[int, ScriptureRef, str]] - - def __init__( - self, - src_sents: List[str], - vrefs: List[ScriptureRef], - stylesheet: UsfmStylesheet, - include_paragraph_markers: bool = False, - include_style_markers: bool = False, - include_embeds: bool = False, - ): - # Remove sentences that are paragraph-type embeds - # NOTE: when only dealing with inserting back into the same USFM structure, i.e. translate_usfm's trg_project is None, - # paragraph-type embeds can be handled more simply with the updater's preserve_paragraph_styles argument, - # but because this approach is necessary when updating a project different from the source, we use it for both cases - src_sents, self._vrefs = self._remove_para_embeds(src_sents, vrefs, include_embeds) - - usfm_tokenizer = UsfmTokenizer(stylesheet) - sentence_toks = [] - for sent in src_sents: - sentence_toks.append(usfm_tokenizer.tokenize(sent)) - - # Take markers and character-type embeds out of sentences - self._src_sents = self._extract_markers( - sentence_toks, include_paragraph_markers, include_style_markers, include_embeds - ) - self._src_tok_ranges = self._tokenize_sents(self._src_sents) - - # Source sentences without USFM markers or embeds, to be used as input to an MT model - @property - def src_sents(self) -> List[str]: - return self._src_sents - - @property - def vrefs(self) -> List[ScriptureRef]: - return self._vrefs - - def _remove_para_embeds( - self, sents: List[str], vrefs: List[ScriptureRef], include_embeds: bool - ) -> List[ScriptureRef]: - para_embeds = [] - for i, (sent, ref) in reversed(list(enumerate(zip(sents, vrefs)))): - if (ref.path[-1].name if len(ref.path) > 0 else "") in PARAGRAPH_TYPE_EMBEDS: - para_embeds.append((i, ref, sent if include_embeds else "")) - sents.pop(i) - vrefs.pop(i) - - self._para_embeds = list(reversed(para_embeds)) - return sents, vrefs - - def _extract_markers( - self, - sentence_toks: List[List[UsfmToken]], - include_paragraph_markers: bool, - include_style_markers: bool, - include_embeds: bool, - ) -> List[str]: - markers = [] - char_embeds = [] - text_only_sents = ["" for _ in sentence_toks] - for i, toks in enumerate(sentence_toks): - embed_usfm = "" - curr_embed = None - for tok in toks: - if curr_embed is not None: - embed_usfm += tok.to_usfm() - if tok.type == UsfmTokenType.END and tok.marker[:-1] == curr_embed.marker: - if include_embeds: - char_embeds.append((i, embed_usfm)) - embed_usfm = "" - curr_embed = None - elif tok.type == UsfmTokenType.NOTE or tok.marker in CHARACTER_TYPE_EMBEDS: - embed_usfm += tok.to_usfm() - curr_embed = tok - elif tok.type == UsfmTokenType.PARAGRAPH and include_paragraph_markers: - markers.append((i, len(text_only_sents[i]), True, tok.to_usfm())) - elif tok.type in [UsfmTokenType.CHARACTER, UsfmTokenType.END] and include_style_markers: - markers.append((i, len(text_only_sents[i]), False, tok.to_usfm())) - elif tok.type == UsfmTokenType.TEXT: - text_only_sents[i] += tok.text - - self._markers = markers - self._char_embeds = char_embeds - return text_only_sents - - def construct_rows(self, translations: List[str]) -> List[Tuple[List[ScriptureRef], str]]: - # Map each token to a character range in the original strings - trg_tok_ranges = self._tokenize_sents(translations) - - # Get index of the text token immediately following each marker and predict the corresponding token on the target side - adj_src_toks = [] - for sent_idx, start_idx, _, _ in self._markers: - for i, tok_range in reversed(list(enumerate(self._src_tok_ranges[sent_idx]))): - if tok_range.start < start_idx: - adj_src_toks.append(i + 1) - break - if i == 0: - adj_src_toks.append(i) - adj_trg_toks = self._predict_marker_locations(adj_src_toks, translations, trg_tok_ranges) - - # Collect the markers to be inserted - to_insert = [[] for _ in trg_tok_ranges] - for i, ((sent_idx, _, is_para_marker, marker), adj_trg_tok) in enumerate(zip(self._markers, adj_trg_toks)): - trg_str_idx = ( - trg_tok_ranges[sent_idx][adj_trg_tok].start - if adj_trg_tok < len(trg_tok_ranges[sent_idx]) - else len(translations[sent_idx]) +def get_alignment_matrices( + src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" +) -> List[WordAlignmentMatrix]: + with TemporaryDirectory() as td: + align_path = Path(td, "sym-align.txt") + write_corpus(Path(td, "src_align.txt"), src_sents) + write_corpus(Path(td, "trg_align.txt"), trg_sents) + compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), aligner, align_path) + + return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] + + +def construct_place_markers_handler( + refs: List[ScriptureRef], source: List[str], translation: List[str], aligner: str = "eflomal" +) -> PlaceMarkersUsfmUpdateBlockHandler: + align_info = [] + tokenizer = LatinWordTokenizer() + alignments = get_alignment_matrices(source, translation, aligner) + for ref, s, t, alignment in zip(refs, source, translation, alignments): + align_info.append( + PlaceMarkersAlignmentInfo( + refs=[str(ref)], + source_tokens=list(tokenizer.tokenize(s)), + translation_tokens=list(tokenizer.tokenize(t)), + alignment=alignment, ) - - # Determine the order of the markers in the sentence to handle ambiguity for directly adjacent markers - insert_pos = 0 - while insert_pos < len(to_insert[sent_idx]) and to_insert[sent_idx][insert_pos][0] <= trg_str_idx: - insert_pos += 1 - to_insert[sent_idx].insert(insert_pos, (trg_str_idx, is_para_marker, marker)) - - # Construct rows for the USFM file - embed_idx = 0 - para_embed_idx = 0 - rows = [] - - # Add any paragraph-style embeds that come before the main sentences - while para_embed_idx < len(self._para_embeds) and self._para_embeds[para_embed_idx][0] == para_embed_idx: - rows.append(([self._para_embeds[para_embed_idx][1]], self._para_embeds[para_embed_idx][2])) - para_embed_idx += 1 - - for i, (ref, translation, inserts) in enumerate(zip(self._vrefs, translations, to_insert)): - # row_text = translation[: inserts[0][0]] if len(inserts) > 0 else translation - row_texts = [translation[: inserts[0][0]]] if len(inserts) > 0 else [translation] - - for j, (insert_idx, is_para_marker, marker) in enumerate(inserts): - if is_para_marker: - row_texts.append("") - - # row_text += ( - row_texts[-1] += ( - # ("\n" if is_para_marker else "") - # + marker - (marker if not is_para_marker else "") - + ( - " " # Extra space if inserting an end marker before a non-punctuation character - if "*" in marker and insert_idx < len(translation) and translation[insert_idx].isalpha() - else "" - ) - + ( - translation[insert_idx : inserts[j + 1][0]] - if j + 1 < len(inserts) - else translation[insert_idx:] - ) - ) - # Prevent spaces before end markers - # if j + 1 < len(inserts) and "*" in inserts[j + 1][2] and len(row_text) > 0 and row_text[-1] == " ": - # row_text = row_text[:-1] - if ( - j + 1 < len(inserts) - and "*" in inserts[j + 1][2] - and len(row_texts[-1]) > 0 - and row_texts[-1][-1] == " " - ): - row_texts[-1] = row_texts[-1][:-1] - - # Append any transferred embeds that match the current ScriptureRef - while embed_idx < len(self._char_embeds) and self._char_embeds[embed_idx][0] == i: - # row_text += self._char_embeds[embed_idx][1] - row_texts[-1] += self._char_embeds[embed_idx][1] - embed_idx += 1 - - # rows.append(([ref], row_text)) - for row_text in row_texts: - rows.append(([ref], row_text)) - - # (sent_idx, ref, embed contents) - # sent_idx == orig idx, in order - while ( - para_embed_idx < len(self._para_embeds) - and self._para_embeds[para_embed_idx][0] == i + 1 + para_embed_idx - ): - rows.append(([self._para_embeds[para_embed_idx][1]], self._para_embeds[para_embed_idx][2])) - para_embed_idx += 1 - - # # Add transferred paragraph-type embeds - # for sent_idx, ref, sent in self._para_embeds: - # rows.insert(sent_idx, ([ref], sent)) - - return rows - - @abstractmethod - def _tokenize_sents(self, sents: List[str]) -> List[List[Range[int]]]: ... - - @abstractmethod - def _get_alignment_matrices(self, trg_sents: List[str]) -> List[WordAlignmentMatrix]: ... - - def _predict_marker_locations( - self, adj_src_toks: List[int], trg_sents: List[str], trg_tok_ranges: List[List[Range[int]]] - ) -> List[int]: - if len(adj_src_toks) == 0: - return [] - - alignment_matrices: List[WordAlignmentMatrix] = self._get_alignment_matrices(trg_sents) - - # Gets the number of alignment pairs that "cross the line" between - # the src marker position and the potential trg marker position, (src_idx - .5) and (trg_idx - .5) - def num_align_crossings(sent_idx: int, src_idx: int, trg_idx: int) -> int: - crossings = 0 - alignment = alignment_matrices[sent_idx] - for i in range(alignment.row_count): - for j in range(alignment.column_count): - if alignment[i, j] and ((i < src_idx and j >= trg_idx) or (i >= src_idx and j < trg_idx)): - crossings += 1 - return crossings - - adj_trg_toks = [] - for (sent_idx, _, _, _), adj_src_tok in zip(self._markers, adj_src_toks): - # If the token on either side of a potential target location is punctuation, - # use it as the basis for deciding the target marker location - trg_hyp = -1 - punct_hyps = [-1, 0] - for punct_hyp in punct_hyps: - src_hyp = adj_src_tok + punct_hyp - if src_hyp < 0 or src_hyp >= len(self._src_tok_ranges[sent_idx]): - continue - # only accept aligned pairs where both the src and trg token are punct - src_hyp_range = self._src_tok_ranges[sent_idx][src_hyp] - if ( - src_hyp_range.length > 0 - and not any(self._src_sents[sent_idx][char_idx].isalpha() for char_idx in src_hyp_range) - and src_hyp < alignment_matrices[sent_idx].row_count - ): - aligned_trg_toks = list(alignment_matrices[sent_idx].get_row_aligned_indices(src_hyp)) - # if aligning to a token that precedes that marker, - # the trg token predicted to be closest to the marker is the last token aligned to the src rather than the first - for trg_tok in reversed(aligned_trg_toks) if punct_hyp < 0 else aligned_trg_toks: - trg_tok_range = trg_tok_ranges[sent_idx][trg_tok] - if not any(trg_sents[sent_idx][char_idx].isalpha() for char_idx in trg_tok_range): - trg_hyp = trg_tok - break - if trg_hyp != -1: - # since adj_trg_toks points to the token after the marker, - # adjust the index when aligning to punctuation that precedes the token - adj_trg_toks.append(trg_hyp - punct_hyp) - break - if trg_hyp != -1: - continue - - hyps = [0, 1, 2] - best_hyp = -1 - best_num_crossings = 200**2 # mostly meaningless, a big number - checked = set() - for hyp in hyps: - src_hyp = adj_src_tok + hyp - if src_hyp in checked: - continue - trg_hyp = -1 - while trg_hyp == -1 and src_hyp >= 0 and src_hyp < alignment_matrices[sent_idx].row_count: - checked.add(src_hyp) - aligned_trg_toks = list(alignment_matrices[sent_idx].get_row_aligned_indices(src_hyp)) - if len(aligned_trg_toks) > 0: - # if aligning with a source token that precedes the marker, - # the target token predicted to be closest to the marker is the last aligned token rather than the first - trg_hyp = aligned_trg_toks[0 if hyp >= 0 else -1] - else: # continue the search outwards - src_hyp += -1 if hyp < 0 else 1 - if trg_hyp != -1: - num_crossings = num_align_crossings(sent_idx, src_hyp, trg_hyp) - if num_crossings < best_num_crossings: - best_hyp = trg_hyp - best_num_crossings = num_crossings - - # if no alignments found, insert at the end of the sentence - if best_hyp == -1: - adj_trg_toks.append(len(trg_tok_ranges[sent_idx])) - continue - - adj_trg_toks.append(best_hyp) - - return adj_trg_toks - - -class StatisticalUsfmPreserver(UsfmPreserver): - def __init__( - self, - src_sentences, - vrefs, - stylesheet, - include_paragraph_markers, - include_style_markers, - include_embeds, - aligner="eflomal", - ): - self._aligner = aligner - super().__init__( - src_sentences, vrefs, stylesheet, include_paragraph_markers, include_style_markers, include_embeds ) - - def _tokenize_sents(self, sents: List[str]) -> List[List[Range[int]]]: - tokenizer = LatinWordTokenizer() - return [list(tokenizer.tokenize_as_ranges(sent)) for sent in sents] - - def _get_alignment_matrices(self, trg_sents: List[str]) -> List[WordAlignmentMatrix]: - with TemporaryDirectory() as td: - align_path = Path(td, "sym-align.txt") - write_corpus(Path(td, "src_align.txt"), self._src_sents) - write_corpus(Path(td, "trg_align.txt"), trg_sents) - compute_alignment_scores(Path(td, "src_align.txt"), Path(td, "trg_align.txt"), self._aligner, align_path) - - return [to_word_alignment_matrix(line) for line in load_corpus(align_path)] - - -""" -Necessary changes to use AttentionUsfmPreserver: -* Use machine.py's _TranslationPipeline (in translation.huggingface.hugging_face_nmt_engine) - to output attentions along with the translations -* Build TranslationResults based on the outputs of the pipeline, using the logic from - HuggingFaceNmtEngine._try_translate_n_batch (also in translation.huggingface.hugging_face_nmt_engine) -""" - - -class AttentionUsfmPreserver(UsfmPreserver): - def __init__(self, src_sents, vrefs, stylesheet, include_paragraph_markers, include_style_markers, include_embeds): - raise NotImplementedError( - "AttentionUsfmPreserver is not a supported class. See class definition for more information about the work needed to use." - ) - - """ - def construct_rows(self, translation_results: List[TranslationResult]) -> List[Tuple[List[ScriptureRef], str]]: - self._translation_results = translation_results - # TODO: do source token ranges need to be reconstructed for each draft? - self._src_tok_ranges = self._construct_tok_ranges([tr.source_tokens for tr in translation_results]) - - super().construct_rows(translation_results) - - # NOTE: only used for target side - # NOTE: _tokenize_sents is called for the source side in UsfmPreserver.__init__, but it will get overwritten - # with the correct tokens when construct_rows is called - def _tokenize_sents(self, sents: List[str]) -> Tuple[List[List[Range[int]]], List[List[str]]]: - return self._construct_tok_ranges([tr.target_tokens for tr in self._translation_results]) - - # NOTE: the "▁" characters in this function are from the NllbTokenizer and are not the same character as the typical underscore - def _construct_tok_ranges(self, toks: List[str]) -> List[List[Range[int]]]: - tok_ranges = [] - for sent_toks in toks: - sent_tok_ranges = [Range.create(0, len(sent_toks[0]) - 1)] - for tok in sent_toks[1:]: - sent_tok_ranges.append( - Range.create( - sent_tok_ranges[-1].end + (1 if tok[0] == "▁" else 0), - sent_tok_ranges[-1].end + len(tok), - ) - ) - tok_ranges.append(sent_tok_ranges) - return tok_ranges - - def _get_alignment_matrices(self) -> List[WordAlignmentMatrix]: - return [tr.alignment for tr in self._translation_results] - """ + return PlaceMarkersUsfmUpdateBlockHandler(align_info)