Skip to content

Commit

Permalink
pmk-gptの基本実装
Browse files Browse the repository at this point in the history
  • Loading branch information
Aoi-Takahashi committed Mar 7, 2024
1 parent f6d25d0 commit 5841281
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 38 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"torch>=2.2.1",
"torchvision>=0.17.1",
"torchaudio>=2.2.1",
"sentencepiece>=0.2.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ requests==2.31.0
# via transformers
safetensors==0.4.2
# via transformers
sentencepiece==0.2.0
# via pm-kun
six==1.16.0
# via python-dateutil
sympy==1.12
Expand Down
2 changes: 2 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ requests==2.31.0
# via transformers
safetensors==0.4.2
# via transformers
sentencepiece==0.2.0
# via pm-kun
six==1.16.0
# via python-dateutil
sympy==1.12
Expand Down
80 changes: 42 additions & 38 deletions src/pm_kun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,53 @@
import datetime
from typing import Optional
from pm_kun.pmgpt.main import generate_response
from pm_kun.screen.main import PmkApp
from pm_kun.task.todo import BaseTask
from pm_kun.type_definition import Priority, Status, ToDo
import uuid


def main() -> None:
print("This is Entory point for pm-kun.")
is_process = True
is_process = False
my_task = BaseTask()
while is_process:
print("1.タスクの追加")
print("2.タスクの削除")
print("3.タスクの更新")
print("4.終了")
select = input("選択してください")
if select == "1":
task = input("タスクを入力してください")
my_Todo: ToDo = {
"id": str(uuid.uuid4()),
"task": "",
"status": Status.未着手,
"period": datetime.date.today(),
"priority": Priority.,
}
my_Todo["id"] = str(uuid.uuid4())
my_Todo["task"] = task
my_task.add_task(my_Todo)
print(f"TODOの数:{len(my_task.todo_list)}")
print(my_task.todo_list)
elif select == "2":
task = input("削除するタスクを入力してください")
for todo in my_task.todo_list:
if task in todo["task"]:
id = todo["id"]
my_task.delete_task(id)
print(my_task.todo_list)
elif select == "3":
cuurent_task = input("更新するタスクを入力してください")
update_task = input("更新後のタスクを入力してください")
for todo in my_task.todo_list:
if cuurent_task in todo["task"]:
id = todo["id"]
my_task.update_task(id, task=update_task)
print(my_task.todo_list)
elif select == "4":
break
print(generate_response("吾輩は猫である"))
# PmkApp().run()
# while is_process:
# print("1.タスクの追加")
# print("2.タスクの削除")
# print("3.タスクの更新")
# print("4.終了")
# select = input("選択してください")
# if select == "1":
# task = input("タスクを入力してください")
# my_Todo: ToDo = {
# "id": str(uuid.uuid4()),
# "task": "",
# "status": Status.未着手,
# "period": datetime.date.today(),
# "priority": Priority.低,
# }
# my_Todo["id"] = str(uuid.uuid4())
# my_Todo["task"] = task
# my_task.add_task(my_Todo)
# print(f"TODOの数:{len(my_task.todo_list)}")
# print(my_task.todo_list)
# elif select == "2":
# task = input("削除するタスクを入力してください")
# for todo in my_task.todo_list:
# if task in todo["task"]:
# id = todo["id"]
# my_task.delete_task(id)
# print(my_task.todo_list)
# elif select == "3":
# cuurent_task = input("更新するタスクを入力してください")
# update_task = input("更新後のタスクを入力してください")
# for todo in my_task.todo_list:
# if cuurent_task in todo["task"]:
# id = todo["id"]
# my_task.update_task(id, task=update_task)
# print(my_task.todo_list)
# elif select == "4":
# break
return None
Empty file added src/pm_kun/pmgpt/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions src/pm_kun/pmgpt/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from transformers import T5Tokenizer, GPT2LMHeadModel, pipeline

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-small")
tokenizer.do_lower_case = True
model = GPT2LMHeadModel.from_pretrained("rinna/japanese-gpt2-small")


# TODO: Transformersのfine-tuningを行う
def generate_response(user_input: str) -> str:
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1,
)

response = generator(
user_input, max_length=100, truncation=True, num_return_sequences=1
)

return response[0]["generated_text"]

0 comments on commit 5841281

Please sign in to comment.