From 58412819318316b77c3894b153d52e0d446e75c9 Mon Sep 17 00:00:00 2001 From: Aoi-Takahashi Date: Thu, 7 Mar 2024 17:48:28 +0900 Subject: [PATCH] =?UTF-8?q?pmk-gpt=E3=81=AE=E5=9F=BA=E6=9C=AC=E5=AE=9F?= =?UTF-8?q?=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + requirements-dev.lock | 2 + requirements.lock | 2 + src/pm_kun/__init__.py | 80 +++++++++++++++++++----------------- src/pm_kun/pmgpt/__init__.py | 0 src/pm_kun/pmgpt/main.py | 22 ++++++++++ 6 files changed, 69 insertions(+), 38 deletions(-) create mode 100644 src/pm_kun/pmgpt/__init__.py create mode 100644 src/pm_kun/pmgpt/main.py diff --git a/pyproject.toml b/pyproject.toml index d79034d..b8b9b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "torch>=2.2.1", "torchvision>=0.17.1", "torchaudio>=2.2.1", + "sentencepiece>=0.2.0", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/requirements-dev.lock b/requirements-dev.lock index bf6aa6a..cd214a4 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -99,6 +99,8 @@ requests==2.31.0 # via transformers safetensors==0.4.2 # via transformers +sentencepiece==0.2.0 + # via pm-kun six==1.16.0 # via python-dateutil sympy==1.12 diff --git a/requirements.lock b/requirements.lock index bf6aa6a..cd214a4 100644 --- a/requirements.lock +++ b/requirements.lock @@ -99,6 +99,8 @@ requests==2.31.0 # via transformers safetensors==0.4.2 # via transformers +sentencepiece==0.2.0 + # via pm-kun six==1.16.0 # via python-dateutil sympy==1.12 diff --git a/src/pm_kun/__init__.py b/src/pm_kun/__init__.py index 09b3b3a..b3ea4b7 100644 --- a/src/pm_kun/__init__.py +++ b/src/pm_kun/__init__.py @@ -1,5 +1,7 @@ import datetime from typing import Optional +from pm_kun.pmgpt.main import generate_response +from pm_kun.screen.main import PmkApp from pm_kun.task.todo import BaseTask from pm_kun.type_definition import Priority, Status, ToDo import uuid @@ -7,43 +9,45 @@ def main() -> None: print("This is Entory point for pm-kun.") - is_process = True + is_process = False my_task = BaseTask() - while is_process: - print("1.タスクの追加") - print("2.タスクの削除") - print("3.タスクの更新") - print("4.終了") - select = input("選択してください") - if select == "1": - task = input("タスクを入力してください") - my_Todo: ToDo = { - "id": str(uuid.uuid4()), - "task": "", - "status": Status.未着手, - "period": datetime.date.today(), - "priority": Priority.低, - } - my_Todo["id"] = str(uuid.uuid4()) - my_Todo["task"] = task - my_task.add_task(my_Todo) - print(f"TODOの数:{len(my_task.todo_list)}") - print(my_task.todo_list) - elif select == "2": - task = input("削除するタスクを入力してください") - for todo in my_task.todo_list: - if task in todo["task"]: - id = todo["id"] - my_task.delete_task(id) - print(my_task.todo_list) - elif select == "3": - cuurent_task = input("更新するタスクを入力してください") - update_task = input("更新後のタスクを入力してください") - for todo in my_task.todo_list: - if cuurent_task in todo["task"]: - id = todo["id"] - my_task.update_task(id, task=update_task) - print(my_task.todo_list) - elif select == "4": - break + print(generate_response("吾輩は猫である")) + # PmkApp().run() + # while is_process: + # print("1.タスクの追加") + # print("2.タスクの削除") + # print("3.タスクの更新") + # print("4.終了") + # select = input("選択してください") + # if select == "1": + # task = input("タスクを入力してください") + # my_Todo: ToDo = { + # "id": str(uuid.uuid4()), + # "task": "", + # "status": Status.未着手, + # "period": datetime.date.today(), + # "priority": Priority.低, + # } + # my_Todo["id"] = str(uuid.uuid4()) + # my_Todo["task"] = task + # my_task.add_task(my_Todo) + # print(f"TODOの数:{len(my_task.todo_list)}") + # print(my_task.todo_list) + # elif select == "2": + # task = input("削除するタスクを入力してください") + # for todo in my_task.todo_list: + # if task in todo["task"]: + # id = todo["id"] + # my_task.delete_task(id) + # print(my_task.todo_list) + # elif select == "3": + # cuurent_task = input("更新するタスクを入力してください") + # update_task = input("更新後のタスクを入力してください") + # for todo in my_task.todo_list: + # if cuurent_task in todo["task"]: + # id = todo["id"] + # my_task.update_task(id, task=update_task) + # print(my_task.todo_list) + # elif select == "4": + # break return None diff --git a/src/pm_kun/pmgpt/__init__.py b/src/pm_kun/pmgpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pm_kun/pmgpt/main.py b/src/pm_kun/pmgpt/main.py new file mode 100644 index 0000000..d111bd6 --- /dev/null +++ b/src/pm_kun/pmgpt/main.py @@ -0,0 +1,22 @@ +import torch +from transformers import T5Tokenizer, GPT2LMHeadModel, pipeline + +tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small") +tokenizer.do_lower_case = True +model = GPT2LMHeadModel.from_pretrained("rinna/japanese-gpt2-small") + + +# TODO: Transformersのfine-tuningを行う +def generate_response(user_input: str) -> str: + generator = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + device=0 if torch.cuda.is_available() else -1, + ) + + response = generator( + user_input, max_length=100, truncation=True, num_return_sequences=1 + ) + + return response[0]["generated_text"]