Skip to content

Commit bde9b4d

Browse files
frreissjakelorocco
andauthored
fix: minor updates to answer relevance (#245)
* Minor updates to answer relevance Signed-off-by: Fred Reiss <[email protected]> * fix: add warnings for model opts for hf intrinsics --------- Signed-off-by: Fred Reiss <[email protected]> Co-authored-by: Jake LoRocco <[email protected]>
1 parent e70d307 commit bde9b4d

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

mellea/backends/huggingface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ async def _generate_from_intrinsic(
251251
if not ctx.is_chat_context:
252252
raise Exception("Does not yet support non-chat contexts.")
253253

254+
if len(model_options.items()) > 0:
255+
FancyLogger.get_logger().info(
256+
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
257+
)
258+
254259
linearized_ctx = ctx.view_for_generation()
255260
assert linearized_ctx is not None, (
256261
"If ctx.is_chat_context, then the context should be linearizable."
@@ -311,6 +316,12 @@ async def _generate_from_intrinsic(
311316
"messages": conversation,
312317
"extra_body": {"documents": docs},
313318
}
319+
320+
# Convert other parameters from Mellea proprietary format to standard format.
321+
for model_option in model_options:
322+
if model_option == ModelOption.TEMPERATURE:
323+
request_json["temperature"] = model_options[model_option]
324+
314325
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
315326

316327
# TODO: Handle caching here. granite_common doesn't tell us what changed,

mellea/backends/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
435435
"extra_body": {"documents": docs},
436436
}
437437

438+
# Convert other parameters from Mellea proprietary format to standard format.
439+
if model_options is not None:
440+
for model_option in model_options:
441+
if model_option == ModelOption.TEMPERATURE:
442+
request_json["temperature"] = model_options[model_option]
443+
438444
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
439445

440446
self.load_adapter(adapter.qualified_name)

mellea/stdlib/intrinsics/rag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdapterType,
1010
GraniteCommonAdapter,
1111
)
12+
from mellea.backends.types import ModelOption
1213
from mellea.stdlib.base import ChatContext, Document
1314
from mellea.stdlib.chat import Message
1415
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
@@ -63,6 +64,7 @@ def _call_intrinsic(
6364
intrinsic,
6465
context,
6566
backend,
67+
model_options={ModelOption.TEMPERATURE: 0.0},
6668
# No rejection sampling, please
6769
strategy=None,
6870
)
@@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
277279
backend,
278280
kwargs={
279281
"answer_relevance_category": result_json["answer_relevance_category"],
280-
"answer_relevance_analysis": result_json["answer_relevance_category"],
282+
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
281283
"correction_method": correction_method,
282284
},
283285
)

test/stdlib_intrinsics/test_rag/test_rag.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def test_hallucination_detection(backend):
161161
def test_answer_relevance(backend):
162162
"""Verify that the answer relevance composite intrinsic functions properly."""
163163
context, answer, docs = _read_input_json("answer_relevance.json")
164-
expected_rewrite = "Alice, Bob, and Carol attended the meeting."
164+
165+
# Note that this is not the optimal answer. This test is currently using an
166+
# outdated LoRA adapter. Releases of new adapters will come after the Mellea
167+
# integration has stabilized.
168+
expected_rewrite = (
169+
"The documents do not provide information about the attendees of the meeting."
170+
)
165171

166172
# First call triggers adapter loading
167173
result = rag.rewrite_answer_for_relevance(answer, docs, context, backend)
@@ -178,5 +184,17 @@ def test_answer_relevance(backend):
178184
assert result == answer
179185

180186

187+
def test_answer_relevance_classifier(backend):
188+
"""Verify that the first phase of the answer relevance flow behaves as expectee."""
189+
context, answer, docs = _read_input_json("answer_relevance.json")
190+
191+
result_json = rag._call_intrinsic(
192+
"answer_relevance_classifier",
193+
context.add(Message("assistant", answer, documents=list(docs))),
194+
backend,
195+
)
196+
assert result_json["answer_relevance_likelihood"] == 0.0
197+
198+
181199
if __name__ == "__main__":
182200
pytest.main([__file__])

0 commit comments

Comments
 (0)