Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pythainlp/spell/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def correct(word: str, engine: str = "pn") -> str:
* *pn* - Peter Norvig's algorithm [#norvig_spellchecker]_ (default)
* *phunspell* - A spell checker utilizing spylls a port of Hunspell.
* *symspellpy* - symspellpy is a Python port of SymSpell v6.5.
* *wanchanberta_thai_grammarly* - WanchanBERTa Thai Grammarly
:return: the corrected word
:rtype: str

Expand Down Expand Up @@ -128,6 +129,11 @@ def correct(word: str, engine: str = "pn") -> str:
from pythainlp.spell.symspellpy import correct as SPELL_CHECKER

text_correct = SPELL_CHECKER(word)
elif engine == "wanchanberta_thai_grammarly":
from pythainlp.spell.wanchanberta_thai_grammarly import correct as SPELL_CHECKER

text_correct = SPELL_CHECKER(word)

else:
text_correct = DEFAULT_SPELL_CHECKER.correct(word)

Expand Down Expand Up @@ -181,6 +187,7 @@ def correct_sent(list_words: List[str], engine: str = "pn") -> List[str]:
* *pn* - Peter Norvig's algorithm [#norvig_spellchecker]_ (default)
* *phunspell* - A spell checker utilizing spylls a port of Hunspell.
* *symspellpy* - symspellpy is a Python port of SymSpell v6.5.
* *wanchanberta_thai_grammarly* - WanchanBERTa Thai Grammarly
:return: the corrected list sentences of word
:rtype: List[str]

Expand Down
120 changes: 120 additions & 0 deletions pythainlp/spell/wanchanberta_thai_grammarly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2016-2023 PyThaiNLP Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Two-stage Thai Misspelling Correction based on Pre-trained Language Models

:See Also:
* Paper: \
https://ieeexplore.ieee.org/abstract/document/10202006
* GitHub: \
https://github.com/bookpanda/Two-stage-Thai-Misspelling-Correction-Based-on-Pre-trained-Language-Models
"""
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer, BertForTokenClassification
import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")

class BertModel(torch.nn.Module):
def __init__(self):
super(BertModel, self).__init__()
self.bert = BertForTokenClassification.from_pretrained('bookpanda/wangchanberta-base-att-spm-uncased-tagging')

def forward(self, input_id, mask, label):
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
return output

tagging_model = BertModel()
if use_cuda:
tagging_model = tagging_model.to(device=device)
ids_to_labels = {0: 'f', 1: 'i'}

def align_word_ids(texts):
tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)
c = tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids)
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:

if word_idx is None:
label_ids.append(-100)
else:
try:
label_ids.append(2)
except:
label_ids.append(-100)

previous_word_idx = word_idx
return label_ids

def evaluate_one_text(model, sentence):
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
mask = text['attention_mask'][0].unsqueeze(0).to(device)
input_id = text['input_ids'][0].unsqueeze(0).to(device)
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
# print(f"input_ids: {input_id}")
# print(f"attnetion_mask: {mask}")
# print(f"label_ids: {label_ids}")

logits = tagging_model(input_id, mask, None)
logits_clean = logits[0][label_ids != -100]
# print(f"logits_clean: {logits_clean}")

predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [ids_to_labels[i] for i in predictions]
return prediction_label


mlm_model = AutoModelForMaskedLM.from_pretrained("bookpanda/wangchanberta-base-att-spm-uncased-masking")
if use_cuda:
mlm_model = mlm_model.to(device=device)

def correct(text):
ans = []
i_f = evaluate_one_text(tagging_model, text)
a = tokenizer(text)
b = a['input_ids']
c = tokenizer.convert_ids_to_tokens(b)
i_f_len = len(i_f)
for j in range(i_f_len):
if i_f[j] == 'i':
ph = a['input_ids'][j+1]
a['input_ids'][j+1] = 25004
b = {'input_ids': torch.Tensor([a['input_ids']]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([a['attention_mask']]).type(torch.int64).to(device)}
token_logits = mlm_model(**b).logits
mask_token_index = torch.where(b["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
ans.append((j, top_5_tokens[0]))
text = ''.join(tokenizer.convert_ids_to_tokens(a['input_ids']))
a['input_ids'][j+1] = ph
for x,y in ans:
a['input_ids'][x+1] = y
final_output = tokenizer.convert_ids_to_tokens(a['input_ids'])
if "<s>" in final_output:
final_output.remove("<s>")
if "</s>" in final_output:
final_output.remove("</s>")
if "" in final_output:
final_output.remove("")
if final_output[0] == '▁':
final_output.pop(0)
final_output = ''.join(final_output)
final_output = final_output.replace("▁", " ")
final_output = final_output.replace("", "")
return final_output
7 changes: 7 additions & 0 deletions tests/test_spell.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def test_word_correct(self):
self.assertIsInstance(result, str)
self.assertNotEqual(result, "")

result = correct("ทดสอง", engine="wanchanberta_thai_grammarly")
self.assertIsInstance(result, str)
self.assertNotEqual(result, "")

def test_norvig_spell_checker(self):
checker = NorvigSpellChecker(dict_filter=None)
self.assertTrue(len(checker.dictionary()) > 0)
Expand Down Expand Up @@ -132,6 +136,9 @@ def test_correct_sent(self):
self.assertIsNotNone(
correct_sent(self.spell_sent, engine="symspellpy")
)
self.assertIsNotNone(
correct_sent(self.spell_sent, engine="wanchanberta_thai_grammarly")
)
self.assertIsNotNone(
symspellpy.correct_sent(self.spell_sent)
)