|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Iterable, List, TypedDict |
| 4 | + |
| 5 | +from ..translation.word_alignment_matrix import WordAlignmentMatrix |
| 6 | +from .usfm_token import UsfmToken, UsfmTokenType |
| 7 | +from .usfm_update_block import UsfmUpdateBlock |
| 8 | +from .usfm_update_block_element import UsfmUpdateBlockElement, UsfmUpdateBlockElementType |
| 9 | +from .usfm_update_block_handler import UsfmUpdateBlockHandler |
| 10 | + |
| 11 | + |
| 12 | +class PlaceMarkersAlignmentInfo(TypedDict): |
| 13 | + refs: List[str] |
| 14 | + source_tokens: List[str] |
| 15 | + translation_tokens: List[str] |
| 16 | + alignment: WordAlignmentMatrix |
| 17 | + |
| 18 | + |
| 19 | +class PlaceMarkersUsfmUpdateBlockHandler(UsfmUpdateBlockHandler): |
| 20 | + |
| 21 | + def __init__(self, align_info: Iterable[PlaceMarkersAlignmentInfo]) -> None: |
| 22 | + self._align_info = {info["refs"][0]: info for info in align_info} |
| 23 | + |
| 24 | + def process_block(self, block: UsfmUpdateBlock) -> UsfmUpdateBlock: |
| 25 | + ref = str(block.refs[0]) |
| 26 | + elements = list(block.elements) |
| 27 | + |
| 28 | + # Nothing to do if there are no markers to place or no alignment to use |
| 29 | + if ( |
| 30 | + len(elements) == 0 |
| 31 | + or ref not in self._align_info.keys() |
| 32 | + or self._align_info[ref]["alignment"].row_count == 0 |
| 33 | + or self._align_info[ref]["alignment"].column_count == 0 |
| 34 | + or not any( |
| 35 | + ( |
| 36 | + e.type in [UsfmUpdateBlockElementType.PARAGRAPH, UsfmUpdateBlockElementType.STYLE] |
| 37 | + and not e.marked_for_removal |
| 38 | + ) |
| 39 | + for e in elements |
| 40 | + ) |
| 41 | + ): |
| 42 | + return block |
| 43 | + |
| 44 | + # Paragraph markers at the end of the block should stay there |
| 45 | + # Section headers should be ignored but re-inserted in the same position relative to other paragraph markers |
| 46 | + end_elements = [] |
| 47 | + eob_empty_paras = True |
| 48 | + header_elements = [] |
| 49 | + para_markers_left = 0 |
| 50 | + for i, element in reversed(list(enumerate(elements))): |
| 51 | + if element.type == UsfmUpdateBlockElementType.PARAGRAPH and not element.marked_for_removal: |
| 52 | + if len(element.tokens) > 1: |
| 53 | + header_elements.insert(0, (para_markers_left, element)) |
| 54 | + elements.pop(i) |
| 55 | + else: |
| 56 | + para_markers_left += 1 |
| 57 | + |
| 58 | + if eob_empty_paras: |
| 59 | + end_elements.insert(0, element) |
| 60 | + elements.pop(i) |
| 61 | + elif not ( |
| 62 | + element.type == UsfmUpdateBlockElementType.EMBED |
| 63 | + or (element.type == UsfmUpdateBlockElementType.TEXT and len(element.tokens[0].to_usfm().strip()) == 0) |
| 64 | + ): |
| 65 | + eob_empty_paras = False |
| 66 | + |
| 67 | + src_toks = self._align_info[ref]["source_tokens"] |
| 68 | + trg_toks = self._align_info[ref]["translation_tokens"] |
| 69 | + src_tok_idx = 0 |
| 70 | + |
| 71 | + src_sent = "" |
| 72 | + trg_sent = "" |
| 73 | + to_place = [] |
| 74 | + adj_src_toks = [] |
| 75 | + placed_elements = [elements.pop(0)] if elements[0].type == UsfmUpdateBlockElementType.OTHER else [] |
| 76 | + ignored_elements = [] |
| 77 | + for element in elements: |
| 78 | + if element.type == UsfmUpdateBlockElementType.TEXT: |
| 79 | + if element.marked_for_removal: |
| 80 | + text = element.tokens[0].to_usfm() |
| 81 | + src_sent += text |
| 82 | + |
| 83 | + # Track seen tokens |
| 84 | + while src_tok_idx < len(src_toks) and src_toks[src_tok_idx] in text: |
| 85 | + text = text[text.index(src_toks[src_tok_idx]) + len(src_toks[src_tok_idx]) :] |
| 86 | + src_tok_idx += 1 |
| 87 | + # Handle tokens split across text elements |
| 88 | + if len(text.strip()) > 0: |
| 89 | + src_tok_idx += 1 |
| 90 | + else: |
| 91 | + trg_sent += element.tokens[0].to_usfm() |
| 92 | + |
| 93 | + if element.marked_for_removal or element.type == UsfmUpdateBlockElementType.EMBED: |
| 94 | + ignored_elements.append(element) |
| 95 | + elif element.type in [UsfmUpdateBlockElementType.PARAGRAPH, UsfmUpdateBlockElementType.STYLE]: |
| 96 | + to_place.append(element) |
| 97 | + adj_src_toks.append(src_tok_idx) |
| 98 | + |
| 99 | + trg_tok_starts = [] |
| 100 | + for tok in trg_toks: |
| 101 | + trg_tok_starts.append(trg_sent.index(tok, trg_tok_starts[-1] + 1 if len(trg_tok_starts) > 0 else 0)) |
| 102 | + |
| 103 | + # Predict marker placements and get insertion order |
| 104 | + to_insert = [] |
| 105 | + for element, adj_src_tok in zip(to_place, adj_src_toks): |
| 106 | + adj_trg_tok = self._predict_marker_location( |
| 107 | + self._align_info[ref]["alignment"], adj_src_tok, src_toks, trg_toks |
| 108 | + ) |
| 109 | + trg_str_idx = trg_tok_starts[adj_trg_tok] if adj_trg_tok < len(trg_tok_starts) else len(trg_sent) |
| 110 | + |
| 111 | + to_insert.append((trg_str_idx, element)) |
| 112 | + to_insert.sort(key=lambda x: x[0]) |
| 113 | + to_insert += [(len(trg_sent), element) for element in end_elements] |
| 114 | + |
| 115 | + # Construct new text tokens to put between markers |
| 116 | + # and reincorporate headers and empty end-of-verse paragraph markers |
| 117 | + if to_insert[0][0] > 0: |
| 118 | + placed_elements.append( |
| 119 | + UsfmUpdateBlockElement( |
| 120 | + UsfmUpdateBlockElementType.TEXT, [UsfmToken(UsfmTokenType.TEXT, text=trg_sent[: to_insert[0][0]])] |
| 121 | + ) |
| 122 | + ) |
| 123 | + for j, (insert_idx, element) in enumerate(to_insert): |
| 124 | + if element.type == UsfmUpdateBlockElementType.PARAGRAPH: |
| 125 | + while len(header_elements) > 0 and header_elements[0][0] == para_markers_left: |
| 126 | + placed_elements.append(header_elements.pop(0)[1]) |
| 127 | + para_markers_left -= 1 |
| 128 | + |
| 129 | + placed_elements.append(element) |
| 130 | + if insert_idx < len(trg_sent) and (j + 1 == len(to_insert) or insert_idx < to_insert[j + 1][0]): |
| 131 | + if j + 1 < len(to_insert): |
| 132 | + text_token = UsfmToken(UsfmTokenType.TEXT, text=(trg_sent[insert_idx : to_insert[j + 1][0]])) |
| 133 | + else: |
| 134 | + text_token = UsfmToken(UsfmTokenType.TEXT, text=(trg_sent[insert_idx:])) |
| 135 | + placed_elements.append(UsfmUpdateBlockElement(UsfmUpdateBlockElementType.TEXT, [text_token])) |
| 136 | + while len(header_elements) > 0: |
| 137 | + placed_elements.append(header_elements.pop(0)[1]) |
| 138 | + |
| 139 | + block._elements = placed_elements + ignored_elements |
| 140 | + return block |
| 141 | + |
| 142 | + def _predict_marker_location( |
| 143 | + self, |
| 144 | + alignment: WordAlignmentMatrix, |
| 145 | + adj_src_tok: int, |
| 146 | + src_toks: List[str], |
| 147 | + trg_toks: List[str], |
| 148 | + ) -> int: |
| 149 | + # Gets the number of alignment pairs that "cross the line" between |
| 150 | + # the src marker position and the potential trg marker position, (src_idx - .5) and (trg_idx - .5) |
| 151 | + def num_align_crossings(src_idx: int, trg_idx: int) -> int: |
| 152 | + crossings = 0 |
| 153 | + for i in range(alignment.row_count): |
| 154 | + for j in range(alignment.column_count): |
| 155 | + if alignment[i, j] and ((i < src_idx and j >= trg_idx) or (i >= src_idx and j < trg_idx)): |
| 156 | + crossings += 1 |
| 157 | + return crossings |
| 158 | + |
| 159 | + # If the token on either side of a potential target location is punctuation, |
| 160 | + # use it as the basis for deciding the target marker location |
| 161 | + trg_hyp = -1 |
| 162 | + punct_hyps = [-1, 0] |
| 163 | + for punct_hyp in punct_hyps: |
| 164 | + src_hyp = adj_src_tok + punct_hyp |
| 165 | + if src_hyp < 0 or src_hyp >= len(src_toks): |
| 166 | + continue |
| 167 | + # Only accept aligned pairs where both the src and trg token are punctuation |
| 168 | + hyp_tok = src_toks[src_hyp] |
| 169 | + if len(hyp_tok) > 0 and not any(c.isalpha() for c in hyp_tok) and src_hyp < alignment.row_count: |
| 170 | + aligned_trg_toks = list(alignment.get_row_aligned_indices(src_hyp)) |
| 171 | + # If aligning to a token that precedes that marker, |
| 172 | + # the trg token predicted to be closest to the marker |
| 173 | + # is the last token aligned to the src rather than the first |
| 174 | + for trg_idx in reversed(aligned_trg_toks) if punct_hyp < 0 else aligned_trg_toks: |
| 175 | + trg_tok = trg_toks[trg_idx] |
| 176 | + if len(trg_tok) > 0 and not any(c.isalpha() for c in trg_tok): |
| 177 | + trg_hyp = trg_idx |
| 178 | + break |
| 179 | + if trg_hyp != -1: |
| 180 | + # Since the marker location is represented by the token after the marker, |
| 181 | + # adjust the index when aligning to punctuation that precedes the token |
| 182 | + return trg_hyp + (1 if punct_hyp == -1 else 0) |
| 183 | + |
| 184 | + hyps = [0, 1, 2] |
| 185 | + best_hyp = -1 |
| 186 | + best_num_crossings = 200**2 # mostly meaningless, a big number |
| 187 | + checked = set() |
| 188 | + for hyp in hyps: |
| 189 | + src_hyp = adj_src_tok + hyp |
| 190 | + if src_hyp in checked: |
| 191 | + continue |
| 192 | + trg_hyp = -1 |
| 193 | + while trg_hyp == -1 and src_hyp >= 0 and src_hyp < alignment.row_count: |
| 194 | + checked.add(src_hyp) |
| 195 | + aligned_trg_toks = list(alignment.get_row_aligned_indices(src_hyp)) |
| 196 | + if len(aligned_trg_toks) > 0: |
| 197 | + # If aligning with a source token that precedes the marker, |
| 198 | + # the target token predicted to be closest to the marker is the last aligned token rather than the first |
| 199 | + trg_hyp = aligned_trg_toks[-1 if hyp < 0 else 0] |
| 200 | + else: # continue the search outwards |
| 201 | + src_hyp += -1 if hyp < 0 else 1 |
| 202 | + if trg_hyp != -1: |
| 203 | + num_crossings = num_align_crossings(adj_src_tok, trg_hyp) |
| 204 | + if num_crossings < best_num_crossings: |
| 205 | + best_hyp = trg_hyp |
| 206 | + best_num_crossings = num_crossings |
| 207 | + if num_crossings == 0: |
| 208 | + break |
| 209 | + |
| 210 | + # If no alignments found, insert at the end of the sentence |
| 211 | + return best_hyp if best_hyp != -1 else len(trg_toks) |
0 commit comments