Skip to content

Commit e27d410

Browse files
isaac091Ben King
authored andcommitted
Marker placement update block handler (#175)
* Marker placement update block handler * Refactor marker placement handler, small bug fixes * Extend and clean up tests, more code cleanup * Fix imports, use separate AlignmentInfo type * Adjust (PlaceMarkers)AlignmentInfo type * 'toks' --> 'tokens'
1 parent bb3e47b commit e27d410

File tree

7 files changed

+695
-14
lines changed

7 files changed

+695
-14
lines changed

machine/corpora/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .paratext_project_terms_parser_base import ParatextProjectTermsParserBase
2424
from .paratext_project_text_updater_base import ParatextProjectTextUpdaterBase
2525
from .paratext_text_corpus import ParatextTextCorpus
26+
from .place_markers_usfm_update_block_handler import PlaceMarkersAlignmentInfo, PlaceMarkersUsfmUpdateBlockHandler
2627
from .quotation_denormalization_scripture_update_block_handler import (
2728
QuotationDenormalizationScriptureUpdateBlockHandler,
2829
)
@@ -115,6 +116,8 @@
115116
"ParatextProjectTermsParserBase",
116117
"ParatextProjectTextUpdaterBase",
117118
"ParatextTextCorpus",
119+
"PlaceMarkersAlignmentInfo",
120+
"PlaceMarkersUsfmUpdateBlockHandler",
118121
"parse_usfm",
119122
"QuotationDenormalizationScriptureUpdateBlockHandler",
120123
"RtlReferenceOrder",
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterable, List, TypedDict
4+
5+
from ..translation.word_alignment_matrix import WordAlignmentMatrix
6+
from .usfm_token import UsfmToken, UsfmTokenType
7+
from .usfm_update_block import UsfmUpdateBlock
8+
from .usfm_update_block_element import UsfmUpdateBlockElement, UsfmUpdateBlockElementType
9+
from .usfm_update_block_handler import UsfmUpdateBlockHandler
10+
11+
12+
class PlaceMarkersAlignmentInfo(TypedDict):
13+
refs: List[str]
14+
source_tokens: List[str]
15+
translation_tokens: List[str]
16+
alignment: WordAlignmentMatrix
17+
18+
19+
class PlaceMarkersUsfmUpdateBlockHandler(UsfmUpdateBlockHandler):
20+
21+
def __init__(self, align_info: Iterable[PlaceMarkersAlignmentInfo]) -> None:
22+
self._align_info = {info["refs"][0]: info for info in align_info}
23+
24+
def process_block(self, block: UsfmUpdateBlock) -> UsfmUpdateBlock:
25+
ref = str(block.refs[0])
26+
elements = list(block.elements)
27+
28+
# Nothing to do if there are no markers to place or no alignment to use
29+
if (
30+
len(elements) == 0
31+
or ref not in self._align_info.keys()
32+
or self._align_info[ref]["alignment"].row_count == 0
33+
or self._align_info[ref]["alignment"].column_count == 0
34+
or not any(
35+
(
36+
e.type in [UsfmUpdateBlockElementType.PARAGRAPH, UsfmUpdateBlockElementType.STYLE]
37+
and not e.marked_for_removal
38+
)
39+
for e in elements
40+
)
41+
):
42+
return block
43+
44+
# Paragraph markers at the end of the block should stay there
45+
# Section headers should be ignored but re-inserted in the same position relative to other paragraph markers
46+
end_elements = []
47+
eob_empty_paras = True
48+
header_elements = []
49+
para_markers_left = 0
50+
for i, element in reversed(list(enumerate(elements))):
51+
if element.type == UsfmUpdateBlockElementType.PARAGRAPH and not element.marked_for_removal:
52+
if len(element.tokens) > 1:
53+
header_elements.insert(0, (para_markers_left, element))
54+
elements.pop(i)
55+
else:
56+
para_markers_left += 1
57+
58+
if eob_empty_paras:
59+
end_elements.insert(0, element)
60+
elements.pop(i)
61+
elif not (
62+
element.type == UsfmUpdateBlockElementType.EMBED
63+
or (element.type == UsfmUpdateBlockElementType.TEXT and len(element.tokens[0].to_usfm().strip()) == 0)
64+
):
65+
eob_empty_paras = False
66+
67+
src_toks = self._align_info[ref]["source_tokens"]
68+
trg_toks = self._align_info[ref]["translation_tokens"]
69+
src_tok_idx = 0
70+
71+
src_sent = ""
72+
trg_sent = ""
73+
to_place = []
74+
adj_src_toks = []
75+
placed_elements = [elements.pop(0)] if elements[0].type == UsfmUpdateBlockElementType.OTHER else []
76+
ignored_elements = []
77+
for element in elements:
78+
if element.type == UsfmUpdateBlockElementType.TEXT:
79+
if element.marked_for_removal:
80+
text = element.tokens[0].to_usfm()
81+
src_sent += text
82+
83+
# Track seen tokens
84+
while src_tok_idx < len(src_toks) and src_toks[src_tok_idx] in text:
85+
text = text[text.index(src_toks[src_tok_idx]) + len(src_toks[src_tok_idx]) :]
86+
src_tok_idx += 1
87+
# Handle tokens split across text elements
88+
if len(text.strip()) > 0:
89+
src_tok_idx += 1
90+
else:
91+
trg_sent += element.tokens[0].to_usfm()
92+
93+
if element.marked_for_removal or element.type == UsfmUpdateBlockElementType.EMBED:
94+
ignored_elements.append(element)
95+
elif element.type in [UsfmUpdateBlockElementType.PARAGRAPH, UsfmUpdateBlockElementType.STYLE]:
96+
to_place.append(element)
97+
adj_src_toks.append(src_tok_idx)
98+
99+
trg_tok_starts = []
100+
for tok in trg_toks:
101+
trg_tok_starts.append(trg_sent.index(tok, trg_tok_starts[-1] + 1 if len(trg_tok_starts) > 0 else 0))
102+
103+
# Predict marker placements and get insertion order
104+
to_insert = []
105+
for element, adj_src_tok in zip(to_place, adj_src_toks):
106+
adj_trg_tok = self._predict_marker_location(
107+
self._align_info[ref]["alignment"], adj_src_tok, src_toks, trg_toks
108+
)
109+
trg_str_idx = trg_tok_starts[adj_trg_tok] if adj_trg_tok < len(trg_tok_starts) else len(trg_sent)
110+
111+
to_insert.append((trg_str_idx, element))
112+
to_insert.sort(key=lambda x: x[0])
113+
to_insert += [(len(trg_sent), element) for element in end_elements]
114+
115+
# Construct new text tokens to put between markers
116+
# and reincorporate headers and empty end-of-verse paragraph markers
117+
if to_insert[0][0] > 0:
118+
placed_elements.append(
119+
UsfmUpdateBlockElement(
120+
UsfmUpdateBlockElementType.TEXT, [UsfmToken(UsfmTokenType.TEXT, text=trg_sent[: to_insert[0][0]])]
121+
)
122+
)
123+
for j, (insert_idx, element) in enumerate(to_insert):
124+
if element.type == UsfmUpdateBlockElementType.PARAGRAPH:
125+
while len(header_elements) > 0 and header_elements[0][0] == para_markers_left:
126+
placed_elements.append(header_elements.pop(0)[1])
127+
para_markers_left -= 1
128+
129+
placed_elements.append(element)
130+
if insert_idx < len(trg_sent) and (j + 1 == len(to_insert) or insert_idx < to_insert[j + 1][0]):
131+
if j + 1 < len(to_insert):
132+
text_token = UsfmToken(UsfmTokenType.TEXT, text=(trg_sent[insert_idx : to_insert[j + 1][0]]))
133+
else:
134+
text_token = UsfmToken(UsfmTokenType.TEXT, text=(trg_sent[insert_idx:]))
135+
placed_elements.append(UsfmUpdateBlockElement(UsfmUpdateBlockElementType.TEXT, [text_token]))
136+
while len(header_elements) > 0:
137+
placed_elements.append(header_elements.pop(0)[1])
138+
139+
block._elements = placed_elements + ignored_elements
140+
return block
141+
142+
def _predict_marker_location(
143+
self,
144+
alignment: WordAlignmentMatrix,
145+
adj_src_tok: int,
146+
src_toks: List[str],
147+
trg_toks: List[str],
148+
) -> int:
149+
# Gets the number of alignment pairs that "cross the line" between
150+
# the src marker position and the potential trg marker position, (src_idx - .5) and (trg_idx - .5)
151+
def num_align_crossings(src_idx: int, trg_idx: int) -> int:
152+
crossings = 0
153+
for i in range(alignment.row_count):
154+
for j in range(alignment.column_count):
155+
if alignment[i, j] and ((i < src_idx and j >= trg_idx) or (i >= src_idx and j < trg_idx)):
156+
crossings += 1
157+
return crossings
158+
159+
# If the token on either side of a potential target location is punctuation,
160+
# use it as the basis for deciding the target marker location
161+
trg_hyp = -1
162+
punct_hyps = [-1, 0]
163+
for punct_hyp in punct_hyps:
164+
src_hyp = adj_src_tok + punct_hyp
165+
if src_hyp < 0 or src_hyp >= len(src_toks):
166+
continue
167+
# Only accept aligned pairs where both the src and trg token are punctuation
168+
hyp_tok = src_toks[src_hyp]
169+
if len(hyp_tok) > 0 and not any(c.isalpha() for c in hyp_tok) and src_hyp < alignment.row_count:
170+
aligned_trg_toks = list(alignment.get_row_aligned_indices(src_hyp))
171+
# If aligning to a token that precedes that marker,
172+
# the trg token predicted to be closest to the marker
173+
# is the last token aligned to the src rather than the first
174+
for trg_idx in reversed(aligned_trg_toks) if punct_hyp < 0 else aligned_trg_toks:
175+
trg_tok = trg_toks[trg_idx]
176+
if len(trg_tok) > 0 and not any(c.isalpha() for c in trg_tok):
177+
trg_hyp = trg_idx
178+
break
179+
if trg_hyp != -1:
180+
# Since the marker location is represented by the token after the marker,
181+
# adjust the index when aligning to punctuation that precedes the token
182+
return trg_hyp + (1 if punct_hyp == -1 else 0)
183+
184+
hyps = [0, 1, 2]
185+
best_hyp = -1
186+
best_num_crossings = 200**2 # mostly meaningless, a big number
187+
checked = set()
188+
for hyp in hyps:
189+
src_hyp = adj_src_tok + hyp
190+
if src_hyp in checked:
191+
continue
192+
trg_hyp = -1
193+
while trg_hyp == -1 and src_hyp >= 0 and src_hyp < alignment.row_count:
194+
checked.add(src_hyp)
195+
aligned_trg_toks = list(alignment.get_row_aligned_indices(src_hyp))
196+
if len(aligned_trg_toks) > 0:
197+
# If aligning with a source token that precedes the marker,
198+
# the target token predicted to be closest to the marker is the last aligned token rather than the first
199+
trg_hyp = aligned_trg_toks[-1 if hyp < 0 else 0]
200+
else: # continue the search outwards
201+
src_hyp += -1 if hyp < 0 else 1
202+
if trg_hyp != -1:
203+
num_crossings = num_align_crossings(adj_src_tok, trg_hyp)
204+
if num_crossings < best_num_crossings:
205+
best_hyp = trg_hyp
206+
best_num_crossings = num_crossings
207+
if num_crossings == 0:
208+
break
209+
210+
# If no alignments found, insert at the end of the sentence
211+
return best_hyp if best_hyp != -1 else len(trg_toks)

machine/jobs/nmt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def _align(
157157
check_canceled()
158158

159159
for i in range(len(pretranslations)):
160-
pretranslations[i]["source_toks"] = list(src_tokenized[i])
161-
pretranslations[i]["translation_toks"] = list(trg_tokenized[i])
160+
pretranslations[i]["source_tokens"] = list(src_tokenized[i])
161+
pretranslations[i]["translation_tokens"] = list(trg_tokenized[i])
162162
pretranslations[i]["alignment"] = alignments[i]
163163

164164
return pretranslations

machine/jobs/translation_file_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class PretranslationInfo(TypedDict):
1616
textId: str # noqa: N815
1717
refs: List[str]
1818
translation: str
19-
source_toks: List[str]
20-
translation_toks: List[str]
19+
source_tokens: List[str]
20+
translation_tokens: List[str]
2121
alignment: str
2222

2323

@@ -65,8 +65,8 @@ def generator() -> Generator[PretranslationInfo, None, None]:
6565
textId=pi["textId"],
6666
refs=list(pi["refs"]),
6767
translation=pi["translation"],
68-
source_toks=list(pi["source_toks"]),
69-
translation_toks=list(pi["translation_toks"]),
68+
source_tokens=list(pi["source_tokens"]),
69+
translation_tokens=list(pi["translation_tokens"]),
7070
alignment=pi["alignment"],
7171
)
7272

0 commit comments

Comments
 (0)