Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pythainlp/corpus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"thai_syllables",
"thai_words",
"path_pythainlp_corpus",
"get_path_folder_corpus",
]

import os
Expand Down Expand Up @@ -84,6 +85,7 @@ def corpus_db_path() -> str:
get_corpus_db_detail,
get_corpus_default_db,
get_corpus_path,
get_path_folder_corpus,
remove,
path_pythainlp_corpus,
) # these imports must come before other pythainlp.corpus.* imports
Expand Down
4 changes: 4 additions & 0 deletions pythainlp/corpus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,7 @@ def remove(name: str) -> bool:

db.close()
return False


def get_path_folder_corpus(name, version, *path):
return os.path.join(get_corpus_path(name, version), *path)
32 changes: 32 additions & 0 deletions pythainlp/tag/lst20_ner_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
from typing import List
from pythainlp.tag.wangchanberta_onnx import WngchanBerta_ONNX


class LST20_NER_ONNX(WngchanBerta_ONNX):
def __init__(self, providers: List[str] = ['CPUExecutionProvider']) -> None:
WngchanBerta_ONNX.__init__(
self,
model_name="onnx_lst20ner",
model_version="1.0",
file_onnx="lst20-ner-model.onnx",
providers=providers
)

def clean_output(self, list_text):
new_list = []
if list_text[0][0] == "▁":
list_text = list_text[1:]
for i, j in list_text:
if i.startswith("▁") and i != '▁':
i = i.replace("▁", "", 1)
elif i == '▁':
i = " "
new_list.append((i, j))
return new_list

def _config(self, list_ner):
_n = []
for i, j in list_ner:
_n.append((i, j.replace('E_', 'I_').replace('_', '-')))
return _n
6 changes: 5 additions & 1 deletion pythainlp/tag/named_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class NER:
**Options for engine**
* *thainer* - Thai NER engine
* *wangchanberta* - wangchanberta model
* *lst20_onnx* - LST20 NER model by wangchanberta with ONNX runtime
* *tltk* - wrapper for `TLTK <https://pypi.org/project/tltk/>`_.

**Options for corpus**
Expand All @@ -33,6 +34,9 @@ def load_engine(self, engine: str, corpus: str) -> None:
if engine == "thainer" and corpus == "thainer":
from pythainlp.tag.thainer import ThaiNameTagger
self.engine = ThaiNameTagger()
elif engine == "lst20_onnx":
from pythainlp.tag.lst20_ner_onnx import LST20_NER_ONNX
self.engine = LST20_NER_ONNX()
elif engine == "wangchanberta":
from pythainlp.wangchanberta import ThaiNameTagger
self.engine = ThaiNameTagger(dataset_name=corpus)
Expand Down Expand Up @@ -88,7 +92,7 @@ def tag(
"""wangchanberta is not support part-of-speech tag.
It have not part-of-speech tag in output."""
)
if self.name_engine == "wangchanberta":
if self.name_engine == "wangchanberta" or self.name_engine == "lst20_onnx":
return self.engine.get_ner(text, tag=tag)
else:
return self.engine.get_ner(text, tag=tag, pos=pos)
112 changes: 112 additions & 0 deletions pythainlp/tag/wangchanberta_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
from typing import List
import json
import sentencepiece as spm
import numpy as np
from onnxruntime import (
InferenceSession, SessionOptions, GraphOptimizationLevel
)
from pythainlp.corpus import get_path_folder_corpus


class WngchanBerta_ONNX:
def __init__(self, model_name: str, model_version: str, file_onnx: str, providers: List[str] = ['CPUExecutionProvider']) -> None:
self.model_name = model_name
self.model_version = model_version
self.options = SessionOptions()
self.options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = InferenceSession(
get_path_folder_corpus(
self.model_name,
self.model_version,
file_onnx
),
sess_options=self.options,
providers=providers
)
self.session.disable_fallback()
self.outputs_name = self.session.get_outputs()[0].name
self.sp = spm.SentencePieceProcessor(
model_file=get_path_folder_corpus(
self.model_name,
self.model_version,
"sentencepiece.bpe.model"
)
)
with open(
get_path_folder_corpus(
self.model_name,
self.model_version,
"config.json"
),
encoding='utf-8-sig'
) as fh:
self._json = json.load(fh)
self.id2tag = self._json['id2label']

def build_tokenizer(self, sent):
_t = [5]+[i+4 for i in self.sp.encode(sent)]+[6]
model_inputs = {}
model_inputs["input_ids"] = np.array([_t], dtype=np.int64)
model_inputs["attention_mask"] = np.array(
[[1]*len(_t)], dtype=np.int64
)
return model_inputs

def postprocess(self, logits_data):
logits_t = logits_data[0]
maxes = np.max(logits_t, axis=-1, keepdims=True)
shifted_exp = np.exp(logits_t - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
return scores

def clean_output(self, list_text):
return list_text

def totag(self, post, sent):
tag = []
_s = self.sp.EncodeAsPieces(sent)
for i in range(len(_s)):
tag.append(
(
_s[i],
self.id2tag[
str(list(post[i+1]).index(max(list(post[i+1]))))
]
)
)
return tag

def _config(self, list_ner):
return list_ner

def get_ner(self, text: str, tag: bool = False):
self._s = self.build_tokenizer(text)
logits = self.session.run(
output_names=[self.outputs_name],
input_feed=self._s
)[0]
_tag = self.clean_output(self.totag(self.postprocess(logits), text))
if tag:
_tag = self._config(_tag)
temp = ""
sent = ""
for idx, (word, ner) in enumerate(_tag):
if ner.startswith("B-") and temp != "":
sent += "</" + temp + ">"
temp = ner[2:]
sent += "<" + temp + ">"
elif ner.startswith("B-"):
temp = ner[2:]
sent += "<" + temp + ">"
elif ner == "O" and temp != "":
sent += "</" + temp + ">"
temp = ""
sent += word

if idx == len(_tag) - 1 and temp != "":
sent += "</" + temp + ">"

return sent
else:
return _tag
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
"tltk": ["tltk>=1.3.8"],
"oskut": ["oskut>=1.3"],
"nlpo3": ["nlpo3>=1.2.2"],
"onnx": [
"sentencepiece>=0.1.91",
"numpy>=1.16.1",
"onnxruntime>=1.10.0"
],
"full": [
"PyYAML>=5.3.1",
"attacut>=1.0.4",
Expand All @@ -100,6 +105,7 @@
"tltk>=1.3.8",
"oskut>=1.3",
"nlpo3>=1.2.2",
"onnxruntime>=1.10.0",
],
}

Expand Down
3 changes: 3 additions & 0 deletions tests/test_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def test_NER_class(self):
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
ner = NER(engine="lst20_onnx")
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", tag=True))
ner = NER(engine="tltk")
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า"))
self.assertIsNotNone(ner.tag("แมวทำอะไรตอนห้าโมงเช้า", pos=False))
Expand Down