diff --git a/pythainlp/spell/core.py b/pythainlp/spell/core.py index b93059c5c..a054506bc 100644 --- a/pythainlp/spell/core.py +++ b/pythainlp/spell/core.py @@ -97,6 +97,7 @@ def correct(word: str, engine: str = "pn") -> str: * *pn* - Peter Norvig's algorithm [#norvig_spellchecker]_ (default) * *phunspell* - A spell checker utilizing spylls a port of Hunspell. * *symspellpy* - symspellpy is a Python port of SymSpell v6.5. + * *wanchanberta_thai_grammarly* - WanchanBERTa Thai Grammarly :return: the corrected word :rtype: str @@ -128,6 +129,11 @@ def correct(word: str, engine: str = "pn") -> str: from pythainlp.spell.symspellpy import correct as SPELL_CHECKER text_correct = SPELL_CHECKER(word) + elif engine == "wanchanberta_thai_grammarly": + from pythainlp.spell.wanchanberta_thai_grammarly import correct as SPELL_CHECKER + + text_correct = SPELL_CHECKER(word) + else: text_correct = DEFAULT_SPELL_CHECKER.correct(word) @@ -181,6 +187,7 @@ def correct_sent(list_words: List[str], engine: str = "pn") -> List[str]: * *pn* - Peter Norvig's algorithm [#norvig_spellchecker]_ (default) * *phunspell* - A spell checker utilizing spylls a port of Hunspell. * *symspellpy* - symspellpy is a Python port of SymSpell v6.5. + * *wanchanberta_thai_grammarly* - WanchanBERTa Thai Grammarly :return: the corrected list sentences of word :rtype: List[str] diff --git a/pythainlp/spell/wanchanberta_thai_grammarly.py b/pythainlp/spell/wanchanberta_thai_grammarly.py new file mode 100644 index 000000000..9004d8838 --- /dev/null +++ b/pythainlp/spell/wanchanberta_thai_grammarly.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016-2023 PyThaiNLP Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Two-stage Thai Misspelling Correction based on Pre-trained Language Models + +:See Also: + * Paper: \ + https://ieeexplore.ieee.org/abstract/document/10202006 + * GitHub: \ + https://github.com/bookpanda/Two-stage-Thai-Misspelling-Correction-Based-on-Pre-trained-Language-Models +""" +from transformers import AutoModelForMaskedLM +from transformers import AutoTokenizer, BertForTokenClassification +import torch + +use_cuda = torch.cuda.is_available() +device = torch.device("cuda" if use_cuda else "cpu") +tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased") + +class BertModel(torch.nn.Module): + def __init__(self): + super(BertModel, self).__init__() + self.bert = BertForTokenClassification.from_pretrained('bookpanda/wangchanberta-base-att-spm-uncased-tagging') + + def forward(self, input_id, mask, label): + output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False) + return output + +tagging_model = BertModel() +if use_cuda: + tagging_model = tagging_model.to(device=device) +ids_to_labels = {0: 'f', 1: 'i'} + +def align_word_ids(texts): + tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True) + c = tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids) + word_ids = tokenized_inputs.word_ids() + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + + if word_idx is None: + label_ids.append(-100) + else: + try: + label_ids.append(2) + except: + label_ids.append(-100) + + previous_word_idx = word_idx + return label_ids + +def evaluate_one_text(model, sentence): + text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") + mask = text['attention_mask'][0].unsqueeze(0).to(device) + input_id = text['input_ids'][0].unsqueeze(0).to(device) + label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) + # print(f"input_ids: {input_id}") + # print(f"attnetion_mask: {mask}") + # print(f"label_ids: {label_ids}") + + logits = tagging_model(input_id, mask, None) + logits_clean = logits[0][label_ids != -100] + # print(f"logits_clean: {logits_clean}") + + predictions = logits_clean.argmax(dim=1).tolist() + prediction_label = [ids_to_labels[i] for i in predictions] + return prediction_label + + +mlm_model = AutoModelForMaskedLM.from_pretrained("bookpanda/wangchanberta-base-att-spm-uncased-masking") +if use_cuda: + mlm_model = mlm_model.to(device=device) + +def correct(text): + ans = [] + i_f = evaluate_one_text(tagging_model, text) + a = tokenizer(text) + b = a['input_ids'] + c = tokenizer.convert_ids_to_tokens(b) + i_f_len = len(i_f) + for j in range(i_f_len): + if i_f[j] == 'i': + ph = a['input_ids'][j+1] + a['input_ids'][j+1] = 25004 + b = {'input_ids': torch.Tensor([a['input_ids']]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([a['attention_mask']]).type(torch.int64).to(device)} + token_logits = mlm_model(**b).logits + mask_token_index = torch.where(b["input_ids"] == tokenizer.mask_token_id)[1] + mask_token_logits = token_logits[0, mask_token_index, :] + top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() + ans.append((j, top_5_tokens[0])) + text = ''.join(tokenizer.convert_ids_to_tokens(a['input_ids'])) + a['input_ids'][j+1] = ph + for x,y in ans: + a['input_ids'][x+1] = y + final_output = tokenizer.convert_ids_to_tokens(a['input_ids']) + if "" in final_output: + final_output.remove("") + if "" in final_output: + final_output.remove("") + if "" in final_output: + final_output.remove("") + if final_output[0] == '▁': + final_output.pop(0) + final_output = ''.join(final_output) + final_output = final_output.replace("▁", " ") + final_output = final_output.replace("", "") + return final_output \ No newline at end of file diff --git a/tests/test_spell.py b/tests/test_spell.py index bb273709a..55043930b 100644 --- a/tests/test_spell.py +++ b/tests/test_spell.py @@ -69,6 +69,10 @@ def test_word_correct(self): self.assertIsInstance(result, str) self.assertNotEqual(result, "") + result = correct("ทดสอง", engine="wanchanberta_thai_grammarly") + self.assertIsInstance(result, str) + self.assertNotEqual(result, "") + def test_norvig_spell_checker(self): checker = NorvigSpellChecker(dict_filter=None) self.assertTrue(len(checker.dictionary()) > 0) @@ -132,6 +136,9 @@ def test_correct_sent(self): self.assertIsNotNone( correct_sent(self.spell_sent, engine="symspellpy") ) + self.assertIsNotNone( + correct_sent(self.spell_sent, engine="wanchanberta_thai_grammarly") + ) self.assertIsNotNone( symspellpy.correct_sent(self.spell_sent) )