File tree Expand file tree Collapse file tree 1 file changed +30
-25
lines changed Expand file tree Collapse file tree 1 file changed +30
-25
lines changed Original file line number Diff line number Diff line change 12
12
13
13
DEBUG = int (os .environ .get ("DEBUG" , "0" ))
14
14
15
-
16
15
def generate_together (
17
- model ,
18
- messages ,
19
- max_tokens = 2048 ,
20
- temperature = 0.7 ,
21
- streaming = False ,
16
+ model ,
17
+ messages ,
18
+ max_tokens = 2048 ,
19
+ temperature = 0.7 ,
20
+ streaming = False ,
22
21
):
23
22
24
23
output = None
@@ -77,10 +76,10 @@ def generate_together(
77
76
78
77
79
78
def generate_together_stream (
80
- model ,
81
- messages ,
82
- max_tokens = 2048 ,
83
- temperature = 0.7 ,
79
+ model ,
80
+ messages ,
81
+ max_tokens = 2048 ,
82
+ temperature = 0.7 ,
84
83
):
85
84
endpoint = "https://api.groq.com/openai/v1/"
86
85
client = openai .OpenAI (
@@ -99,10 +98,10 @@ def generate_together_stream(
99
98
100
99
101
100
def generate_openai (
102
- model ,
103
- messages ,
104
- max_tokens = 2048 ,
105
- temperature = 0.7 ,
101
+ model ,
102
+ messages ,
103
+ max_tokens = 2048 ,
104
+ temperature = 0.7 ,
106
105
):
107
106
108
107
client = openai .OpenAI (
@@ -137,8 +136,8 @@ def generate_openai(
137
136
138
137
139
138
def inject_references_to_messages (
140
- messages ,
141
- references ,
139
+ messages ,
140
+ references ,
142
141
):
143
142
144
143
messages = copy .deepcopy (messages )
@@ -163,21 +162,27 @@ def inject_references_to_messages(
163
162
164
163
165
164
def generate_with_references (
166
- model ,
167
- messages ,
168
- references = [],
169
- max_tokens = 2048 ,
170
- temperature = 0.7 ,
171
- generate_fn = generate_together ,
165
+ model ,
166
+ messages ,
167
+ references = [],
168
+ max_tokens = 2048 ,
169
+ temperature = 0.7 ,
170
+ generate_fn = generate_together_stream ,
172
171
):
173
-
174
172
if len (references ) > 0 :
175
-
176
173
messages = inject_references_to_messages (messages , references )
177
174
178
- return generate_fn (
175
+ # Generate response using the provided generate function
176
+ response = generate_fn (
179
177
model = model ,
180
178
messages = messages ,
181
179
temperature = temperature ,
182
180
max_tokens = max_tokens ,
183
181
)
182
+
183
+ # Check if the response is in the expected format
184
+ if hasattr (response , 'choices' ):
185
+ return response
186
+ else :
187
+ return [{"choices" : [{"delta" : {"content" : response }}]}]
188
+
You can’t perform that action at this time.
0 commit comments