Skip to content

Commit c723048

Browse files
authored
Merge pull request #41 from studio-recoding/fix/case2
[fix] case 넘버와 함께 반환
2 parents 4dd3a88 + b8f9e39 commit c723048

File tree

4 files changed

+69
-42
lines changed

4 files changed

+69
-42
lines changed

app/dto/openai_dto.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ class PromptRequest(BaseModel):
55
prompt: str
66

77
class ChatResponse(BaseModel):
8-
ness: str
8+
ness: str
9+
10+
class ChatCaseResponse(BaseModel):
11+
ness: str
12+
case: int

app/prompt/openai_config.ini

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
[NESS_NORMAL]
22
TEMPERATURE = 0
33
MAX_TOKENS = 2048
4+
MODEL_NAME = gpt-3.5-turbo-1106
5+
6+
[NESS_CASE]
7+
TEMPERATURE = 0
8+
MAX_TOKENS = 2048
9+
MODEL_NAME = gpt-4
10+
11+
[NESS_RECOMMENDATION]
12+
TEMPERATURE = 0
13+
MAX_TOKENS = 2048
414
MODEL_NAME = gpt-3.5-turbo-1106

app/routers/chat.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain_community.chat_models import ChatOpenAI
88
from langchain_core.prompts import PromptTemplate
99

10-
from app.dto.openai_dto import PromptRequest, ChatResponse
10+
from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse
1111
from app.prompt import openai_prompt
1212

1313
import app.database.chroma_db as vectordb
@@ -26,15 +26,15 @@
2626
config = configparser.ConfigParser()
2727
config.read(CONFIG_FILE_PATH)
2828

29-
@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatResponse)
30-
async def get_langchain_case(data: PromptRequest) -> ChatResponse:
29+
@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatCaseResponse)
30+
async def get_langchain_case(data: PromptRequest) -> ChatCaseResponse:
3131
# description: use langchain
3232

33-
config_normal = config['NESS_NORMAL']
33+
config_chat = config['NESS_CHAT']
3434

35-
chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
36-
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
37-
model_name=config_normal['MODEL_NAME'], # 모델명
35+
chat_model = ChatOpenAI(temperature=config_chat['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
36+
max_tokens=config_chat['MAX_TOKENS'], # 최대 토큰수
37+
model_name=config_chat['MODEL_NAME'], # 모델명
3838
openai_api_key=OPENAI_API_KEY # API 키
3939
)
4040
question = data.prompt
@@ -48,28 +48,32 @@ async def get_langchain_case(data: PromptRequest) -> ChatResponse:
4848
print(case)
4949
case = int(case)
5050
if case == 1:
51-
return await get_langchain_normal(data)
51+
response = await get_langchain_normal(data)
5252

5353
elif case == 2:
54-
return await get_langchain_schedule(data)
54+
response = await get_langchain_schedule(data)
5555

5656
elif case == 3:
57-
return await get_langchain_rag(data)
57+
response = await get_langchain_rag(data)
5858

5959
else:
6060
print("wrong case classification")
6161
# 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다.
6262
raise HTTPException(status_code=400, detail="Wrong case classification")
6363

64+
return ChatCaseResponse(ness=response, case=case)
65+
6466

6567
# case 1 : normal
6668
#@router.post("/case/normal") # 테스트용 엔드포인트
67-
async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 : normal
69+
async def get_langchain_normal(data: PromptRequest): # case 1 : normal
6870
print("running case 1")
6971
# description: use langchain
70-
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
71-
max_tokens=2048, # 최대 토큰수
72-
model_name='gpt-3.5-turbo-1106', # 모델명
72+
config_normal = config['NESS_NORMAL']
73+
74+
chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
75+
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
76+
model_name=config_normal['MODEL_NAME'], # 모델명
7377
openai_api_key=OPENAI_API_KEY # API 키
7478
)
7579
question = data.prompt
@@ -82,16 +86,18 @@ async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 :
8286
prompt = PromptTemplate.from_template(my_template)
8387
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
8488
print(response)
85-
return ChatResponse(ness=response)
89+
return response
8690

8791
# case 2 : 일정 생성
8892
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
89-
async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
93+
async def get_langchain_schedule(data: PromptRequest):
9094
print("running case 2")
9195
# description: use langchain
92-
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
93-
max_tokens=2048, # 최대 토큰수
94-
model_name='gpt-3.5-turbo-1106', # 모델명
96+
config_normal = config['NESS_NORMAL']
97+
98+
chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
99+
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
100+
model_name=config_normal['MODEL_NAME'], # 모델명
95101
openai_api_key=OPENAI_API_KEY # API 키
96102
)
97103
question = data.prompt
@@ -100,16 +106,18 @@ async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
100106
prompt = PromptTemplate.from_template(case2_template)
101107
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
102108
print(response)
103-
return ChatResponse(ness=response)
109+
return response
104110

105111
# case 3 : rag
106112
#@router.post("/case/rag") # 테스트용 엔드포인트
107-
async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
113+
async def get_langchain_rag(data: PromptRequest):
108114
print("running case 3")
109115
# description: use langchain
110-
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
111-
max_tokens=2048, # 최대 토큰수
112-
model_name='gpt-3.5-turbo-1106', # 모델명
116+
config_normal = config['NESS_NORMAL']
117+
118+
chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
119+
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
120+
model_name=config_normal['MODEL_NAME'], # 모델명
113121
openai_api_key=OPENAI_API_KEY # API 키
114122
)
115123
question = data.prompt
@@ -123,4 +131,4 @@ async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
123131
prompt = PromptTemplate.from_template(case3_template)
124132
response = chat_model.predict(prompt.format(output_language="Korean", question=question, schedule=schedule))
125133
print(response)
126-
return ChatResponse(ness=response)
134+
return response

app/routers/recommendation.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
from dotenv import load_dotenv
5-
from fastapi import APIRouter, Depends, status
5+
from fastapi import APIRouter, Depends, status, HTTPException
66
from langchain_community.chat_models import ChatOpenAI
77
from langchain_core.prompts import PromptTemplate
88

@@ -27,23 +27,28 @@
2727

2828
@router.post("/main", status_code=status.HTTP_200_OK)
2929
async def get_recommendation(user_data: RecommendationMainRequestDTO) -> ChatResponse:
30+
try:
31+
# 모델
32+
config_recommendation = config['NESS_RECOMMENDATION']
3033

31-
# 모델
32-
chat_model = ChatOpenAI(temperature=0.5, # 창의성 (0.0 ~ 2.0)
33-
max_tokens=2048, # 최대 토큰수
34-
model_name='gpt-3.5-turbo-1106', # 모델명
35-
openai_api_key=OPENAI_API_KEY # API 키
36-
)
34+
chat_model = ChatOpenAI(temperature=config_recommendation['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
35+
max_tokens=config_recommendation['MAX_TOKENS'], # 최대 토큰수
36+
model_name=config_recommendation['MODEL_NAME'], # 모델명
37+
openai_api_key=OPENAI_API_KEY # API 키
38+
)
3739

38-
# vectordb에서 유저의 정보를 가져온다.
39-
schedule = await vectordb.db_recommendation_main(user_data)
40+
# vectordb에서 유저의 정보를 가져온다.
41+
schedule = await vectordb.db_recommendation_main(user_data)
4042

41-
print(schedule)
43+
print(schedule)
4244

43-
# 템플릿
44-
recommendation_template = openai_prompt.Template.recommendation_template
45+
# 템플릿
46+
recommendation_template = openai_prompt.Template.recommendation_template
4547

46-
prompt = PromptTemplate.from_template(recommendation_template)
47-
result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
48-
print(result)
49-
return ChatResponse(ness=result)
48+
prompt = PromptTemplate.from_template(recommendation_template)
49+
result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
50+
print(result)
51+
return ChatResponse(ness=result)
52+
53+
except Exception as e:
54+
raise HTTPException(status_code=500, detail=str(e))

0 commit comments

Comments
 (0)