Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ async def _generate_from_intrinsic(
if not ctx.is_chat_context:
raise Exception("Does not yet support non-chat contexts.")

if len(model_options.items()) > 0:
FancyLogger.get_logger().info(
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
)

linearized_ctx = ctx.view_for_generation()
assert linearized_ctx is not None, (
"If ctx.is_chat_context, then the context should be linearizable."
Expand Down Expand Up @@ -311,6 +316,12 @@ async def _generate_from_intrinsic(
"messages": conversation,
"extra_body": {"documents": docs},
}

# Convert other parameters from Mellea proprietary format to standard format.
for model_option in model_options:
if model_option == ModelOption.TEMPERATURE:
request_json["temperature"] = model_options[model_option]

rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)

# TODO: Handle caching here. granite_common doesn't tell us what changed,
Expand Down
6 changes: 6 additions & 0 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
"extra_body": {"documents": docs},
}

# Convert other parameters from Mellea proprietary format to standard format.
if model_options is not None:
for model_option in model_options:
if model_option == ModelOption.TEMPERATURE:
request_json["temperature"] = model_options[model_option]

rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)

self.load_adapter(adapter.qualified_name)
Expand Down
4 changes: 3 additions & 1 deletion mellea/stdlib/intrinsics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AdapterType,
GraniteCommonAdapter,
)
from mellea.backends.types import ModelOption
from mellea.stdlib.base import ChatContext, Document
from mellea.stdlib.chat import Message
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
Expand Down Expand Up @@ -63,6 +64,7 @@ def _call_intrinsic(
intrinsic,
context,
backend,
model_options={ModelOption.TEMPERATURE: 0.0},
# No rejection sampling, please
strategy=None,
)
Expand Down Expand Up @@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
backend,
kwargs={
"answer_relevance_category": result_json["answer_relevance_category"],
"answer_relevance_analysis": result_json["answer_relevance_category"],
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
"correction_method": correction_method,
},
)
Expand Down
20 changes: 19 additions & 1 deletion test/stdlib_intrinsics/test_rag/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def test_hallucination_detection(backend):
def test_answer_relevance(backend):
"""Verify that the answer relevance composite intrinsic functions properly."""
context, answer, docs = _read_input_json("answer_relevance.json")
expected_rewrite = "Alice, Bob, and Carol attended the meeting."

# Note that this is not the optimal answer. This test is currently using an
# outdated LoRA adapter. Releases of new adapters will come after the Mellea
# integration has stabilized.
expected_rewrite = (
"The documents do not provide information about the attendees of the meeting."
)

# First call triggers adapter loading
result = rag.rewrite_answer_for_relevance(answer, docs, context, backend)
Expand All @@ -178,5 +184,17 @@ def test_answer_relevance(backend):
assert result == answer


def test_answer_relevance_classifier(backend):
"""Verify that the first phase of the answer relevance flow behaves as expectee."""
context, answer, docs = _read_input_json("answer_relevance.json")

result_json = rag._call_intrinsic(
"answer_relevance_classifier",
context.add(Message("assistant", answer, documents=list(docs))),
backend,
)
assert result_json["answer_relevance_likelihood"] == 0.0


if __name__ == "__main__":
pytest.main([__file__])
Loading