From 7ef38e322effff8f2297a579d008c7090fbd66be Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:20:14 +0900 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20whisper=20=EC=97=B0=EA=B2=B0?= =?UTF-8?q?=EC=9D=84=20=EC=9C=84=ED=95=9C=20api=20=EC=8A=A4=ED=8E=99=20?= =?UTF-8?q?=EB=B3=80=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/dto/openai_dto.py | 1 + app/routers/chat.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/app/dto/openai_dto.py b/app/dto/openai_dto.py index d5e0805..bb04a4c 100644 --- a/app/dto/openai_dto.py +++ b/app/dto/openai_dto.py @@ -4,6 +4,7 @@ class PromptRequest(BaseModel): prompt: str persona: str + chatType: str class ChatResponse(BaseModel): ness: str diff --git a/app/routers/chat.py b/app/routers/chat.py index e8b35ef..2b110d2 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -38,14 +38,24 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: model_name=config_chat['MODEL_NAME'], # 모델명 openai_api_key=OPENAI_API_KEY # API 키 ) - question = data.prompt + question = data.prompt # 유저로부터 받은 채팅의 내용 + chat_type = data.chatType # 위스퍼 사용 여부 [STT, USER] # description: give NESS's instruction as for case analysis my_template = openai_prompt.Template.case_classify_template - prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + # chat type에 따라 적합한 프롬프트를 삽입 + if chat_type == "STT": + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + elif chat_type == "USER": + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + else: + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question)) + # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) case = int(case) if case == 1: From d65978212370de4cadb86a43427682bc04e641a8 Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:37:21 +0900 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20stt=20=EC=82=AC=EC=9A=A9=EC=8B=9C?= =?UTF-8?q?=20=EC=B6=94=EA=B0=80=ED=95=A0=20=ED=94=84=EB=A1=AC=ED=94=84?= =?UTF-8?q?=ED=8A=B8=20=EC=9E=91=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/prompt/openai_prompt.py | 9 +++++++++ app/routers/chat.py | 12 +++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 317221d..980bfb7 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -41,6 +41,8 @@ class Template: Task: User Chat Classification You are a case classifier integrated in scheduler application. Please analyze User Chat according to the following criteria and return the appropriate case number (1, 2, 3). + {chat_type} + - Case 1: \ The question is a general information request, advice, or simple conversation, and does not require accessing the user's schedule database. - Case 2: \ @@ -71,6 +73,13 @@ class Template: User Chat: {question} Answer: """ + chat_type_stt_template = """ + You should keep in mind that this user's input was written using speech to text technology. + Therefore, there may be inaccuracies in the text due to errors in the STT process. + You need to consider this aspect when performing the given task. + """ + chat_type_user_template = """ + """ case1_template = """ {persona} YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. diff --git a/app/routers/chat.py b/app/routers/chat.py index 2b110d2..00e4a24 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -47,13 +47,14 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # chat type에 따라 적합한 프롬프트를 삽입 if chat_type == "STT": prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + chat_type_prompt = openai_prompt.Template.chat_type_stt_template + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) elif chat_type == "USER": prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + chat_type_prompt = openai_prompt.Template.chat_type_user_template + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) else: - prompt = PromptTemplate.from_template(my_template) - case = chat_model.predict(prompt.format(question=question)) + raise HTTPException(status_code=500, detail="WRONG CHAT TYPE") # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) @@ -68,9 +69,6 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: response = await get_langchain_rag(data) else: - # print("wrong case classification") - # # 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다. - # raise HTTPException(status_code=400, detail="Wrong case classification") response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?" case = "Exception" From 372901467c57d4d2446e9a4e6c58cc2a6e19250c Mon Sep 17 00:00:00 2001 From: uommou Date: Wed, 15 May 2024 17:54:08 +0900 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20whisper=20=EC=82=AC=EC=9A=A9=20?= =?UTF-8?q?=EA=B0=80=EB=8A=A5=ED=95=98=EB=8F=84=EB=A1=9D=20=ED=94=84?= =?UTF-8?q?=EB=A1=AC=ED=94=84=ED=8A=B8=20=EC=9E=91=EC=84=B1=20=EC=99=84?= =?UTF-8?q?=EB=A3=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/prompt/openai_prompt.py | 6 +++++- app/routers/chat.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 980bfb7..9ee23ca 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -82,11 +82,13 @@ class Template: """ case1_template = """ {persona} - YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. + {chat_type} + YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time. User input: {question} """ case2_template = """ {persona} + {chat_type} The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. @@ -118,6 +120,8 @@ class Template: case3_template = """ {persona} + {chat_type} + Current time is {current_time}. Respond to the user considering the current time. When responding to user inputs, it's crucial to adapt your responses to the specified output language, maintaining a consistent and accessible communication style. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Your responses should not only be accurate but also display empathy and understanding of the user's needs. You are equipped with a state-of-the-art RAG (Retrieval-Augmented Generation) technique, enabling you to dynamically pull relevant schedule information from a comprehensive database tailored to the user's specific inquiries. This technique enhances your ability to provide precise, context-aware responses by leveraging real-time data retrieval combined with advanced natural language understanding. diff --git a/app/routers/chat.py b/app/routers/chat.py index 00e4a24..b7c60ae 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -46,27 +46,26 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # chat type에 따라 적합한 프롬프트를 삽입 if chat_type == "STT": - prompt = PromptTemplate.from_template(my_template) chat_type_prompt = openai_prompt.Template.chat_type_stt_template - case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) elif chat_type == "USER": - prompt = PromptTemplate.from_template(my_template) chat_type_prompt = openai_prompt.Template.chat_type_user_template - case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) else: raise HTTPException(status_code=500, detail="WRONG CHAT TYPE") + prompt = PromptTemplate.from_template(my_template) + case = chat_model.predict(prompt.format(question=question, chat_type=chat_type_prompt)) + # 각 케이스에도 chat type에 따라 적합한 프롬프트 삽입 필요 print(case) case = int(case) if case == 1: - response = await get_langchain_normal(data) + response = await get_langchain_normal(data, chat_type_prompt) elif case == 2: - response = await get_langchain_schedule(data) + response = await get_langchain_schedule(data, chat_type_prompt) elif case == 3: - response = await get_langchain_rag(data) + response = await get_langchain_rag(data, chat_type_prompt) else: response = "좀 더 명확한 요구가 필요해요. 다시 한 번 얘기해주실 수 있나요?" @@ -77,7 +76,7 @@ async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse: # case 1 : normal #@router.post("/case/normal") # 테스트용 엔드포인트 -async def get_langchain_normal(data: PromptRequest): # case 1 : normal +async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 : normal print("running case 1") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -95,13 +94,14 @@ async def get_langchain_normal(data: PromptRequest): # case 1 : normal my_template = openai_prompt.Template.case1_template prompt = PromptTemplate.from_template(my_template) - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question)) + current_time = datetime.now() + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response # case 2 : 일정 생성 #@router.post("/case/make_schedule") # 테스트용 엔드포인트 -async def get_langchain_schedule(data: PromptRequest): +async def get_langchain_schedule(data: PromptRequest, chat_type_prompt): print("running case 2") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -118,13 +118,13 @@ async def get_langchain_schedule(data: PromptRequest): prompt = PromptTemplate.from_template(case2_template) current_time = datetime.now() - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time)) + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response # case 3 : rag #@router.post("/case/rag") # 테스트용 엔드포인트 -async def get_langchain_rag(data: PromptRequest): +async def get_langchain_rag(data: PromptRequest, chat_type_prompt): print("running case 3") # description: use langchain config_normal = config['NESS_NORMAL'] @@ -145,6 +145,7 @@ async def get_langchain_rag(data: PromptRequest): case3_template = openai_prompt.Template.case3_template prompt = PromptTemplate.from_template(case3_template) - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule)) + current_time = datetime.now() + response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, schedule=schedule, current_time=current_time, chat_type=chat_type_prompt)) print(response) return response