Skip to content

Commit

Permalink
Merge pull request #56 from studio-recoding/feat/whisper
Browse files Browse the repository at this point in the history
feat: whisper 사용 가능하도록 api 스펙 변경 및 프롬프트 추가
  • Loading branch information
uommou authored May 15, 2024
2 parents b9a1a48 + 3729014 commit ead3aa3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
1 change: 1 addition & 0 deletions app/dto/openai_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class PromptRequest(BaseModel):
prompt: str
persona: str
chatType: str

class ChatResponse(BaseModel):
ness: str
Expand Down
15 changes: 14 additions & 1 deletion app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Template:
Task: User Chat Classification
You are a case classifier integrated in scheduler application.
Please analyze User Chat according to the following criteria and return the appropriate case number (1, 2, 3).
{chat_type}
- Case 1: \
The question is a general information request, advice, or simple conversation, and does not require accessing the user's schedule database.
- Case 2: \
Expand Down Expand Up @@ -71,13 +73,22 @@ class Template:
User Chat: {question}
Answer:
"""
chat_type_stt_template = """
You should keep in mind that this user's input was written using speech to text technology.
Therefore, there may be inaccuracies in the text due to errors in the STT process.
You need to consider this aspect when performing the given task.
"""
chat_type_user_template = """
"""
case1_template = """
{persona}
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
{chat_type}
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time.
User input: {question}
"""
case2_template = """
{persona}
{chat_type}
The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform:
1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
Expand Down Expand Up @@ -109,6 +120,8 @@ class Template:

case3_template = """
{persona}
{chat_type}
Current time is {current_time}. Respond to the user considering the current time.
When responding to user inputs, it's crucial to adapt your responses to the specified output language, maintaining a consistent and accessible communication style. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Your responses should not only be accurate but also display empathy and understanding of the user's needs.
You are equipped with a state-of-the-art RAG (Retrieval-Augmented Generation) technique, enabling you to dynamically pull relevant schedule information from a comprehensive database tailored to the user's specific inquiries. This technique enhances your ability to provide precise, context-aware responses by leveraging real-time data retrieval combined with advanced natural language understanding.
Expand Down
37 changes: 23 additions & 14 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,36 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:
model_name=config_chat['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
question = data.prompt # 유저로부터 받은 채팅의 내용
chat_type = data.chatType # 위스퍼 사용 여부 [STT, USER]

# description: give NESS's instruction as for case analysis
my_template = openai_prompt.Template.case_classify_template

# chat type에 따라 적합한 프롬프트를 삽입
if chat_type == "STT":
chat_type_prompt = openai_prompt.Template.chat_type_stt_template
elif chat_type == "USER":
chat_type_prompt = openai_prompt.Template.chat_type_user_template
else:
raise HTTPException(status_code=500, detail="WRONG CHAT TYPE")

prompt = PromptTemplate.from_template(my_template)
case = chat_model.predict(prompt.format(question=question))
case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt))

# 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요
print(case)
case = int(case)
if case == 1:
response = await get_langchain_normal(data)
response = await get_langchain_normal(data, chat_type_prompt)

elif case == 2:
response = await get_langchain_schedule(data)
response = await get_langchain_schedule(data, chat_type_prompt)

elif case == 3:
response = await get_langchain_rag(data)
response = await get_langchain_rag(data, chat_type_prompt)

else:
# print("wrong case classification")
# # 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다.
# raise HTTPException(status_code=400, detail="Wrong case classification")
response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?"
case = "Exception"

Expand All @@ -69,7 +76,7 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:

# case 1 : normal
#@router.post("/case/normal") # 테스트용 엔드포인트
async def get_langchain_normal(data: PromptRequest): # case 1 : normal
async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 : normal
print("running case 1")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -87,13 +94,14 @@ async def get_langchain_normal(data: PromptRequest): # case 1 : normal
my_template = openai_prompt.Template.case1_template

prompt = PromptTemplate.from_template(my_template)
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question))
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: PromptRequest):
async def get_langchain_schedule(data: PromptRequest, chat_type_prompt):
print("running case 2")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -110,13 +118,13 @@ async def get_langchain_schedule(data: PromptRequest):

prompt = PromptTemplate.from_template(case2_template)
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time))
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

# case 3 : rag
#@router.post("/case/rag") # 테스트용 엔드포인트
async def get_langchain_rag(data: PromptRequest):
async def get_langchain_rag(data: PromptRequest, chat_type_prompt):
print("running case 3")
# description: use langchain
config_normal = config['NESS_NORMAL']
Expand All @@ -137,6 +145,7 @@ async def get_langchain_rag(data: PromptRequest):
case3_template = openai_prompt.Template.case3_template

prompt = PromptTemplate.from_template(case3_template)
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule))
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response

0 comments on commit ead3aa3

Please sign in to comment.