Skip to content

Commit 60d558e

Browse files
authored
Update utils.py
1 parent 97dbdf9 commit 60d558e

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

utils.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212

1313
DEBUG = int(os.environ.get("DEBUG", "0"))
1414

15-
1615
def generate_together(
17-
model,
18-
messages,
19-
max_tokens=2048,
20-
temperature=0.7,
21-
streaming=False,
16+
model,
17+
messages,
18+
max_tokens=2048,
19+
temperature=0.7,
20+
streaming=False,
2221
):
2322

2423
output = None
@@ -77,10 +76,10 @@ def generate_together(
7776

7877

7978
def generate_together_stream(
80-
model,
81-
messages,
82-
max_tokens=2048,
83-
temperature=0.7,
79+
model,
80+
messages,
81+
max_tokens=2048,
82+
temperature=0.7,
8483
):
8584
endpoint = "https://api.groq.com/openai/v1/"
8685
client = openai.OpenAI(
@@ -99,10 +98,10 @@ def generate_together_stream(
9998

10099

101100
def generate_openai(
102-
model,
103-
messages,
104-
max_tokens=2048,
105-
temperature=0.7,
101+
model,
102+
messages,
103+
max_tokens=2048,
104+
temperature=0.7,
106105
):
107106

108107
client = openai.OpenAI(
@@ -137,8 +136,8 @@ def generate_openai(
137136

138137

139138
def inject_references_to_messages(
140-
messages,
141-
references,
139+
messages,
140+
references,
142141
):
143142

144143
messages = copy.deepcopy(messages)
@@ -163,21 +162,27 @@ def inject_references_to_messages(
163162

164163

165164
def generate_with_references(
166-
model,
167-
messages,
168-
references=[],
169-
max_tokens=2048,
170-
temperature=0.7,
171-
generate_fn=generate_together,
165+
model,
166+
messages,
167+
references=[],
168+
max_tokens=2048,
169+
temperature=0.7,
170+
generate_fn=generate_together_stream,
172171
):
173-
174172
if len(references) > 0:
175-
176173
messages = inject_references_to_messages(messages, references)
177174

178-
return generate_fn(
175+
# Generate response using the provided generate function
176+
response = generate_fn(
179177
model=model,
180178
messages=messages,
181179
temperature=temperature,
182180
max_tokens=max_tokens,
183181
)
182+
183+
# Check if the response is in the expected format
184+
if hasattr(response, 'choices'):
185+
return response
186+
else:
187+
return [{"choices": [{"delta": {"content": response}}]}]
188+

0 commit comments

Comments
 (0)