7
7
from langchain_community .chat_models import ChatOpenAI
8
8
from langchain_core .prompts import PromptTemplate
9
9
10
- from app .dto .openai_dto import PromptRequest , ChatResponse
10
+ from app .dto .openai_dto import PromptRequest , ChatResponse , ChatCaseResponse
11
11
from app .prompt import openai_prompt
12
12
13
13
import app .database .chroma_db as vectordb
26
26
config = configparser .ConfigParser ()
27
27
config .read (CONFIG_FILE_PATH )
28
28
29
- @router .post ("/case" , status_code = status .HTTP_200_OK , response_model = ChatResponse )
30
- async def get_langchain_case (data : PromptRequest ) -> ChatResponse :
29
+ @router .post ("/case" , status_code = status .HTTP_200_OK , response_model = ChatCaseResponse )
30
+ async def get_langchain_case (data : PromptRequest ) -> ChatCaseResponse :
31
31
# description: use langchain
32
32
33
- config_normal = config ['NESS_NORMAL ' ]
33
+ config_chat = config ['NESS_CHAT ' ]
34
34
35
- chat_model = ChatOpenAI (temperature = config_normal ['TEMPERATURE' ], # 창의성 (0.0 ~ 2.0)
36
- max_tokens = config_normal ['MAX_TOKENS' ], # 최대 토큰수
37
- model_name = config_normal ['MODEL_NAME' ], # 모델명
35
+ chat_model = ChatOpenAI (temperature = config_chat ['TEMPERATURE' ], # 창의성 (0.0 ~ 2.0)
36
+ max_tokens = config_chat ['MAX_TOKENS' ], # 최대 토큰수
37
+ model_name = config_chat ['MODEL_NAME' ], # 모델명
38
38
openai_api_key = OPENAI_API_KEY # API 키
39
39
)
40
40
question = data .prompt
@@ -48,28 +48,32 @@ async def get_langchain_case(data: PromptRequest) -> ChatResponse:
48
48
print (case )
49
49
case = int (case )
50
50
if case == 1 :
51
- return await get_langchain_normal (data )
51
+ response = await get_langchain_normal (data )
52
52
53
53
elif case == 2 :
54
- return await get_langchain_schedule (data )
54
+ response = await get_langchain_schedule (data )
55
55
56
56
elif case == 3 :
57
- return await get_langchain_rag (data )
57
+ response = await get_langchain_rag (data )
58
58
59
59
else :
60
60
print ("wrong case classification" )
61
61
# 적절한 HTTP 상태 코드와 함께 오류 메시지를 반환하거나, 다른 처리를 할 수 있습니다.
62
62
raise HTTPException (status_code = 400 , detail = "Wrong case classification" )
63
63
64
+ return ChatCaseResponse (ness = response , case = case )
65
+
64
66
65
67
# case 1 : normal
66
68
#@router.post("/case/normal") # 테스트용 엔드포인트
67
- async def get_langchain_normal (data : PromptRequest ) -> ChatResponse : # case 1 : normal
69
+ async def get_langchain_normal (data : PromptRequest ): # case 1 : normal
68
70
print ("running case 1" )
69
71
# description: use langchain
70
- chat_model = ChatOpenAI (temperature = 0 , # 창의성 (0.0 ~ 2.0)
71
- max_tokens = 2048 , # 최대 토큰수
72
- model_name = 'gpt-3.5-turbo-1106' , # 모델명
72
+ config_normal = config ['NESS_NORMAL' ]
73
+
74
+ chat_model = ChatOpenAI (temperature = config_normal ['TEMPERATURE' ], # 창의성 (0.0 ~ 2.0)
75
+ max_tokens = config_normal ['MAX_TOKENS' ], # 최대 토큰수
76
+ model_name = config_normal ['MODEL_NAME' ], # 모델명
73
77
openai_api_key = OPENAI_API_KEY # API 키
74
78
)
75
79
question = data .prompt
@@ -82,16 +86,18 @@ async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 :
82
86
prompt = PromptTemplate .from_template (my_template )
83
87
response = chat_model .predict (prompt .format (output_language = "Korean" , question = question ))
84
88
print (response )
85
- return ChatResponse ( ness = response )
89
+ return response
86
90
87
91
# case 2 : 일정 생성
88
92
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
89
- async def get_langchain_schedule (data : PromptRequest ) -> ChatResponse :
93
+ async def get_langchain_schedule (data : PromptRequest ):
90
94
print ("running case 2" )
91
95
# description: use langchain
92
- chat_model = ChatOpenAI (temperature = 0 , # 창의성 (0.0 ~ 2.0)
93
- max_tokens = 2048 , # 최대 토큰수
94
- model_name = 'gpt-3.5-turbo-1106' , # 모델명
96
+ config_normal = config ['NESS_NORMAL' ]
97
+
98
+ chat_model = ChatOpenAI (temperature = config_normal ['TEMPERATURE' ], # 창의성 (0.0 ~ 2.0)
99
+ max_tokens = config_normal ['MAX_TOKENS' ], # 최대 토큰수
100
+ model_name = config_normal ['MODEL_NAME' ], # 모델명
95
101
openai_api_key = OPENAI_API_KEY # API 키
96
102
)
97
103
question = data .prompt
@@ -100,16 +106,18 @@ async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
100
106
prompt = PromptTemplate .from_template (case2_template )
101
107
response = chat_model .predict (prompt .format (output_language = "Korean" , question = question ))
102
108
print (response )
103
- return ChatResponse ( ness = response )
109
+ return response
104
110
105
111
# case 3 : rag
106
112
#@router.post("/case/rag") # 테스트용 엔드포인트
107
- async def get_langchain_rag (data : PromptRequest ) -> ChatResponse :
113
+ async def get_langchain_rag (data : PromptRequest ):
108
114
print ("running case 3" )
109
115
# description: use langchain
110
- chat_model = ChatOpenAI (temperature = 0 , # 창의성 (0.0 ~ 2.0)
111
- max_tokens = 2048 , # 최대 토큰수
112
- model_name = 'gpt-3.5-turbo-1106' , # 모델명
116
+ config_normal = config ['NESS_NORMAL' ]
117
+
118
+ chat_model = ChatOpenAI (temperature = config_normal ['TEMPERATURE' ], # 창의성 (0.0 ~ 2.0)
119
+ max_tokens = config_normal ['MAX_TOKENS' ], # 최대 토큰수
120
+ model_name = config_normal ['MODEL_NAME' ], # 모델명
113
121
openai_api_key = OPENAI_API_KEY # API 키
114
122
)
115
123
question = data .prompt
@@ -123,4 +131,4 @@ async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
123
131
prompt = PromptTemplate .from_template (case3_template )
124
132
response = chat_model .predict (prompt .format (output_language = "Korean" , question = question , schedule = schedule ))
125
133
print (response )
126
- return ChatResponse ( ness = response )
134
+ return response
0 commit comments