Skip to content

Commit 67e19a3

Browse files
authored
Merge branch 'main' into jal/small-fixes
2 parents fcb5137 + 633bfd7 commit 67e19a3

File tree

16 files changed

+163
-60
lines changed

16 files changed

+163
-60
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ repos:
77
- id: ruff-format
88
name: "Ruff formatter"
99
args: [--config=pyproject.toml]
10-
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
10+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
1111
- id: ruff
1212
name: "Ruff linter"
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
14-
files: '^(mellea|tests).*\.(py|ipynb)$'
14+
files: '^(mellea).*\.(py|ipynb)$'
1515

1616
- repo: local
1717
hooks:
@@ -20,7 +20,7 @@ repos:
2020
entry: uv run --no-sync mypy mellea
2121
pass_filenames: false
2222
language: system
23-
files: '\.py$'
23+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

mellea/backends/adapters/adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import abc
44
import pathlib
5-
from typing import Any, TypeVar
5+
from typing import TypeVar
66

77
import granite_common.intrinsics
88
import yaml
9-
from litellm import cast
109

1110
from mellea.backends import Backend
1211
from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata

mellea/backends/huggingface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ async def _generate_from_intrinsic(
251251
if not ctx.is_chat_context:
252252
raise Exception("Does not yet support non-chat contexts.")
253253

254+
if len(model_options.items()) > 0:
255+
FancyLogger.get_logger().info(
256+
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
257+
)
258+
254259
linearized_ctx = ctx.view_for_generation()
255260
assert linearized_ctx is not None, (
256261
"If ctx.is_chat_context, then the context should be linearizable."
@@ -311,6 +316,12 @@ async def _generate_from_intrinsic(
311316
"messages": conversation,
312317
"extra_body": {"documents": docs},
313318
}
319+
320+
# Convert other parameters from Mellea proprietary format to standard format.
321+
for model_option in model_options:
322+
if model_option == ModelOption.TEMPERATURE:
323+
request_json["temperature"] = model_options[model_option]
324+
314325
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
315326

316327
# TODO: Handle caching here. granite_common doesn't tell us what changed,

mellea/backends/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
435435
"extra_body": {"documents": docs},
436436
}
437437

438+
# Convert other parameters from Mellea proprietary format to standard format.
439+
if model_options is not None:
440+
for model_option in model_options:
441+
if model_option == ModelOption.TEMPERATURE:
442+
request_json["temperature"] = model_options[model_option]
443+
438444
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
439445

440446
self.load_adapter(adapter.qualified_name)

mellea/stdlib/intrinsics/rag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdapterType,
1010
GraniteCommonAdapter,
1111
)
12+
from mellea.backends.types import ModelOption
1213
from mellea.stdlib.base import ChatContext, Document
1314
from mellea.stdlib.chat import Message
1415
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
@@ -63,6 +64,7 @@ def _call_intrinsic(
6364
intrinsic,
6465
context,
6566
backend,
67+
model_options={ModelOption.TEMPERATURE: 0.0},
6668
# No rejection sampling, please
6769
strategy=None,
6870
)
@@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
277279
backend,
278280
kwargs={
279281
"answer_relevance_category": result_json["answer_relevance_category"],
280-
"answer_relevance_analysis": result_json["answer_relevance_category"],
282+
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
281283
"correction_method": correction_method,
282284
},
283285
)

test/backends/test_huggingface.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,20 @@
1010
from mellea.backends.formatter import TemplateFormatter
1111
from mellea.backends.huggingface import LocalHFBackend
1212
from mellea.backends.types import ModelOption
13-
from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk,
14-
SimpleContext)
15-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
16-
Requirement, ValidationResult,
17-
default_output_to_bool)
13+
from mellea.stdlib.base import (
14+
CBlock,
15+
ChatContext,
16+
Context,
17+
ModelOutputThunk,
18+
SimpleContext,
19+
)
20+
from mellea.stdlib.requirement import (
21+
ALoraRequirement,
22+
LLMaJRequirement,
23+
Requirement,
24+
ValidationResult,
25+
default_output_to_bool,
26+
)
1827

1928

2029
@pytest.fixture(scope="module")
@@ -40,6 +49,7 @@ def session(backend):
4049
yield session
4150
session.reset()
4251

52+
4353
@pytest.mark.qualitative
4454
def test_adapters(backend):
4555
assert len(backend._added_adapters.items()) > 0

test/backends/test_litellm_ollama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def backend(gh_run: int):
2626
url = url.replace("127.0.0.1", "http://localhost")
2727

2828
return LiteLLMBackend(
29-
model_id=_MODEL_ID,
30-
base_url=url,
31-
model_options={"api_base": url},
29+
model_id=_MODEL_ID, base_url=url, model_options={"api_base": url}
3230
)
3331
else:
3432
return LiteLLMBackend(model_id=_MODEL_ID)
@@ -138,6 +136,7 @@ def is_happy(text: str) -> bool:
138136
# should yield to true - but, of course, is model dependent
139137
assert h is True
140138

139+
141140
async def test_generate_from_raw(session):
142141
prompts = [
143142
"what is 1+1?",
@@ -151,7 +150,9 @@ async def test_generate_from_raw(session):
151150
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
152151
)
153152

154-
assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts"
153+
assert len(results) == 1, (
154+
"ollama doesn't support batching; litellm should send a single message containing all prompts"
155+
)
155156
assert results[0].value is not None
156157

157158

test/backends/test_litellm_watsonx.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,15 @@ def test_multiple_sync_funcs(session):
4141

4242
@pytest.mark.qualitative
4343
async def test_generate_from_raw(session):
44-
prompts = [
45-
"what is 1+1?",
46-
"what is 2+2?",
47-
"what is 3+3?",
48-
"what is 4+2+2?",
49-
]
44+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"]
5045

5146
results = await session.backend.generate_from_raw(
5247
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
5348
)
5449

55-
assert len(results) == 1, "litellm converts a batch request for watsonx into a single message"
50+
assert len(results) == 1, (
51+
"litellm converts a batch request for watsonx into a single message"
52+
)
5653
assert results[0].value is not None
5754

5855

test/backends/test_openai_ollama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ async def test_generate_from_raw(m_session):
122122
actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx
123123
)
124124

125+
125126
# Default OpenAI implementation doesn't support structured outputs for the completions API.
126127
# def test_generate_from_raw_with_format(self):
127128
# prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

test/backends/test_openai_vllm/test_openai_vllm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
from mellea.backends.openai import OpenAIBackend
1212
from mellea.backends.types import ModelOption, _ServerType
1313
from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk
14-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
15-
Requirement, req)
14+
from mellea.stdlib.requirement import (
15+
ALoraRequirement,
16+
LLMaJRequirement,
17+
Requirement,
18+
req,
19+
)
1620

1721
# The vllm tests are disabled by default, because we need a test environment with the vLLM server running.
1822
# We use an env var VLLM_TESTS_ENABLED to enable these tests.
@@ -138,8 +142,11 @@ class TestOpenAIALoraStuff:
138142
base_url="http://localhost:8000/v1",
139143
api_key="EMPTY",
140144
)
141-
backend.add_adapter(GraniteCommonAdapter("requirement_check",
142-
base_model_name=backend.base_model_name))
145+
backend.add_adapter(
146+
GraniteCommonAdapter(
147+
"requirement_check", base_model_name=backend.base_model_name
148+
)
149+
)
143150

144151
m = MelleaSession(backend, ctx=ChatContext())
145152

0 commit comments

Comments
 (0)