Skip to content

Commit 9c455b2

Browse files
committed
acting on GV's comment; adding api_stats in one place; updating import
1 parent a6c6dca commit 9c455b2

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

src/palimpzest/elements/pzlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .elements import Field
1+
from palimpzest.elements import Field
22

33
class ListField(Field, list):
44
"""A field representing a list of elements of specified types, with full list functionality."""

src/palimpzest/generators/generators.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,24 @@ def _get_model(self) -> dsp.LM:
7777

7878
return model
7979

80-
def _get_attn(dspy_lm: dsp.LM):
80+
def _get_usage_and_finish_reason(self, dspy_lm: dsp.LM):
81+
"""
82+
Parse and return the usage statistics and finish reason.
83+
"""
84+
usage, finish_reason = None, None
85+
if self.model_name in [Model.GPT_3_5.value, Model.GPT_4.value]:
86+
usage = dspy_lm.history[-1]['response']['usage']
87+
finish_reason = dspy_lm.history[-1]['response']['choices'][-1]['finish_reason']
88+
elif self.model_name in [Model.GEMINI_1.value]:
89+
usage = {}
90+
finish_reason = dspy_lm.history[-1]['response'][0]._result.candidates[0].finish_reason
91+
elif self.model_name in [Model.MIXTRAL.value]:
92+
usage = dspy_lm.history[-1]['response']['usage']
93+
finish_reason = dspy_lm.history[-1]['response']['finish_reason']
94+
95+
return usage, finish_reason
96+
97+
def _get_attn(self, dspy_lm: dsp.LM):
8198
"""
8299
TODO
83100
"""
@@ -147,16 +164,7 @@ def generate(self, context: str, question: str) -> GenerationOutput:
147164

148165
# extract the log probabilities for the actual result(s) which are returned
149166
answer_log_probs = self._get_answer_log_probs(dspy_lm, pred.answer)
150-
# TODO refactor this out in a method like in ImageTextGenerator?
151-
if self.model_name in [Model.GPT_3_5.value, Model.GPT_4.value]:
152-
usage = dspy_lm.history[-1]['response']['usage']
153-
finish_reason = dspy_lm.history[-1]['response']['choices'][-1]['finish_reason']
154-
elif self.model_name in [Model.GEMINI_1.value]:
155-
usage = {}
156-
finish_reason = dspy_lm.history[-1]['response'][0]._result.candidates[0].finish_reason
157-
elif self.model_name in [Model.MIXTRAL.value]:
158-
usage = dspy_lm.history[-1]['response']['usage']
159-
finish_reason = dspy_lm.history[-1]['response']['finish_reason']
167+
usage, finish_reason = self._get_usage_and_finish_reason(dspy_lm)
160168

161169
# collect statistics on prompt, usage, and timing
162170
stats = GenerationStats(

src/palimpzest/solver/solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _fileToXLS(candidate: DataRecord):
178178
dr.sheet_names = xls.sheet_names
179179

180180
if shouldProfile:
181-
candidate._stats[td.op_id] = InduceNonLLMStats()
181+
candidate._stats[td.op_id] = InduceNonLLMStats(api_stats=api_stats)
182182
return [dr]
183183
return _fileToXLS
184184

0 commit comments

Comments
 (0)