Skip to content

Commit

Permalink
Merge pull request #79 from studio-recoding/feat/chat-completion
Browse files Browse the repository at this point in the history
feat: chat completion
  • Loading branch information
uommou authored Jul 15, 2024
2 parents 9717190 + 5ef5add commit 0f49174
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 52 deletions.
29 changes: 29 additions & 0 deletions app/database/connect_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,32 @@ def fetch_category_classification_data(member_id):
return result
finally:
connection.close()

def fetch_previous_conversations(member_id):
connection = get_rds_connection()
try:
with connection.cursor() as cursor:
sql = """
SELECT c.text, c.chat_type
FROM chat c
WHERE c.member_id = %s
ORDER BY c.created_date DESC
LIMIT 5
"""
cursor.execute(sql, (member_id,))
result = cursor.fetchall()
formatted_result = []
for chat in result:
# chat_type에 따라 'system' 또는 'human'으로 설정
sender_type = 'ai' if chat['chat_type'] == 'AI' else 'human'
# 메시지 포맷팅
formatted_result.append((sender_type, chat['text']))
return formatted_result[::-1] # 최신 메시지가 마지막에 오도록 순서를 뒤집습니다.
return result
except Exception as e:
print("Error fetching conversations:", e, file=sys.stderr)
return [] # 오류 발생 시 빈 리스트 반환
finally:
connection.close()


10 changes: 1 addition & 9 deletions app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class Template:
{persona}
{chat_type}
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time.
User input: {question}
"""

case2_template = """
Expand Down Expand Up @@ -164,13 +163,9 @@ class Template:
"id": 2,
"color": "#00FF00"
}},
"search keyword": "top Italian wines"
"searchKeyword": "top Italian wines"
}}
]
User input: {question}
Response to user:
"""

case3_template = """
Expand All @@ -186,9 +181,6 @@ class Template:
Response: Good morning! You have two meetings scheduled for tomorrow: the project status update at 10 AM and the client discussion at 3 PM. Would you like reminders for these, or is there anything else I can assist you with?
Now respond to following User input, based on RAG Retrieval.
User input: {question},
RAG Retrieval: {schedule}
Response:
"""

case4_template = """
Expand Down
171 changes: 128 additions & 43 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException, status
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain.schema import SystemMessage, HumanMessage, AIMessage
from datetime import datetime

from app.database.connect_rds import fetch_category_classification_data
from app.database.connect_rds import fetch_category_classification_data, fetch_previous_conversations
from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse
from app.prompt import openai_prompt, persona_prompt

Expand Down Expand Up @@ -98,67 +99,121 @@ async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
# 유저 입력
question = data.prompt
# 페르소나 프롬프트
persona = data.persona
user_persona_prompt = persona_prompt.Template.from_persona(persona)

# description: give NESS's ideal instruction as template
my_template = openai_prompt.Template.case1_template

prompt = PromptTemplate.from_template(my_template)
# 현재 시각 설정
seoul_timezone = pytz.timezone('Asia/Seoul')
current_time = datetime.now(seoul_timezone)
print(f'current time: {current_time}')
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

# 이전 대화내역 불러오기
previous_conversations = fetch_previous_conversations(data.member_id)
print(previous_conversations)

# case 1 프롬프트
case1_template = openai_prompt.Template.case1_template
#prompt = PromptTemplate.from_template(my_template)
# system, human, ai
chat_prompt = ChatPromptTemplate.from_messages(
previous_conversations + [
("system", case1_template),
("human", "{question}")
]
)

# 프롬프트와 모델을 chaining
chain = chat_prompt | chat_model

#response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
response = chain.invoke({
"persona": user_persona_prompt,
"output_language": "Korean",
"current_time": current_time,
"chat_type": chat_type_prompt,
"question": question
})
print(response.content)
return response.content

# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: PromptRequest, chat_type_prompt):
try:
print("running case 2")
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)

member_id = data.member_id

# 카테고리 가져오기
categories = fetch_category_classification_data(member_id)
# 카테고리 데이터를 딕셔너리 형태로 변환
categories_dict = [
{"name": category['name'], "id": category['category_id'], "color": category['color']}
for category in categories
]

# description: use langchain
config_normal = config['NESS_NORMAL']
# 시간 가져오기
seoul_timezone = pytz.timezone('Asia/Seoul')
current_time = datetime.now(seoul_timezone)
print(f'current time: {current_time}')

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
# 이전 대화 내역 가져오기
previous_conversations = fetch_previous_conversations(member_id)
print(previous_conversations)

question = data.prompt
persona = data.persona
user_persona_prompt = persona_prompt.Template.from_persona(persona)
case2_template = openai_prompt.Template.case2_template

prompt = PromptTemplate.from_template(case2_template)
seoul_timezone = pytz.timezone('Asia/Seoul')
current_time = datetime.now(seoul_timezone)
print(f'current time: {current_time}')

# OpenAI 프롬프트에 데이터 통합
response = chat_model.predict(
prompt.format(
persona=user_persona_prompt,
output_language="Korean",
question=question,
current_time=current_time,
chat_type=chat_type_prompt,
categories=categories_dict # 카테고리 데이터를 프롬프트에 포함
)
# case 2 프롬프트
case2_template = openai_prompt.Template.case2_template
# prompt = PromptTemplate.from_template(case2_template)
# system, human, ai
chat_prompt = ChatPromptTemplate.from_messages(
previous_conversations + [
("system", case2_template),
("human", "User input: {question}")
]
)

print(response)
return response

# # OpenAI 프롬프트에 데이터 통합
# response = chat_model.predict(
# prompt.format(
# persona=user_persona_prompt,
# output_language="Korean",
# question=question,
# current_time=current_time,
# chat_type=chat_type_prompt,
# categories=categories_dict # 카테고리 데이터를 프롬프트에 포함
# )
# )

# 프롬프트와 모델을 chaining
chain = chat_prompt | chat_model

# response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
response = chain.invoke({
"persona": user_persona_prompt,
"output_language": "Korean",
"current_time": current_time,
"chat_type": chat_type_prompt,
"categories": categories_dict,
"question": question
})

print(response.content)
return response.content

except Exception as e:
print(e)
Expand All @@ -176,24 +231,54 @@ async def get_langchain_rag(data: PromptRequest, chat_type_prompt):
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)

member_id = data.member_id
question = data.prompt
persona = data.persona
user_persona_prompt = persona_prompt.Template.from_persona(persona)

# vectordb.search_db_query를 비동기적으로 호출합니다.
# 시간 가져오기
seoul_timezone = pytz.timezone('Asia/Seoul')
current_time = datetime.now(seoul_timezone)
print(f'current time: {current_time}')

# 관련 스케줄 가져오기
schedule = await vectordb.search_db_query(member_id, question) # vector db에서 검색

# 이전 대화 내역 불러오기
previous_conversations = fetch_previous_conversations(member_id)
print(previous_conversations)

# description: give NESS's ideal instruction as template
case3_template = openai_prompt.Template.case3_template

prompt = PromptTemplate.from_template(case3_template)
seoul_timezone = pytz.timezone('Asia/Seoul')
current_time = datetime.now(seoul_timezone)
print(f'current time: {current_time}')
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response
case3_user = """
User input: {question},
RAG Retrieval: {schedule}
Response:
"""
#prompt = PromptTemplate.from_template(case3_template)
# system, human, ai
chat_prompt = ChatPromptTemplate.from_messages(
previous_conversations + [
("system", case3_template),
("human", case3_user)
]
)

#response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt))
# 프롬프트와 모델을 chaining
chain = chat_prompt | chat_model

response = chain.invoke({
"persona": user_persona_prompt,
"output_language": "Korean",
"current_time": current_time,
"chat_type": chat_type_prompt,
"schedule": schedule,
"question": question
})
print(response.content)
return response.content

# case 4 : delete schedule
async def delete_schedule(data: PromptRequest, chat_type_prompt):
Expand Down

0 comments on commit 0f49174

Please sign in to comment.