Skip to content

Commit e8f1d62

Browse files
authored
Merge branch 'main' into jal/simple-hf-lock
2 parents a50fd16 + 2120112 commit e8f1d62

File tree

19 files changed

+177
-84
lines changed

19 files changed

+177
-84
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ repos:
77
- id: ruff-format
88
name: "Ruff formatter"
99
args: [--config=pyproject.toml]
10-
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
10+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
1111
- id: ruff
1212
name: "Ruff linter"
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
14-
files: '^(mellea|tests).*\.(py|ipynb)$'
14+
files: '^(mellea).*\.(py|ipynb)$'
1515

1616
- repo: local
1717
hooks:
@@ -20,7 +20,7 @@ repos:
2020
entry: uv run --no-sync mypy mellea
2121
pass_filenames: false
2222
language: system
23-
files: '\.py$'
23+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

mellea/backends/adapters/adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import abc
44
import pathlib
5-
from typing import Any, TypeVar
5+
from typing import TypeVar
66

77
import granite_common.intrinsics
88
import yaml
9-
from litellm import cast
109

1110
from mellea.backends import Backend
1211
from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata

mellea/backends/huggingface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ async def _generate_from_intrinsic(
281281
if not ctx.is_chat_context:
282282
raise Exception("Does not yet support non-chat contexts.")
283283

284+
if len(model_options.items()) > 0:
285+
FancyLogger.get_logger().info(
286+
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
287+
)
288+
284289
linearized_ctx = ctx.view_for_generation()
285290
assert linearized_ctx is not None, (
286291
"If ctx.is_chat_context, then the context should be linearizable."
@@ -341,6 +346,12 @@ async def _generate_from_intrinsic(
341346
"messages": conversation,
342347
"extra_body": {"documents": docs},
343348
}
349+
350+
# Convert other parameters from Mellea proprietary format to standard format.
351+
for model_option in model_options:
352+
if model_option == ModelOption.TEMPERATURE:
353+
request_json["temperature"] = model_options[model_option]
354+
344355
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
345356

346357
# TODO: Handle caching here. granite_common doesn't tell us what changed,

mellea/backends/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def __init__(
5454
base_url: str | None = "http://localhost:11434",
5555
model_options: dict | None = None,
5656
):
57-
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
57+
"""Initialize an OpenAI compatible backend using the [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk).
5858
5959
Note: If getting `Unclosed client session`, set `export DISABLE_AIOHTTP_TRANSPORT=True` in your environment. See: https://github.com/BerriAI/litellm/issues/13251.
6060
6161
Args:
62-
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
62+
model_id : The LiteLLM model identifier; in most cases requires some combination of `<provider>/<model_creator>/<model_name>`. Make sure that all necessary credentials are in OS environment variables.
6363
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
6464
base_url : Base url for LLM API. Defaults to None.
6565
model_options : Generation options to pass to the LLM. Defaults to None.

mellea/backends/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
435435
"extra_body": {"documents": docs},
436436
}
437437

438+
# Convert other parameters from Mellea proprietary format to standard format.
439+
if model_options is not None:
440+
for model_option in model_options:
441+
if model_option == ModelOption.TEMPERATURE:
442+
request_json["temperature"] = model_options[model_option]
443+
438444
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
439445

440446
self.load_adapter(adapter.qualified_name)

mellea/stdlib/intrinsics/rag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdapterType,
1010
GraniteCommonAdapter,
1111
)
12+
from mellea.backends.types import ModelOption
1213
from mellea.stdlib.base import ChatContext, Document
1314
from mellea.stdlib.chat import Message
1415
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
@@ -63,6 +64,7 @@ def _call_intrinsic(
6364
intrinsic,
6465
context,
6566
backend,
67+
model_options={ModelOption.TEMPERATURE: 0.0},
6668
# No rejection sampling, please
6769
strategy=None,
6870
)
@@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
277279
backend,
278280
kwargs={
279281
"answer_relevance_category": result_json["answer_relevance_category"],
280-
"answer_relevance_analysis": result_json["answer_relevance_category"],
282+
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
281283
"correction_method": correction_method,
282284
},
283285
)

mellea/templates/prompts/default/LLMaJRequirement.jinja2

Lines changed: 0 additions & 15 deletions
This file was deleted.

test/backends/test_huggingface.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717
from mellea.backends.formatter import TemplateFormatter
1818
from mellea.backends.huggingface import LocalHFBackend, _assert_correct_adapters
1919
from mellea.backends.types import ModelOption
20-
from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk,
21-
SimpleContext)
20+
from mellea.stdlib.base import (
21+
CBlock,
22+
ChatContext,
23+
Context,
24+
ModelOutputThunk,
25+
SimpleContext,
26+
)
2227
from mellea.stdlib.chat import Message
2328
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
24-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
25-
Requirement, ValidationResult,
26-
default_output_to_bool)
29+
from mellea.stdlib.requirement import (
30+
ALoraRequirement,
31+
LLMaJRequirement,
32+
Requirement,
33+
ValidationResult,
34+
default_output_to_bool,
35+
)
2736

2837

2938
@pytest.fixture(scope="module")
@@ -54,6 +63,7 @@ def session(backend):
5463
yield session
5564
session.reset()
5665

66+
5767
@pytest.mark.qualitative
5868
def test_adapters(backend):
5969
assert len(backend._added_adapters.items()) > 0

test/backends/test_litellm_ollama.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def backend(gh_run: int):
2626
url = url.replace("127.0.0.1", "http://localhost")
2727

2828
return LiteLLMBackend(
29-
model_id=_MODEL_ID,
30-
base_url=url,
31-
model_options={"api_base": url},
29+
model_id=_MODEL_ID, base_url=url, model_options={"api_base": url}
3230
)
3331
else:
3432
return LiteLLMBackend(model_id=_MODEL_ID)
@@ -111,12 +109,6 @@ def test_litellm_ollama_instruct_options(session):
111109
ModelOption.SEED: 123,
112110
ModelOption.TEMPERATURE: 0.5,
113111
ModelOption.MAX_NEW_TOKENS: 100,
114-
115-
# Ollama thinking controls currently broken on Granite; see
116-
# https://github.com/ollama/ollama/issues/10983
117-
# TODO: Re-enable when this upstream bug gets fixed.
118-
#ModelOption.THINKING: True,
119-
#"reasoning_effort": True,
120112
"homer_simpson": "option should be kicked out",
121113
}
122114

@@ -144,6 +136,7 @@ def is_happy(text: str) -> bool:
144136
# should yield to true - but, of course, is model dependent
145137
assert h is True
146138

139+
147140
async def test_generate_from_raw(session):
148141
prompts = [
149142
"what is 1+1?",
@@ -157,7 +150,9 @@ async def test_generate_from_raw(session):
157150
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
158151
)
159152

160-
assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts"
153+
assert len(results) == 1, (
154+
"ollama doesn't support batching; litellm should send a single message containing all prompts"
155+
)
161156
assert results[0].value is not None
162157

163158

test/backends/test_litellm_watsonx.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,15 @@ def test_multiple_sync_funcs(session):
4141

4242
@pytest.mark.qualitative
4343
async def test_generate_from_raw(session):
44-
prompts = [
45-
"what is 1+1?",
46-
"what is 2+2?",
47-
"what is 3+3?",
48-
"what is 4+2+2?",
49-
]
44+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"]
5045

5146
results = await session.backend.generate_from_raw(
5247
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
5348
)
5449

55-
assert len(results) == 1, "litellm converts a batch request for watsonx into a single message"
50+
assert len(results) == 1, (
51+
"litellm converts a batch request for watsonx into a single message"
52+
)
5653
assert results[0].value is not None
5754

5855

0 commit comments

Comments
 (0)