13
13
14
14
15
15
class ThaiTextAugmenter :
16
- def __init__ (self ,) -> None :
17
- from transformers import (AutoTokenizer ,
18
- AutoModelForMaskedLM ,
19
- pipeline ,)
16
+ def __init__ (self ) -> None :
17
+ from transformers import (
18
+ AutoTokenizer ,
19
+ AutoModelForMaskedLM ,
20
+ pipeline ,
21
+ )
22
+
20
23
self .tokenizer = AutoTokenizer .from_pretrained (_MODEL_NAME )
21
- self .model_for_masked_lm = AutoModelForMaskedLM .from_pretrained (_MODEL_NAME )
22
- self .model = pipeline ("fill-mask" , tokenizer = self .tokenizer , model = self .model_for_masked_lm )
24
+ self .model_for_masked_lm = AutoModelForMaskedLM .from_pretrained (
25
+ _MODEL_NAME
26
+ )
27
+ self .model = pipeline (
28
+ "fill-mask" ,
29
+ tokenizer = self .tokenizer ,
30
+ model = self .model_for_masked_lm ,
31
+ )
23
32
self .processor = ThaiTextProcessor ()
24
33
25
- def generate (self ,
26
- sample_text : str ,
27
- word_rank : int ,
28
- max_length : int = 3 ,
29
- sample : bool = False
30
- ) -> str :
34
+ def generate (
35
+ self ,
36
+ sample_text : str ,
37
+ word_rank : int ,
38
+ max_length : int = 3 ,
39
+ sample : bool = False ,
40
+ ) -> str :
31
41
sample_txt = sample_text
32
42
final_text = ""
33
43
@@ -45,11 +55,9 @@ def generate(self,
45
55
46
56
return gen_txt
47
57
48
- def augment (self ,
49
- text : str ,
50
- num_augs : int = 3 ,
51
- sample : bool = False
52
- ) -> List [str ]:
58
+ def augment (
59
+ self , text : str , num_augs : int = 3 , sample : bool = False
60
+ ) -> List [str ]:
53
61
"""
54
62
Text augmentation from PhayaThaiBERT
55
63
@@ -84,11 +92,14 @@ def augment(self,
84
92
if num_augs <= MAX_NUM_AUGS :
85
93
for rank in range (num_augs ):
86
94
gen_text = self .generate (text , rank , sample = sample )
87
- processed_text = re .sub ("<_>" , " " , self .processor .preprocess (gen_text ))
95
+ processed_text = re .sub (
96
+ "<_>" , " " , self .processor .preprocess (gen_text )
97
+ )
88
98
augment_list .append (processed_text )
99
+ else :
100
+ raise ValueError (
101
+ f"augmentation of more than { num_augs } is exceeded \
102
+ the default limit: { MAX_NUM_AUGS } "
103
+ )
89
104
90
- return augment_list
91
-
92
- raise ValueError (
93
- f"augmentation of more than { num_augs } is exceeded the default limit: { MAX_NUM_AUGS } "
94
- )
105
+ return augment_list
0 commit comments