diff --git a/docs/api/chat.rst b/docs/api/chat.rst new file mode 100644 index 000000000..3d0d9d2e2 --- /dev/null +++ b/docs/api/chat.rst @@ -0,0 +1,7 @@ +.. currentmodule:: pythainlp.chat + +pythainlp.chat +============== + +.. autoclass:: ChatBotModel + :members: \ No newline at end of file diff --git a/docs/api/generate.rst b/docs/api/generate.rst index 02459dfc3..910bba27d 100644 --- a/docs/api/generate.rst +++ b/docs/api/generate.rst @@ -13,4 +13,6 @@ Modules :members: .. autoclass:: Trigram :members: -.. autofunction:: pythainlp.generate.thai2fit.gen_sentence \ No newline at end of file +.. autofunction:: pythainlp.generate.thai2fit.gen_sentence +.. autoclass:: pythainlp.generate.wangchanglm.WangChanGLM + :members: \ No newline at end of file diff --git a/docs/notes/installation.rst b/docs/notes/installation.rst index 92bbd436e..b8d596482 100644 --- a/docs/notes/installation.rst +++ b/docs/notes/installation.rst @@ -36,6 +36,7 @@ where ``extras`` can be - ``transformers_ud`` (to support transformers_ud engine) - ``dependency_parsing`` (to support dependency parsing with all engine) - ``coreference_resolution`` (to support coreference esolution with all engine) + - ``wangchanglm`` (to support wangchanglm model) - ``wsd`` (to support pythainlp.wsd) - ``full`` (install everything) diff --git a/notebooks/test-chat.ipynb b/notebooks/test-chat.ipynb new file mode 100644 index 000000000..d3c64f3c1 --- /dev/null +++ b/notebooks/test-chat.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3ad128a6-2959-431f-b5ff-d9e15761c9c0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from pythainlp.chat.core import ChatBotModel\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "35760aec-f47a-4d33-ad1c-a8230194180c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chatbot = ChatBotModel()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "99129184-3a9f-4871-bfb9-ce611e80ff55", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting ds_accelerator to cuda (auto detect)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "54dd6a2c6afa41959bfb11ec98b30562", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/98 [00:00: ขอวิธีทำข้าวผัดหน่อย\\n: ')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.instruct_generate(instruct=\"ขอวิธีทำข้าวผัดหน่อย\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d4fbd1e6-8a41-4b46-aa12-a23f8df9bcb0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('ข้าวผัดน้ําพริกลงเรือ', ': ขอวิธีทำข้าวผัดหน่อย\\n: ')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.instruct_generate(instruct=\"ขอวิธีทำข้าวผัดหน่อย\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fe71b834-4f12-406e-8a74-41829d8a7d9d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('เป้าหมายของคุณคือการลดน้ําหนักหรือไม่?', ': ขอลดน้ำหนัก\\n: ')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.instruct_generate(instruct=\"ขอลดน้ำหนัก\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2cd5063d-21b6-40fb-8e4e-c54fb07ac613", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('ลดน้ําหนักให้ได้ผล ต้องทําอย่างค่อยเป็นค่อยไป ปรับเปลี่ยนพฤติกรรมการกินอาหาร ออกกําลังกายอย่างสม่ําเสมอ และพักผ่อนให้เพียงพอ ที่สําคัญควรหลีกเลี่ยงอาหารที่มีแคลอรี่สูง เช่น อาหารทอด อาหารมัน อาหารที่มีน้ําตาลสูง และเครื่องดื่มแอลกอฮอล์',\n", + " ': ขอวิธีลดน้ำหนัก\\n: ')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.instruct_generate(instruct=\"ขอวิธีลดน้ำหนัก\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5b54b24-59b8-400e-89ff-0b8b67dce71f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pythainlp/chat/__init__.py b/pythainlp/chat/__init__.py new file mode 100644 index 000000000..8c594795c --- /dev/null +++ b/pythainlp/chat/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016-2023 PyThaiNLP Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +pythainlp.chat +""" + +__all__ = ["ChatBotModel"] + +from pythainlp.chat.core import ChatBotModel diff --git a/pythainlp/chat/core.py b/pythainlp/chat/core.py new file mode 100644 index 000000000..8eed4685e --- /dev/null +++ b/pythainlp/chat/core.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016-2023 PyThaiNLP Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +class ChatBotModel: + def __init__(self): + """ + Chat with AI generation + """ + self.history = [] + def reset_chat(self): + """ + Reset chat by clean history + """ + self.history = [] + def load_model( + self, + model_name:str="wangchanglm", + return_dict:bool=True, + load_in_8bit:bool=False, + device:str="cuda", + torch_dtype=torch.float16, + offload_folder:str="./", + low_cpu_mem_usage:bool=True + ): + """ + Load model + + :param str model_name: Model name (Now, we support wangchanglm only) + :param bool return_dict: return_dict + :param bool load_in_8bit: load model in 8bit + :param str device: device (cpu, cuda or other) + :param torch_dtype torch_dtype: torch_dtype + :param str offload_folder: offload folder + :param bool low_cpu_mem_usage: low cpu mem usage + """ + if model_name == "wangchanglm": + from pythainlp.generate.wangchanglm import WangChanGLM + self.model = WangChanGLM() + self.model.load_model( + model_path="pythainlp/wangchanglm-7.5B-sft-en-sharded", + return_dict=return_dict, + load_in_8bit=load_in_8bit, + offload_folder=offload_folder, + device=device, + torch_dtype=torch_dtype, + low_cpu_mem_usage=low_cpu_mem_usage + ) + else: + raise NotImplementedError(f"We doesn't support {model_name}.") + def chat(self, text:str)->str: + """ + Chatbot + + :param str text: text for asking chatbot. + :return: the answer from chatbot. + :rtype: str + """ + _temp="" + if self.history!=[]: + for h,b in self.history: + _temp+=self.model.PROMPT_DICT['prompt_chatbot'].format_map({"human":h,"bot":b})+self.model.stop_token + _temp+=self.model.PROMPT_DICT['prompt_chatbot'].format_map({"human":human,"bot":""}) + _bot = self.model.gen_instruct(_temp) + self.history.append((text,_bot)) + return _bot diff --git a/pythainlp/generate/wangchanglm.py b/pythainlp/generate/wangchanglm.py new file mode 100644 index 000000000..77e0043e1 --- /dev/null +++ b/pythainlp/generate/wangchanglm.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016-2023 PyThaiNLP Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import torch + + +class WangChanGLM: + def __init__(self): + self.exclude_pattern = re.compile(r'[^ก-๙]+') + self.stop_token = "\n" + self.PROMPT_DICT = { + "prompt_input": ( + ": {input}\n: {instruction}\n: " + ), + "prompt_no_input": ( + ": {instruction}\n: " + ), + "prompt_chatbot": ( + ": {human}\n: {bot}" + ), + } + def is_exclude(self, text:str)->bool: + return bool(self.exclude_pattern.search(text)) + def load_model( + self, + model_path:str="pythainlp/wangchanglm-7.5B-sft-en-sharded", + return_dict:bool=True, + load_in_8bit:bool=False, + device:str="cuda", + torch_dtype=torch.float16, + offload_folder:str="./", + low_cpu_mem_usage:bool=True + ): + """ + Load model + + :param str model_path: Model path + :param bool return_dict: return_dict + :param bool load_in_8bit: load model in 8bit + :param str device: device (cpu, cuda or other) + :param torch_dtype torch_dtype: torch_dtype + :param str offload_folder: offload folder + :param bool low_cpu_mem_usage: low cpu mem usage + """ + import pandas as pd + from transformers import AutoModelForCausalLM, AutoTokenizer + self.device = device + self.torch_dtype = torch_dtype + self.model_path = model_path + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + return_dict=return_dict, + load_in_8bit=load_in_8bit, + device_map=device, + torch_dtype=torch_dtype, + offload_folder=offload_folder, + low_cpu_mem_usage=low_cpu_mem_usage + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + self.df = pd.DataFrame(self.tokenizer.vocab.items(), columns=['text', 'idx']) + self.df['is_exclude'] = self.df.text.map(self.is_exclude) + self.exclude_ids = self.df[self.df.is_exclude==True].idx.tolist() + def gen_instruct( + self, + text:str, + max_new_tokens:int=512, + top_p:float=0.95, + temperature:float=0.9, + top_k:int=50, + no_repeat_ngram_size:int=2, + typical_p:float=1., + thai_only:bool=True, + skip_special_tokens:bool=True + ): + """ + Generate Instruct + + :param str text: text + :param int max_new_tokens: max new tokens + :param float top_p: Top p + :param float temperature: temperature + :param int top_k: Top k + :param int no_repeat_ngram_size: no repeat ngram size + :param float typical_p: typical p + :param bool thai_only: Thai only + :param bool skip_special_tokens: skip special tokens + :return: the answer from Instruct. + :rtype: str + """ + batch = self.tokenizer(text, return_tensors="pt") + with torch.autocast(device_type=self.device, dtype=self.torch_dtype): + if thai_only: + output_tokens = self.model.generate( + input_ids=batch["input_ids"], + max_new_tokens=max_new_tokens, # 512 + begin_suppress_tokens = self.exclude_ids, + no_repeat_ngram_size=no_repeat_ngram_size, + #oasst k50 + top_k=top_k, + top_p=top_p, # 0.95 + typical_p=typical_p, + temperature=temperature, # 0.9 + ) + else: + output_tokens = self.model.generate( + input_ids=batch["input_ids"], + max_new_tokens=max_new_tokens, # 512 + no_repeat_ngram_size=no_repeat_ngram_size, + #oasst k50 + top_k=top_k, + top_p=top_p, # 0.95 + typical_p=typical_p, + temperature=temperature, # 0.9 + ) + return self.tokenizer.decode(output_tokens[0][len(batch["input_ids"][0]):], skip_special_tokens=skip_special_tokens) + def instruct_generate( + self, + instruct: str, + context: str = None, + max_new_tokens=512, + temperature: float =0.9, + top_p: float = 0.95, + top_k:int=50, + no_repeat_ngram_size:int=2, + typical_p:float=1, + thai_only:bool=True, + skip_special_tokens:bool=True + ): + """ + Generate Instruct + + :param str instruct: Instruct + :param str context: context + :param int max_new_tokens: max new tokens + :param float top_p: Top p + :param float temperature: temperature + :param int top_k: Top k + :param int no_repeat_ngram_size: no repeat ngram size + :param float typical_p: typical p + :param bool thai_only: Thai only + :param bool skip_special_tokens: skip special tokens + :return: the answer from Instruct. + :rtype: str + """ + if context == None or context=="": + prompt = self.PROMPT_DICT['prompt_no_input'].format_map( + {'instruction': instruct, 'input': ''} + ) + else: + prompt = self.PROMPT_DICT['prompt_input'].format_map( + {'instruction': instruct, 'input': context} + ) + result = self.gen_instruct( + prompt, + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + temperature=temperature, + no_repeat_ngram_size=no_repeat_ngram_size, + typical_p=typical_p, + thai_only=thai_only, + skip_special_tokens=skip_special_tokens + ) + return result diff --git a/setup.py b/setup.py index 16bb96c1a..2b750c0a9 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,11 @@ "word_approximation":{ "panphon>=0.20.0" }, + "wangchanglm": [ + "transformers>=4.6.0", + "sentencepiece>=0.1.91", + "pandas>=0.24" + ], "wsd":{ "sentence-transformers>=2.2.2" },