diff --git a/pythainlp/tag/pos_tag.py b/pythainlp/tag/pos_tag.py index abdfe5fc2..d415c0805 100644 --- a/pythainlp/tag/pos_tag.py +++ b/pythainlp/tag/pos_tag.py @@ -180,12 +180,13 @@ def pos_tag_sents( def pos_tag_transformers( - words: str, engine: str = "bert-base-th-cased-blackboard" + words: str, engine: str = "phayathai" ): """ - "wangchanberta-ud-thai-pud-upos", - "mdeberta-v3-ud-thai-pud-upos", - "bert-base-th-cased-blackboard", + "wangchanberta", + "mdeberta", + "bert_cased", + "phayathai", """ @@ -199,17 +200,21 @@ def pos_tag_transformers( if not words: return [] - if engine == "wangchanberta-ud-thai-pud-upos": + if engine == "wangchanberta": model = AutoModelForTokenClassification.from_pretrained( "Pavarissy/wangchanberta-ud-thai-pud-upos") tokenizer = AutoTokenizer.from_pretrained("Pavarissy/wangchanberta-ud-thai-pud-upos") - elif engine == "mdeberta-v3-ud-thai-pud-upos": + elif engine == "mdeberta": model = AutoModelForTokenClassification.from_pretrained( "Pavarissy/mdeberta-v3-ud-thai-pud-upos") tokenizer = AutoTokenizer.from_pretrained("Pavarissy/mdeberta-v3-ud-thai-pud-upos") - elif engine == "bert-base-th-cased-blackboard": + elif engine == "bert_cased": model = AutoModelForTokenClassification.from_pretrained("lunarlist/pos_thai") tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai") + elif engine == "phayathai": + model = AutoModelForTokenClassification.from_pretrained( + "lunarlist/pos_thai_phayathai") + tokenizer = AutoTokenizer.from_pretrained("lunarlist/pos_thai_phayathai") else: raise ValueError( "pos_tag_transformers not support {0} engine.".format( diff --git a/tests/test_tag.py b/tests/test_tag.py index 8d1755b18..95eff9a45 100644 --- a/tests/test_tag.py +++ b/tests/test_tag.py @@ -367,10 +367,12 @@ def test_NNER_class(self): def test_pos_tag_transformers(self): self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert-base-th-cased-blackboard")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="bert_cased")) self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta-v3-ud-thai-pud-upos")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="mdeberta")) self.assertIsNotNone(pos_tag_transformers( - words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta-ud-thai-pud-upos")) + words="แมวทำอะไรตอนห้าโมงเช้า", engine="wangchanberta")) + self.assertIsNotNone(pos_tag_transformers( + words="แมวทำอะไรตอนห้าโมงเช้า", engine="phayathai")) with self.assertRaises(ValueError): pos_tag_transformers(words="แมวทำอะไรตอนห้าโมงเช้า", engine="non-existing-engine") \ No newline at end of file