Skip to content

Commit 3b6daf0

Browse files
authored
Merge pull request #889 from bact/dev
phayathaibert, khavee, parse: Code clean up
2 parents ff74b39 + de1a1bc commit 3b6daf0

File tree

13 files changed

+727
-474
lines changed

13 files changed

+727
-474
lines changed

pythainlp/augment/lm/phayathaibert.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,31 @@
1313

1414

1515
class ThaiTextAugmenter:
16-
def __init__(self,) -> None:
17-
from transformers import (AutoTokenizer,
18-
AutoModelForMaskedLM,
19-
pipeline,)
16+
def __init__(self) -> None:
17+
from transformers import (
18+
AutoTokenizer,
19+
AutoModelForMaskedLM,
20+
pipeline,
21+
)
22+
2023
self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
21-
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
22-
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
24+
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(
25+
_MODEL_NAME
26+
)
27+
self.model = pipeline(
28+
"fill-mask",
29+
tokenizer=self.tokenizer,
30+
model=self.model_for_masked_lm,
31+
)
2332
self.processor = ThaiTextProcessor()
2433

25-
def generate(self,
26-
sample_text: str,
27-
word_rank: int,
28-
max_length: int = 3,
29-
sample: bool = False
30-
) -> str:
34+
def generate(
35+
self,
36+
sample_text: str,
37+
word_rank: int,
38+
max_length: int = 3,
39+
sample: bool = False,
40+
) -> str:
3141
sample_txt = sample_text
3242
final_text = ""
3343

@@ -45,11 +55,9 @@ def generate(self,
4555

4656
return gen_txt
4757

48-
def augment(self,
49-
text: str,
50-
num_augs: int = 3,
51-
sample: bool = False
52-
) -> List[str]:
58+
def augment(
59+
self, text: str, num_augs: int = 3, sample: bool = False
60+
) -> List[str]:
5361
"""
5462
Text augmentation from PhayaThaiBERT
5563
@@ -84,11 +92,14 @@ def augment(self,
8492
if num_augs <= MAX_NUM_AUGS:
8593
for rank in range(num_augs):
8694
gen_text = self.generate(text, rank, sample=sample)
87-
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
95+
processed_text = re.sub(
96+
"<_>", " ", self.processor.preprocess(gen_text)
97+
)
8898
augment_list.append(processed_text)
99+
else:
100+
raise ValueError(
101+
f"augmentation of more than {num_augs} is exceeded \
102+
the default limit: {MAX_NUM_AUGS}"
103+
)
89104

90-
return augment_list
91-
92-
raise ValueError(
93-
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
94-
)
105+
return augment_list

pythainlp/augment/lm/wangchanberta.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
33
# SPDX-License-Identifier: Apache-2.0
4+
45
from typing import List
6+
57
from transformers import (
68
CamembertTokenizer,
79
pipeline,
@@ -51,9 +53,9 @@ def generate(self, sentence: str, num_replace_tokens: int = 3):
5153

5254
def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
5355
"""
54-
Text Augment from wangchanberta
56+
Text augmentation from WangchanBERTa
5557
56-
:param str sentence: thai sentence
58+
:param str sentence: Thai sentence
5759
:param int num_replace_tokens: number replace tokens
5860
5961
:return: list of text augment
@@ -64,7 +66,7 @@ def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
6466
6567
from pythainlp.augment.lm import Thai2transformersAug
6668
67-
aug=Thai2transformersAug()
69+
aug = Thai2transformersAug()
6870
6971
aug.augment("ช้างมีทั้งหมด 50 ตัว บน")
7072
# output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้',

pythainlp/augment/wordnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections import OrderedDict
1313
import itertools
1414
from typing import List
15+
1516
from nltk.corpus import wordnet as wn
1617
from pythainlp.corpus import wordnet
1718
from pythainlp.tokenize import word_tokenize

pythainlp/khavee/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
33
# SPDX-License-Identifier: Apache-2.0
4+
45
__all__ = ["KhaveeVerifier"]
56

67
from pythainlp.khavee.core import KhaveeVerifier

0 commit comments

Comments
 (0)