Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions pythainlp/augment/lm/phayathaibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@


class ThaiTextAugmenter:
def __init__(self,) -> None:
from transformers import (AutoTokenizer,
AutoModelForMaskedLM,
pipeline,)
def __init__(self) -> None:
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
pipeline,
)

self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(
_MODEL_NAME
)
self.model = pipeline(
"fill-mask",
tokenizer=self.tokenizer,
model=self.model_for_masked_lm,
)
self.processor = ThaiTextProcessor()

def generate(self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False
) -> str:
def generate(
self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False,
) -> str:
sample_txt = sample_text
final_text = ""

Expand All @@ -45,11 +55,9 @@ def generate(self,

return gen_txt

def augment(self,
text: str,
num_augs: int = 3,
sample: bool = False
) -> List[str]:
def augment(
self, text: str, num_augs: int = 3, sample: bool = False
) -> List[str]:
"""
Text augmentation from PhayaThaiBERT

Expand Down Expand Up @@ -84,11 +92,14 @@ def augment(self,
if num_augs <= MAX_NUM_AUGS:
for rank in range(num_augs):
gen_text = self.generate(text, rank, sample=sample)
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
processed_text = re.sub(
"<_>", " ", self.processor.preprocess(gen_text)
)
augment_list.append(processed_text)
else:
raise ValueError(
f"augmentation of more than {num_augs} is exceeded \
the default limit: {MAX_NUM_AUGS}"
)

return augment_list

raise ValueError(
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
)
return augment_list
8 changes: 5 additions & 3 deletions pythainlp/augment/lm/wangchanberta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

from typing import List

from transformers import (
CamembertTokenizer,
pipeline,
Expand Down Expand Up @@ -51,9 +53,9 @@ def generate(self, sentence: str, num_replace_tokens: int = 3):

def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
"""
Text Augment from wangchanberta
Text augmentation from WangchanBERTa

:param str sentence: thai sentence
:param str sentence: Thai sentence
:param int num_replace_tokens: number replace tokens

:return: list of text augment
Expand All @@ -64,7 +66,7 @@ def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:

from pythainlp.augment.lm import Thai2transformersAug

aug=Thai2transformersAug()
aug = Thai2transformersAug()

aug.augment("ช้างมีทั้งหมด 50 ตัว บน")
# output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้',
Expand Down
1 change: 1 addition & 0 deletions pythainlp/augment/wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import OrderedDict
import itertools
from typing import List

from nltk.corpus import wordnet as wn
from pythainlp.corpus import wordnet
from pythainlp.tokenize import word_tokenize
Expand Down
1 change: 1 addition & 0 deletions pythainlp/khavee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

__all__ = ["KhaveeVerifier"]

from pythainlp.khavee.core import KhaveeVerifier
Loading