Skip to content

Commit 74d9b1c

Browse files
ArthurZuckergojiteji
authored andcommitted
[LlamaTokenizerFast] nit update post_processor on the fly (huggingface#23855)
* Update the processor when changing add_eos and add_bos * fixup * update * add a test * fix failing tests * fixup
1 parent 7386bc3 commit 74d9b1c

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

src/transformers/models/llama/tokenization_llama_fast.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from shutil import copyfile
1717
from typing import Optional, Tuple
1818

19+
from tokenizers import processors
20+
1921
from ...tokenization_utils_fast import PreTrainedTokenizerFast
2022
from ...utils import is_sentencepiece_available, logging
2123
from ...utils.versions import require_version
@@ -84,6 +86,8 @@ def __init__(
8486
unk_token="<unk>",
8587
bos_token="<s>",
8688
eos_token="</s>",
89+
add_bos_token=True,
90+
add_eos_token=False,
8791
**kwargs,
8892
):
8993
super().__init__(
@@ -95,10 +99,50 @@ def __init__(
9599
eos_token=eos_token,
96100
**kwargs,
97101
)
102+
self._add_bos_token = add_bos_token
103+
self._add_eos_token = add_eos_token
104+
self.update_post_processor()
98105

99106
self.vocab_file = vocab_file
100107
self.can_save_slow_tokenizer = False if not self.vocab_file else True
101108

109+
def update_post_processor(self):
110+
bos = self.bos_token
111+
bos_token_id = self.bos_token_id
112+
113+
eos = self.eos_token
114+
eos_token_id = self.eos_token_id
115+
116+
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
117+
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
118+
119+
special_tokens = []
120+
if self.add_bos_token:
121+
special_tokens.append((bos, bos_token_id))
122+
if self.add_eos_token:
123+
special_tokens.append((eos, eos_token_id))
124+
self._tokenizer.post_processor = processors.TemplateProcessing(
125+
single=single, pair=pair, special_tokens=special_tokens
126+
)
127+
128+
@property
129+
def add_eos_token(self):
130+
return self._add_eos_token
131+
132+
@property
133+
def add_bos_token(self):
134+
return self._add_bos_token
135+
136+
@add_eos_token.setter
137+
def add_eos_token(self, value):
138+
self._add_eos_token = value
139+
self.update_post_processor()
140+
141+
@add_bos_token.setter
142+
def add_bos_token(self, value):
143+
self._add_bos_token = value
144+
self.update_post_processor()
145+
102146
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
103147
if not self.can_save_slow_tokenizer:
104148
raise ValueError(

tests/models/llama/test_tokenization_llama.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,39 @@ def integration_tests(self):
315315
},
316316
)
317317

318+
def test_fast_special_tokens(self):
319+
slow_tokenizer = self.tokenizer
320+
fast_tokenizer = self.rust_tokenizer
321+
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
322+
assert slow == [1, 319, 4559, 1243]
323+
324+
fast_tokenizer.add_eos_token = False
325+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
326+
assert fast == [1, 319, 4559, 1243]
327+
328+
fast_tokenizer.add_eos_token = True
329+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
330+
assert fast == [1, 319, 4559, 1243, 2]
331+
332+
slow_tokenizer.add_eos_token = True
333+
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
334+
assert slow == [1, 319, 4559, 1243, 2]
335+
336+
fast_tokenizer = LlamaTokenizerFast.from_pretrained(
337+
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
338+
)
339+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
340+
assert fast == [319, 4559, 1243, 2]
341+
342+
slow_tokenzier = LlamaTokenizer.from_pretrained(
343+
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
344+
)
345+
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
346+
assert slow == [319, 4559, 1243, 2]
347+
348+
self.tokenizer.add_eos_token = False
349+
self.rust_tokenizer.add_eos_token = False
350+
318351
@slow
319352
def test_conversion(self):
320353
# This is excruciatingly slow since it has to recreate the entire merge

0 commit comments

Comments
 (0)