Skip to content

Commit

Permalink
feat: whisper 사용 가능하도록 프롬프트 작성 완료
Browse files Browse the repository at this point in the history
  • Loading branch information
uommou committed May 15, 2024
1 parent d659782 commit 3729014
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
6 changes: 5 additions & 1 deletion app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class Template:
"""
case1_template = """
{persona}
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
{chat_type}
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time.
User input: {question}
"""
case2_template = """
{persona}
{chat_type}
The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform:
1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
Expand Down Expand Up @@ -118,6 +120,8 @@ class Template:

case3_template = """
{persona}
{chat_type}
Current time is {current_time}. Respond to the user considering the current time.
When responding to user inputs, it's crucial to adapt your responses to the specified output language, maintaining a consistent and accessible communication style. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Your responses should not only be accurate but also display empathy and understanding of the user's needs.
You are equipped with a state-of-the-art RAG (Retrieval-Augmented Generation) technique, enabling you to dynamically pull relevant schedule information from a comprehensive database tailored to the user's specific inquiries. This technique enhances your ability to provide precise, context-aware responses by leveraging real-time data retrieval combined with advanced natural language understanding.
Expand Down
27 changes: 14 additions & 13 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,26 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:

# chat type에 따라 적합한 프롬프트를 삽입
if chat_type == "STT":
prompt = PromptTemplate.from_template(my_template)
chat_type_prompt = openai_prompt.Template.chat_type_stt_template
case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt))
elif chat_type == "USER":
prompt = PromptTemplate.from_template(my_template)
chat_type_prompt = openai_prompt.Template.chat_type_user_template
case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt))
else:
raise HTTPException(status_code=500, detail="WRONG CHAT TYPE")

prompt = PromptTemplate.from_template(my_template)
case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt))

# 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요
print(case)
case = int(case)
if case == 1:
response = await get_langchain_normal(data)
response = await get_langchain_normal(data, chat_type_prompt)

elif case == 2:
response = await get_langchain_schedule(data)
response = await get_langchain_schedule(data, chat_type_prompt)

elif case == 3:
response = await get_langchain_rag(data)
response = await get_langchain_rag(data, chat_type_prompt)

else:
response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?"
Expand All @@ -77,7 +76,7 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:

# case 1 : normal
#@router.post("/case/normal") # 테스트용 엔드포인트
async def get_langchain_normal(data: PromptRequest): # case 1 : normal
async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 : normal
print("running case 1")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -95,13 +94,14 @@ async def get_langchain_normal(data: PromptRequest): # case 1 : normal
my_template = openai_prompt.Template.case1_template

prompt = PromptTemplate.from_template(my_template)
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question))
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: PromptRequest):
async def get_langchain_schedule(data: PromptRequest, chat_type_prompt):
print("running case 2")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -118,13 +118,13 @@ async def get_langchain_schedule(data: PromptRequest):

prompt = PromptTemplate.from_template(case2_template)
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time))
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

# case 3 : rag
#@router.post("/case/rag") # 테스트용 엔드포인트
async def get_langchain_rag(data: PromptRequest):
async def get_langchain_rag(data: PromptRequest, chat_type_prompt):
print("running case 3")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -145,6 +145,7 @@ async def get_langchain_rag(data: PromptRequest):
case3_template = openai_prompt.Template.case3_template

prompt = PromptTemplate.from_template(case3_template)
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule))
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

0 comments on commit 3729014

Please sign in to comment.