diff --git a/pythainlp/corpus/__init__.py b/pythainlp/corpus/__init__.py index 9a2d2c812..cef277a28 100644 --- a/pythainlp/corpus/__init__.py +++ b/pythainlp/corpus/__init__.py @@ -27,6 +27,7 @@ "thai_syllables", "thai_words", "path_pythainlp_corpus", + "get_path_folder_corpus", ] import os @@ -84,6 +85,7 @@ def corpus_db_path() -> str: get_corpus_db_detail, get_corpus_default_db, get_corpus_path, + get_path_folder_corpus, remove, path_pythainlp_corpus, ) # these imports must come before other pythainlp.corpus.* imports diff --git a/pythainlp/corpus/core.py b/pythainlp/corpus/core.py index 74f976809..0196acdbb 100644 --- a/pythainlp/corpus/core.py +++ b/pythainlp/corpus/core.py @@ -522,3 +522,7 @@ def remove(name: str) -> bool: db.close() return False + + +def get_path_folder_corpus(name, version, *path): + return os.path.join(get_corpus_path(name, version), *path) diff --git a/pythainlp/tag/lst20_ner_onnx.py b/pythainlp/tag/lst20_ner_onnx.py new file mode 100644 index 000000000..eaa7f2ab7 --- /dev/null +++ b/pythainlp/tag/lst20_ner_onnx.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from typing import List +from pythainlp.tag.wangchanberta_onnx import WngchanBerta_ONNX + + +class LST20_NER_ONNX(WngchanBerta_ONNX): + def __init__(self, providers: List[str] = ['CPUExecutionProvider']) -> None: + WngchanBerta_ONNX.__init__( + self, + model_name="onnx_lst20ner", + model_version="1.0", + file_onnx="lst20-ner-model.onnx", + providers=providers + ) + + def clean_output(self, list_text): + new_list = [] + if list_text[0][0] == "▁": + list_text = list_text[1:] + for i, j in list_text: + if i.startswith("▁") and i != '▁': + i = i.replace("▁", "", 1) + elif i == '▁': + i = " " + new_list.append((i, j)) + return new_list + + def _config(self, list_ner): + _n = [] + for i, j in list_ner: + _n.append((i, j.replace('E_', 'I_').replace('_', '-'))) + return _n diff --git a/pythainlp/tag/named_entity.py b/pythainlp/tag/named_entity.py index 8a0f1b842..d745b4d3e 100644 --- a/pythainlp/tag/named_entity.py +++ b/pythainlp/tag/named_entity.py @@ -16,6 +16,7 @@ class NER: **Options for engine** * *thainer* - Thai NER engine * *wangchanberta* - wangchanberta model + * *lst20_onnx* - LST20 NER model by wangchanberta with ONNX runtime * *tltk* - wrapper for `TLTK `_. **Options for corpus** @@ -33,6 +34,9 @@ def load_engine(self, engine: str, corpus: str) -> None: if engine == "thainer" and corpus == "thainer": from pythainlp.tag.thainer import ThaiNameTagger self.engine = ThaiNameTagger() + elif engine == "lst20_onnx": + from pythainlp.tag.lst20_ner_onnx import LST20_NER_ONNX + self.engine = LST20_NER_ONNX() elif engine == "wangchanberta": from pythainlp.wangchanberta import ThaiNameTagger self.engine = ThaiNameTagger(dataset_name=corpus) @@ -88,7 +92,7 @@ def tag( """wangchanberta is not support part-of-speech tag. It have not part-of-speech tag in output.""" ) - if self.name_engine == "wangchanberta": + if self.name_engine == "wangchanberta" or self.name_engine == "lst20_onnx": return self.engine.get_ner(text, tag=tag) else: return self.engine.get_ner(text, tag=tag, pos=pos) diff --git a/pythainlp/tag/wangchanberta_onnx.py b/pythainlp/tag/wangchanberta_onnx.py new file mode 100644 index 000000000..ae619918d --- /dev/null +++ b/pythainlp/tag/wangchanberta_onnx.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +from typing import List +import json +import sentencepiece as spm +import numpy as np +from onnxruntime import ( + InferenceSession, SessionOptions, GraphOptimizationLevel +) +from pythainlp.corpus import get_path_folder_corpus + + +class WngchanBerta_ONNX: + def __init__(self, model_name: str, model_version: str, file_onnx: str, providers: List[str] = ['CPUExecutionProvider']) -> None: + self.model_name = model_name + self.model_version = model_version + self.options = SessionOptions() + self.options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + self.session = InferenceSession( + get_path_folder_corpus( + self.model_name, + self.model_version, + file_onnx + ), + sess_options=self.options, + providers=providers + ) + self.session.disable_fallback() + self.outputs_name = self.session.get_outputs()[0].name + self.sp = spm.SentencePieceProcessor( + model_file=get_path_folder_corpus( + self.model_name, + self.model_version, + "sentencepiece.bpe.model" + ) + ) + with open( + get_path_folder_corpus( + self.model_name, + self.model_version, + "config.json" + ), + encoding='utf-8-sig' + ) as fh: + self._json = json.load(fh) + self.id2tag = self._json['id2label'] + + def build_tokenizer(self, sent): + _t = [5]+[i+4 for i in self.sp.encode(sent)]+[6] + model_inputs = {} + model_inputs["input_ids"] = np.array([_t], dtype=np.int64) + model_inputs["attention_mask"] = np.array( + [[1]*len(_t)], dtype=np.int64 + ) + return model_inputs + + def postprocess(self, logits_data): + logits_t = logits_data[0] + maxes = np.max(logits_t, axis=-1, keepdims=True) + shifted_exp = np.exp(logits_t - maxes) + scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) + return scores + + def clean_output(self, list_text): + return list_text + + def totag(self, post, sent): + tag = [] + _s = self.sp.EncodeAsPieces(sent) + for i in range(len(_s)): + tag.append( + ( + _s[i], + self.id2tag[ + str(list(post[i+1]).index(max(list(post[i+1])))) + ] + ) + ) + return tag + + def _config(self, list_ner): + return list_ner + + def get_ner(self, text: str, tag: bool = False): + self._s = self.build_tokenizer(text) + logits = self.session.run( + output_names=[self.outputs_name], + input_feed=self._s + )[0] + _tag = self.clean_output(self.totag(self.postprocess(logits), text)) + if tag: + _tag = self._config(_tag) + temp = "" + sent = "" + for idx, (word, ner) in enumerate(_tag): + if ner.startswith("B-") and temp != "": + sent += "" + temp = ner[2:] + sent += "<" + temp + ">" + elif ner.startswith("B-"): + temp = ner[2:] + sent += "<" + temp + ">" + elif ner == "O" and temp != "": + sent += "" + temp = "" + sent += word + + if idx == len(_tag) - 1 and temp != "": + sent += "" + + return sent + else: + return _tag diff --git a/setup.py b/setup.py index 1dc67d129..29bbf51b1 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,11 @@ "tltk": ["tltk>=1.3.8"], "oskut": ["oskut>=1.3"], "nlpo3": ["nlpo3>=1.2.2"], + "onnx": [ + "sentencepiece>=0.1.91", + "numpy>=1.16.1", + "onnxruntime>=1.10.0" + ], "full": [ "PyYAML>=5.3.1", "attacut>=1.0.4", @@ -100,6 +105,7 @@ "tltk>=1.3.8", "oskut>=1.3", "nlpo3>=1.2.2", + "onnxruntime>=1.10.0", ], } diff --git a/tests/test_tag.py b/tests/test_tag.py index 536f65eae..63e887747 100644 --- a/tests/test_tag.py +++ b/tests/test_tag.py @@ -361,6 +361,9 @@ def test_NER_class(self): self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า")) self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False)) self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True)) + ner = NER(engine="lst20_onnx") + self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า")) + self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True)) ner = NER(engine="tltk") self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า")) self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))