diff --git a/pythainlp/tag/pos_tag.py b/pythainlp/tag/pos_tag.py index abdfe5fc2..9fcb8ff01 100644 --- a/pythainlp/tag/pos_tag.py +++ b/pythainlp/tag/pos_tag.py @@ -180,7 +180,9 @@ def pos_tag_sents( def pos_tag_transformers( - words: str, engine: str = "bert-base-th-cased-blackboard" + words: str, + engine: str = "bert", + corpus: str = "blackboard", ): """ "wangchanberta-ud-thai-pud-upos", @@ -199,21 +201,27 @@ def pos_tag_transformers( if not words: return [] - if engine == "wangchanberta-ud-thai-pud-upos": - model = AutoModelForTokenClassification.from_pretrained( - "Pavarissy/wangchanberta-ud-thai-pud-upos") - tokenizer = AutoTokenizer.from_pretrained("Pavarissy/wangchanberta-ud-thai-pud-upos") - elif engine == "mdeberta-v3-ud-thai-pud-upos": - model = AutoModelForTokenClassification.from_pretrained( - "Pavarissy/mdeberta-v3-ud-thai-pud-upos") - tokenizer = AutoTokenizer.from_pretrained("Pavarissy/mdeberta-v3-ud-thai-pud-upos") - elif engine == "bert-base-th-cased-blackboard": - model = AutoModelForTokenClassification.from_pretrained("lunarlist/pos_thai") - tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai") + _blackboard_support_engine = { + "bert" : "lunarlist/pos_thai", + } + + _pud_support_engine = { + "wangchanberta" : "Pavarissy/wangchanberta-ud-thai-pud-upos", + "mdeberta" : "Pavarissy/mdeberta-v3-ud-thai-pud-upos", + } + + if corpus == 'blackboard' and engine in _blackboard_support_engine.keys(): + base_model = _blackboard_support_engine.get(engine) + model = AutoModelForTokenClassification.from_pretrained(base_model) + tokenizer = AutoTokenizer.from_pretrained(base_model) + elif corpus == 'pud' and engine in _pud_support_engine.keys(): + base_model = _pud_support_engine.get(engine) + model = AutoModelForTokenClassification.from_pretrained(base_model) + tokenizer = AutoTokenizer.from_pretrained(base_model) else: raise ValueError( - "pos_tag_transformers not support {0} engine.".format( - engine + "pos_tag_transformers not support {0} engine or {1} corpus.".format( + engine, corpus ) ) diff --git a/tests/test_tag.py b/tests/test_tag.py index 8d1755b18..b5529ec5b 100644 --- a/tests/test_tag.py +++ b/tests/test_tag.py @@ -367,10 +367,13 @@ def test_NNER_class(self): def test_pos_tag_transformers(self): self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert-base-th-cased-blackboard")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert", corpus="blackboard")) self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta-v3-ud-thai-pud-upos")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta", corpus="pud")) self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta-ud-thai-pud-upos")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta", corpus="pud")) with self.assertRaises(ValueError): - pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine") \ No newline at end of file + pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine") + with self.assertRaises(ValueError): + pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert", + corpus="non-existing corpus") \ No newline at end of file