Skip to content

Commit 0102164

Browse files
committed
Merge branch 'main' into jal/simple-hf-lock
2 parents a50fd16 + 2120112 commit 0102164

File tree

19 files changed

+230
-112
lines changed

19 files changed

+230
-112
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ repos:
77
- id: ruff-format
88
name: "Ruff formatter"
99
args: [--config=pyproject.toml]
10-
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
10+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
1111
- id: ruff
1212
name: "Ruff linter"
1313
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
14-
files: '^(mellea|tests).*\.(py|ipynb)$'
14+
files: '^(mellea).*\.(py|ipynb)$'
1515

1616
- repo: local
1717
hooks:
@@ -20,7 +20,7 @@ repos:
2020
entry: uv run --no-sync mypy mellea
2121
pass_filenames: false
2222
language: system
23-
files: '\.py$'
23+
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
2424

2525
- repo: https://github.com/astral-sh/uv-pre-commit
2626
rev: 0.7.8

mellea/backends/adapters/adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import abc
44
import pathlib
5-
from typing import Any, TypeVar
5+
from typing import TypeVar
66

77
import granite_common.intrinsics
88
import yaml
9-
from litellm import cast
109

1110
from mellea.backends import Backend
1211
from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata

mellea/backends/huggingface.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ async def _generate_from_intrinsic(
281281
if not ctx.is_chat_context:
282282
raise Exception("Does not yet support non-chat contexts.")
283283

284+
if len(model_options.items()) > 0:
285+
FancyLogger.get_logger().info(
286+
"passing in model options when generating with an adapter; some model options may be overwritten / ignored"
287+
)
288+
284289
linearized_ctx = ctx.view_for_generation()
285290
assert linearized_ctx is not None, (
286291
"If ctx.is_chat_context, then the context should be linearizable."
@@ -341,6 +346,12 @@ async def _generate_from_intrinsic(
341346
"messages": conversation,
342347
"extra_body": {"documents": docs},
343348
}
349+
350+
# Convert other parameters from Mellea proprietary format to standard format.
351+
for model_option in model_options:
352+
if model_option == ModelOption.TEMPERATURE:
353+
request_json["temperature"] = model_options[model_option]
354+
344355
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
345356

346357
# TODO: Handle caching here. granite_common doesn't tell us what changed,

mellea/backends/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def __init__(
5454
base_url: str | None = "http://localhost:11434",
5555
model_options: dict | None = None,
5656
):
57-
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
57+
"""Initialize an OpenAI compatible backend using the [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk).
5858
5959
Note: If getting `Unclosed client session`, set `export DISABLE_AIOHTTP_TRANSPORT=True` in your environment. See: https://github.com/BerriAI/litellm/issues/13251.
6060
6161
Args:
62-
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
62+
model_id : The LiteLLM model identifier; in most cases requires some combination of `<provider>/<model_creator>/<model_name>`. Make sure that all necessary credentials are in OS environment variables.
6363
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
6464
base_url : Base url for LLM API. Defaults to None.
6565
model_options : Generation options to pass to the LLM. Defaults to None.

mellea/backends/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ async def _generate_from_intrinsic(
435435
"extra_body": {"documents": docs},
436436
}
437437

438+
# Convert other parameters from Mellea proprietary format to standard format.
439+
if model_options is not None:
440+
for model_option in model_options:
441+
if model_option == ModelOption.TEMPERATURE:
442+
request_json["temperature"] = model_options[model_option]
443+
438444
rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs)
439445

440446
self.load_adapter(adapter.qualified_name)

mellea/stdlib/intrinsics/rag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdapterType,
1010
GraniteCommonAdapter,
1111
)
12+
from mellea.backends.types import ModelOption
1213
from mellea.stdlib.base import ChatContext, Document
1314
from mellea.stdlib.chat import Message
1415
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
@@ -63,6 +64,7 @@ def _call_intrinsic(
6364
intrinsic,
6465
context,
6566
backend,
67+
model_options={ModelOption.TEMPERATURE: 0.0},
6668
# No rejection sampling, please
6769
strategy=None,
6870
)
@@ -277,7 +279,7 @@ def rewrite_answer_for_relevance(
277279
backend,
278280
kwargs={
279281
"answer_relevance_category": result_json["answer_relevance_category"],
280-
"answer_relevance_analysis": result_json["answer_relevance_category"],
282+
"answer_relevance_analysis": result_json["answer_relevance_analysis"],
281283
"correction_method": correction_method,
282284
},
283285
)

mellea/templates/prompts/default/LLMaJRequirement.jinja2

Lines changed: 0 additions & 15 deletions
This file was deleted.

test/backends/test_huggingface.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717
from mellea.backends.formatter import TemplateFormatter
1818
from mellea.backends.huggingface import LocalHFBackend, _assert_correct_adapters
1919
from mellea.backends.types import ModelOption
20-
from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk,
21-
SimpleContext)
20+
from mellea.stdlib.base import (
21+
CBlock,
22+
ChatContext,
23+
Context,
24+
ModelOutputThunk,
25+
SimpleContext,
26+
)
2227
from mellea.stdlib.chat import Message
2328
from mellea.stdlib.intrinsics.intrinsic import Intrinsic
24-
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
25-
Requirement, ValidationResult,
26-
default_output_to_bool)
29+
from mellea.stdlib.requirement import (
30+
ALoraRequirement,
31+
LLMaJRequirement,
32+
Requirement,
33+
ValidationResult,
34+
default_output_to_bool,
35+
)
2736

2837

2938
@pytest.fixture(scope="module")
@@ -40,9 +49,7 @@ def backend():
4049
)
4150
)
4251
backend.add_adapter(
43-
GraniteCommonAdapter(
44-
"answerability", base_model_name=backend.base_model_name
45-
)
52+
GraniteCommonAdapter("answerability", base_model_name=backend.base_model_name)
4653
)
4754
return backend
4855

@@ -54,6 +61,7 @@ def session(backend):
5461
yield session
5562
session.reset()
5663

64+
5765
@pytest.mark.qualitative
5866
def test_adapters(backend):
5967
assert len(backend._added_adapters.items()) > 0
@@ -305,6 +313,7 @@ async def test_async_avalue(session):
305313
assert m1_final_val is not None
306314
assert m1_final_val == mot1.value
307315

316+
308317
@pytest.mark.qualitative
309318
async def test_generate_with_lock(backend):
310319
# Enable the faulthandler for this test.
@@ -319,23 +328,20 @@ async def test_generate_with_lock(backend):
319328
b._added_adapters = {}
320329
b._loaded_adapters = {}
321330
b.add_adapter(
322-
GraniteCommonAdapter(
323-
"requirement_check", base_model_name=b.base_model_name
324-
)
331+
GraniteCommonAdapter("requirement_check", base_model_name=b.base_model_name)
325332
)
326333
b.add_adapter(
327-
GraniteCommonAdapter(
328-
"answerability", base_model_name=b.base_model_name
329-
)
334+
GraniteCommonAdapter("answerability", base_model_name=b.base_model_name)
330335
)
331336

332337
memoized = dict()
333338
gen_func = model.generate
339+
334340
def mock_func(input_ids, *args, **kwargs):
335341
"""Mocks the generate function. Must call `populate_mocked_dict` with each input that must be cached before using this."""
336342
for key, val in memoized.items():
337343
if torch.equal(key, input_ids):
338-
time.sleep(random.uniform(.1, .5)) # Simulate a bit of work.
344+
time.sleep(random.uniform(0.1, 0.5)) # Simulate a bit of work.
339345
return val
340346
assert False, "did not get a cached response"
341347

@@ -347,7 +353,9 @@ def populate_mocked_dict(input_ids, *args, **kwargs):
347353
return output
348354

349355
model.generate = Mock(side_effect=populate_mocked_dict)
350-
assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail"
356+
assert not isinstance(backend._model, Mock), (
357+
"mocking went wrong; backend fixture changed; other tests may fail"
358+
)
351359

352360
# Set up the inputs.
353361
ctx = ChatContext().add(Message("user", "hello"))
@@ -362,18 +370,22 @@ def call_backend_generate():
362370
b.generate_from_context(act, ctx),
363371
b.generate_from_context(req_intrinsic, ctx),
364372
b.generate_from_context(answerability_intrinsic, ctx),
365-
b.generate_from_raw([raw_act], ctx, model_options={ModelOption.MAX_NEW_TOKENS: 3})
373+
b.generate_from_raw(
374+
[raw_act], ctx, model_options={ModelOption.MAX_NEW_TOKENS: 3}
375+
),
366376
]
367377

368378
# Call once to populate the memoized mock.
369379
outputs = await asyncio.gather(*call_backend_generate())
370380
for output in outputs:
371381
mot = output[0]
372-
await mot.avalue() # Ensure all values are computed.
382+
await mot.avalue() # Ensure all values are computed.
373383

374384
# Use the memoized mock that errors if not precomputed.
375385
model.generate = Mock(side_effect=mock_func)
376-
count = 5 # Use a high number to try to put pressure on the lock and catch deadlocks.
386+
count = (
387+
5 # Use a high number to try to put pressure on the lock and catch deadlocks.
388+
)
377389
coros: list[Coroutine[Any, Any, tuple[ModelOutputThunk, Context]]] = []
378390
for _ in range(count):
379391
coros.extend(call_backend_generate())
@@ -388,10 +400,11 @@ def call_backend_generate():
388400

389401
faulthandler.disable()
390402

403+
391404
@pytest.mark.qualitative
392405
async def test_generate_with_lock_does_not_block_when_awaiting_value(backend):
393-
"""This is a tricky test to setup.
394-
406+
"""This is a tricky test to setup.
407+
395408
It's purpose is to ensure that a long-running generation doesn't get blocked
396409
when awaiting the `model_output_thunk.avalue()` of a different generation request.
397410
@@ -417,14 +430,28 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend):
417430
# - a streaming generation that will take a long time to resolve.
418431
# - a regular generation that should be able to happen while the streaming is happening.
419432
# - two intrinsics that shouldn't be able to happen concurrently.
420-
reg_mot_stream, _ = await backend.generate_from_context(act, ctx, model_options={ModelOption.STREAM: True, ModelOption.MAX_NEW_TOKENS: token_generation_length, "min_length": token_generation_length})
433+
reg_mot_stream, _ = await backend.generate_from_context(
434+
act,
435+
ctx,
436+
model_options={
437+
ModelOption.STREAM: True,
438+
ModelOption.MAX_NEW_TOKENS: token_generation_length,
439+
"min_length": token_generation_length,
440+
},
441+
)
421442
reg_mot, _ = await backend.generate_from_context(act, ctx)
422-
req_mot, _ = await backend.generate_from_context(req_intrinsic, ctx, model_options={ModelOption.STREAM: True})
423-
answerability_mot, _ = await backend.generate_from_context(answerability_intrinsic, ctx, model_options={ModelOption.STREAM: True})
443+
req_mot, _ = await backend.generate_from_context(
444+
req_intrinsic, ctx, model_options={ModelOption.STREAM: True}
445+
)
446+
answerability_mot, _ = await backend.generate_from_context(
447+
answerability_intrinsic, ctx, model_options={ModelOption.STREAM: True}
448+
)
424449

425450
# Ensure the stream is generating but not yet completing.
426451
await reg_mot_stream.astream()
427-
assert not reg_mot_stream.is_computed(), "generation completed too early, see test for more details"
452+
assert not reg_mot_stream.is_computed(), (
453+
"generation completed too early, see test for more details"
454+
)
428455

429456
# Awaiting this shouldn't cause a deadlock. Add the timeout so the test can fail.
430457
# If the test fails, this means that the streaming generation wasn't able to complete,
@@ -442,11 +469,12 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend):
442469
raise e
443470
else:
444471
raise Exception("timeout ended too early, see test for more details")
445-
472+
446473
for output in [reg_mot_stream, reg_mot, req_mot, answerability_mot]:
447474
if not output.is_computed():
448475
await output.avalue() # Ensure everything gets computed.
449476

477+
450478
@pytest.mark.qualitative
451479
async def test_error_during_generate_with_lock(backend):
452480
# Create local versions of these objects so that mocking
@@ -459,20 +487,21 @@ async def test_error_during_generate_with_lock(backend):
459487
b._added_adapters = {}
460488
b._loaded_adapters = {}
461489
b.add_adapter(
462-
GraniteCommonAdapter(
463-
"requirement_check", base_model_name=b.base_model_name
464-
)
490+
GraniteCommonAdapter("requirement_check", base_model_name=b.base_model_name)
465491
)
466492

467493
regular_generate = b._model.generate
494+
468495
def generate_and_raise_exc(*args, **kwargs):
469496
"""Will generate like usual for the intrinsic request. Will fail for the regular generation request."""
470497
if "max_new_tokens" in kwargs:
471498
return regular_generate(*args, **kwargs) # type: ignore
472499
raise Exception("Oops!")
473500

474501
b._model.generate = Mock(side_effect=generate_and_raise_exc)
475-
assert not isinstance(backend._model, Mock), "mocking went wrong; backend fixture changed; other tests may fail"
502+
assert not isinstance(backend._model, Mock), (
503+
"mocking went wrong; backend fixture changed; other tests may fail"
504+
)
476505

477506
# Set up the inputs.
478507
ctx = ChatContext().add(Message("user", "hello"))
@@ -487,9 +516,10 @@ def generate_and_raise_exc(*args, **kwargs):
487516

488517
await req_mot.avalue()
489518

519+
490520
def test_assert_correct_adapters():
491521
model = Mock()
492-
522+
493523
# Test scenarios with no active adapters.
494524
model.active_adapters = Mock(return_value=[])
495525
_assert_correct_adapters("", model)
@@ -505,11 +535,16 @@ def test_assert_correct_adapters():
505535
_assert_correct_adapters("new", model)
506536

507537
# Test scenarios when no adapters have been loaded.
508-
model.active_adapters = Mock(side_effect=ValueError("No adapter loaded. Please load an adapter first."))
509-
_assert_correct_adapters("", model) # This will fail if peft ever changes the error message.
538+
model.active_adapters = Mock(
539+
side_effect=ValueError("No adapter loaded. Please load an adapter first.")
540+
)
541+
_assert_correct_adapters(
542+
"", model
543+
) # This will fail if peft ever changes the error message.
510544
with pytest.raises(AssertionError):
511545
_assert_correct_adapters("new", model)
512546

547+
513548
if __name__ == "__main__":
514549
import pytest
515550

0 commit comments

Comments
 (0)