From 811abae289e3e4112bcd5e16176bcb361770c1ed Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 08:24:40 +0000 Subject: [PATCH 1/6] Update the processor when changing add_eos and add_bos --- .../models/llama/tokenization_llama_fast.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index bb2737075ea2..b542363010fd 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -15,12 +15,12 @@ import os from shutil import copyfile from typing import Optional, Tuple +from tokenizers import processors from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import is_sentencepiece_available, logging from ...utils.versions import require_version - require_version("tokenizers>=0.13.3") if is_sentencepiece_available(): @@ -84,6 +84,8 @@ def __init__( unk_token="", bos_token="", eos_token="", + add_bos_token=True, + add_eos_token=False, **kwargs, ): super().__init__( @@ -95,10 +97,49 @@ def __init__( eos_token=eos_token, **kwargs, ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self._tokenizer.post_processor = self.update_post_processor() + self.vocab_file = vocab_file self.can_save_slow_tokenizer = False if not self.vocab_file else True + def update_post_processor(self): + bos = self.bos_token + bos_token_id = self.bos_token_id + + eos = self.eos_token + eos_token_id = self.eos_token_id + + single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}" + pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not self.can_save_slow_tokenizer: raise ValueError( From 6c03a9c2b623c4927e7ea8f4866fb365ffb91b3c Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 08:29:02 +0000 Subject: [PATCH 2/6] fixup --- .../models/llama/tokenization_llama_fast.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index b542363010fd..0b943db59e41 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -15,12 +15,14 @@ import os from shutil import copyfile from typing import Optional, Tuple + from tokenizers import processors from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import is_sentencepiece_available, logging from ...utils.versions import require_version + require_version("tokenizers>=0.13.3") if is_sentencepiece_available(): @@ -100,7 +102,6 @@ def __init__( self._add_bos_token = add_bos_token self._add_eos_token = add_eos_token self._tokenizer.post_processor = self.update_post_processor() - self.vocab_file = vocab_file self.can_save_slow_tokenizer = False if not self.vocab_file else True @@ -120,26 +121,28 @@ def update_post_processor(self): special_tokens.append((bos, bos_token_id)) if self.add_eos_token: special_tokens.append((eos, eos_token_id)) - self._tokenizer.post_processor = processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens) - + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + @property def add_eos_token(self): return self._add_eos_token - + @property def add_bos_token(self): return self._add_bos_token - + @add_eos_token.setter def add_eos_token(self, value): self._add_eos_token = value self.update_post_processor() - + @add_bos_token.setter def add_bos_token(self, value): self._add_bos_token = value self.update_post_processor() - + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not self.can_save_slow_tokenizer: raise ValueError( From 2b4382ef496d8096728fc1d1e93f7b7c017175c0 Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 08:34:26 +0000 Subject: [PATCH 3/6] update --- src/transformers/models/llama/tokenization_llama_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index 0b943db59e41..c3946d83b0e0 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -101,7 +101,7 @@ def __init__( ) self._add_bos_token = add_bos_token self._add_eos_token = add_eos_token - self._tokenizer.post_processor = self.update_post_processor() + self.update_post_processor() self.vocab_file = vocab_file self.can_save_slow_tokenizer = False if not self.vocab_file else True From 7785eecdbf8a4866d2b1764afaceebc2bb53c572 Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 09:01:21 +0000 Subject: [PATCH 4/6] add a test --- tests/models/llama/test_tokenization_llama.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 6ce1bb44c03d..31e3ad9a7267 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -315,6 +315,36 @@ def integration_tests(self): }, ) + def test_fast_special_tokens(self): + slow_tokenizer = self.tokenizer + fast_tokenzier = self.rust_tokenizer + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243] + + fast_tokenzier.add_eos_token = False + fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243] + + fast_tokenzier.add_eos_token = True + fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243, 2] + + slow_tokenizer.add_eos_token = True + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243, 2] + + fast_tokenzier = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False + ) + fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + assert fast == [319, 4559, 1243, 2] + + slow_tokenzier = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False + ) + slow = slow_tokenzier.encode("A sample test", add_special_tokens=True) + assert slow == [319, 4559, 1243, 2] + @slow def test_conversion(self): # This is excruciatingly slow since it has to recreate the entire merge From 4ff55bfac1c909d0c46a8802adb4c60b7e989b80 Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 14:31:17 +0000 Subject: [PATCH 5/6] fix failing tests --- tests/models/llama/test_tokenization_llama.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 31e3ad9a7267..43b2e0c3a2e7 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -317,26 +317,26 @@ def integration_tests(self): def test_fast_special_tokens(self): slow_tokenizer = self.tokenizer - fast_tokenzier = self.rust_tokenizer + fast_tokenizer = self.rust_tokenizer slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) assert slow == [1, 319, 4559, 1243] - fast_tokenzier.add_eos_token = False - fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + fast_tokenizer.add_eos_token = False + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) assert fast == [1, 319, 4559, 1243] - fast_tokenzier.add_eos_token = True - fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + fast_tokenizer.add_eos_token = True + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) assert fast == [1, 319, 4559, 1243, 2] slow_tokenizer.add_eos_token = True slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) assert slow == [1, 319, 4559, 1243, 2] - fast_tokenzier = LlamaTokenizerFast.from_pretrained( + fast_tokenizer = LlamaTokenizerFast.from_pretrained( "hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False ) - fast = fast_tokenzier.encode("A sample test", add_special_tokens=True) + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) assert fast == [319, 4559, 1243, 2] slow_tokenzier = LlamaTokenizer.from_pretrained( @@ -344,6 +344,10 @@ def test_fast_special_tokens(self): ) slow = slow_tokenzier.encode("A sample test", add_special_tokens=True) assert slow == [319, 4559, 1243, 2] + + self.tokenizer.add_eos_token = False + self.rust_tokenizer.add_eos_token = False + @slow def test_conversion(self): From d33b8999eb36f765564dfed16677993c1c9fb411 Mon Sep 17 00:00:00 2001 From: "arthur.zucker@gmail.com" Date: Tue, 30 May 2023 14:31:32 +0000 Subject: [PATCH 6/6] fixup --- tests/models/llama/test_tokenization_llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 43b2e0c3a2e7..3a1ec2be93bf 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -344,10 +344,9 @@ def test_fast_special_tokens(self): ) slow = slow_tokenzier.encode("A sample test", add_special_tokens=True) assert slow == [319, 4559, 1243, 2] - + self.tokenizer.add_eos_token = False self.rust_tokenizer.add_eos_token = False - @slow def test_conversion(self):