@@ -77,7 +77,24 @@ def _get_model(self) -> dsp.LM:
77
77
78
78
return model
79
79
80
- def _get_attn (dspy_lm : dsp .LM ):
80
+ def _get_usage_and_finish_reason (self , dspy_lm : dsp .LM ):
81
+ """
82
+ Parse and return the usage statistics and finish reason.
83
+ """
84
+ usage , finish_reason = None , None
85
+ if self .model_name in [Model .GPT_3_5 .value , Model .GPT_4 .value ]:
86
+ usage = dspy_lm .history [- 1 ]['response' ]['usage' ]
87
+ finish_reason = dspy_lm .history [- 1 ]['response' ]['choices' ][- 1 ]['finish_reason' ]
88
+ elif self .model_name in [Model .GEMINI_1 .value ]:
89
+ usage = {}
90
+ finish_reason = dspy_lm .history [- 1 ]['response' ][0 ]._result .candidates [0 ].finish_reason
91
+ elif self .model_name in [Model .MIXTRAL .value ]:
92
+ usage = dspy_lm .history [- 1 ]['response' ]['usage' ]
93
+ finish_reason = dspy_lm .history [- 1 ]['response' ]['finish_reason' ]
94
+
95
+ return usage , finish_reason
96
+
97
+ def _get_attn (self , dspy_lm : dsp .LM ):
81
98
"""
82
99
TODO
83
100
"""
@@ -147,16 +164,7 @@ def generate(self, context: str, question: str) -> GenerationOutput:
147
164
148
165
# extract the log probabilities for the actual result(s) which are returned
149
166
answer_log_probs = self ._get_answer_log_probs (dspy_lm , pred .answer )
150
- # TODO refactor this out in a method like in ImageTextGenerator?
151
- if self .model_name in [Model .GPT_3_5 .value , Model .GPT_4 .value ]:
152
- usage = dspy_lm .history [- 1 ]['response' ]['usage' ]
153
- finish_reason = dspy_lm .history [- 1 ]['response' ]['choices' ][- 1 ]['finish_reason' ]
154
- elif self .model_name in [Model .GEMINI_1 .value ]:
155
- usage = {}
156
- finish_reason = dspy_lm .history [- 1 ]['response' ][0 ]._result .candidates [0 ].finish_reason
157
- elif self .model_name in [Model .MIXTRAL .value ]:
158
- usage = dspy_lm .history [- 1 ]['response' ]['usage' ]
159
- finish_reason = dspy_lm .history [- 1 ]['response' ]['finish_reason' ]
167
+ usage , finish_reason = self ._get_usage_and_finish_reason (dspy_lm )
160
168
161
169
# collect statistics on prompt, usage, and timing
162
170
stats = GenerationStats (
0 commit comments