1717from mellea .backends .formatter import TemplateFormatter
1818from mellea .backends .huggingface import LocalHFBackend , _assert_correct_adapters
1919from mellea .backends .types import ModelOption
20- from mellea .stdlib .base import (CBlock , ChatContext , Context , ModelOutputThunk ,
21- SimpleContext )
20+ from mellea .stdlib .base import (
21+ CBlock ,
22+ ChatContext ,
23+ Context ,
24+ ModelOutputThunk ,
25+ SimpleContext ,
26+ )
2227from mellea .stdlib .chat import Message
2328from mellea .stdlib .intrinsics .intrinsic import Intrinsic
24- from mellea .stdlib .requirement import (ALoraRequirement , LLMaJRequirement ,
25- Requirement , ValidationResult ,
26- default_output_to_bool )
29+ from mellea .stdlib .requirement import (
30+ ALoraRequirement ,
31+ LLMaJRequirement ,
32+ Requirement ,
33+ ValidationResult ,
34+ default_output_to_bool ,
35+ )
2736
2837
2938@pytest .fixture (scope = "module" )
@@ -40,9 +49,7 @@ def backend():
4049 )
4150 )
4251 backend .add_adapter (
43- GraniteCommonAdapter (
44- "answerability" , base_model_name = backend .base_model_name
45- )
52+ GraniteCommonAdapter ("answerability" , base_model_name = backend .base_model_name )
4653 )
4754 return backend
4855
@@ -54,6 +61,7 @@ def session(backend):
5461 yield session
5562 session .reset ()
5663
64+
5765@pytest .mark .qualitative
5866def test_adapters (backend ):
5967 assert len (backend ._added_adapters .items ()) > 0
@@ -305,6 +313,7 @@ async def test_async_avalue(session):
305313 assert m1_final_val is not None
306314 assert m1_final_val == mot1 .value
307315
316+
308317@pytest .mark .qualitative
309318async def test_generate_with_lock (backend ):
310319 # Enable the faulthandler for this test.
@@ -319,23 +328,20 @@ async def test_generate_with_lock(backend):
319328 b ._added_adapters = {}
320329 b ._loaded_adapters = {}
321330 b .add_adapter (
322- GraniteCommonAdapter (
323- "requirement_check" , base_model_name = b .base_model_name
324- )
331+ GraniteCommonAdapter ("requirement_check" , base_model_name = b .base_model_name )
325332 )
326333 b .add_adapter (
327- GraniteCommonAdapter (
328- "answerability" , base_model_name = b .base_model_name
329- )
334+ GraniteCommonAdapter ("answerability" , base_model_name = b .base_model_name )
330335 )
331336
332337 memoized = dict ()
333338 gen_func = model .generate
339+
334340 def mock_func (input_ids , * args , ** kwargs ):
335341 """Mocks the generate function. Must call `populate_mocked_dict` with each input that must be cached before using this."""
336342 for key , val in memoized .items ():
337343 if torch .equal (key , input_ids ):
338- time .sleep (random .uniform (.1 , .5 )) # Simulate a bit of work.
344+ time .sleep (random .uniform (0 .1 , 0 .5 )) # Simulate a bit of work.
339345 return val
340346 assert False , "did not get a cached response"
341347
@@ -347,7 +353,9 @@ def populate_mocked_dict(input_ids, *args, **kwargs):
347353 return output
348354
349355 model .generate = Mock (side_effect = populate_mocked_dict )
350- assert not isinstance (backend ._model , Mock ), "mocking went wrong; backend fixture changed; other tests may fail"
356+ assert not isinstance (backend ._model , Mock ), (
357+ "mocking went wrong; backend fixture changed; other tests may fail"
358+ )
351359
352360 # Set up the inputs.
353361 ctx = ChatContext ().add (Message ("user" , "hello" ))
@@ -362,18 +370,22 @@ def call_backend_generate():
362370 b .generate_from_context (act , ctx ),
363371 b .generate_from_context (req_intrinsic , ctx ),
364372 b .generate_from_context (answerability_intrinsic , ctx ),
365- b .generate_from_raw ([raw_act ], ctx , model_options = {ModelOption .MAX_NEW_TOKENS : 3 })
373+ b .generate_from_raw (
374+ [raw_act ], ctx , model_options = {ModelOption .MAX_NEW_TOKENS : 3 }
375+ ),
366376 ]
367377
368378 # Call once to populate the memoized mock.
369379 outputs = await asyncio .gather (* call_backend_generate ())
370380 for output in outputs :
371381 mot = output [0 ]
372- await mot .avalue () # Ensure all values are computed.
382+ await mot .avalue () # Ensure all values are computed.
373383
374384 # Use the memoized mock that errors if not precomputed.
375385 model .generate = Mock (side_effect = mock_func )
376- count = 5 # Use a high number to try to put pressure on the lock and catch deadlocks.
386+ count = (
387+ 5 # Use a high number to try to put pressure on the lock and catch deadlocks.
388+ )
377389 coros : list [Coroutine [Any , Any , tuple [ModelOutputThunk , Context ]]] = []
378390 for _ in range (count ):
379391 coros .extend (call_backend_generate ())
@@ -388,10 +400,11 @@ def call_backend_generate():
388400
389401 faulthandler .disable ()
390402
403+
391404@pytest .mark .qualitative
392405async def test_generate_with_lock_does_not_block_when_awaiting_value (backend ):
393- """This is a tricky test to setup.
394-
406+ """This is a tricky test to setup.
407+
395408 It's purpose is to ensure that a long-running generation doesn't get blocked
396409 when awaiting the `model_output_thunk.avalue()` of a different generation request.
397410
@@ -417,14 +430,28 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend):
417430 # - a streaming generation that will take a long time to resolve.
418431 # - a regular generation that should be able to happen while the streaming is happening.
419432 # - two intrinsics that shouldn't be able to happen concurrently.
420- reg_mot_stream , _ = await backend .generate_from_context (act , ctx , model_options = {ModelOption .STREAM : True , ModelOption .MAX_NEW_TOKENS : token_generation_length , "min_length" : token_generation_length })
433+ reg_mot_stream , _ = await backend .generate_from_context (
434+ act ,
435+ ctx ,
436+ model_options = {
437+ ModelOption .STREAM : True ,
438+ ModelOption .MAX_NEW_TOKENS : token_generation_length ,
439+ "min_length" : token_generation_length ,
440+ },
441+ )
421442 reg_mot , _ = await backend .generate_from_context (act , ctx )
422- req_mot , _ = await backend .generate_from_context (req_intrinsic , ctx , model_options = {ModelOption .STREAM : True })
423- answerability_mot , _ = await backend .generate_from_context (answerability_intrinsic , ctx , model_options = {ModelOption .STREAM : True })
443+ req_mot , _ = await backend .generate_from_context (
444+ req_intrinsic , ctx , model_options = {ModelOption .STREAM : True }
445+ )
446+ answerability_mot , _ = await backend .generate_from_context (
447+ answerability_intrinsic , ctx , model_options = {ModelOption .STREAM : True }
448+ )
424449
425450 # Ensure the stream is generating but not yet completing.
426451 await reg_mot_stream .astream ()
427- assert not reg_mot_stream .is_computed (), "generation completed too early, see test for more details"
452+ assert not reg_mot_stream .is_computed (), (
453+ "generation completed too early, see test for more details"
454+ )
428455
429456 # Awaiting this shouldn't cause a deadlock. Add the timeout so the test can fail.
430457 # If the test fails, this means that the streaming generation wasn't able to complete,
@@ -442,11 +469,12 @@ async def test_generate_with_lock_does_not_block_when_awaiting_value(backend):
442469 raise e
443470 else :
444471 raise Exception ("timeout ended too early, see test for more details" )
445-
472+
446473 for output in [reg_mot_stream , reg_mot , req_mot , answerability_mot ]:
447474 if not output .is_computed ():
448475 await output .avalue () # Ensure everything gets computed.
449476
477+
450478@pytest .mark .qualitative
451479async def test_error_during_generate_with_lock (backend ):
452480 # Create local versions of these objects so that mocking
@@ -459,20 +487,21 @@ async def test_error_during_generate_with_lock(backend):
459487 b ._added_adapters = {}
460488 b ._loaded_adapters = {}
461489 b .add_adapter (
462- GraniteCommonAdapter (
463- "requirement_check" , base_model_name = b .base_model_name
464- )
490+ GraniteCommonAdapter ("requirement_check" , base_model_name = b .base_model_name )
465491 )
466492
467493 regular_generate = b ._model .generate
494+
468495 def generate_and_raise_exc (* args , ** kwargs ):
469496 """Will generate like usual for the intrinsic request. Will fail for the regular generation request."""
470497 if "max_new_tokens" in kwargs :
471498 return regular_generate (* args , ** kwargs ) # type: ignore
472499 raise Exception ("Oops!" )
473500
474501 b ._model .generate = Mock (side_effect = generate_and_raise_exc )
475- assert not isinstance (backend ._model , Mock ), "mocking went wrong; backend fixture changed; other tests may fail"
502+ assert not isinstance (backend ._model , Mock ), (
503+ "mocking went wrong; backend fixture changed; other tests may fail"
504+ )
476505
477506 # Set up the inputs.
478507 ctx = ChatContext ().add (Message ("user" , "hello" ))
@@ -487,9 +516,10 @@ def generate_and_raise_exc(*args, **kwargs):
487516
488517 await req_mot .avalue ()
489518
519+
490520def test_assert_correct_adapters ():
491521 model = Mock ()
492-
522+
493523 # Test scenarios with no active adapters.
494524 model .active_adapters = Mock (return_value = [])
495525 _assert_correct_adapters ("" , model )
@@ -505,11 +535,16 @@ def test_assert_correct_adapters():
505535 _assert_correct_adapters ("new" , model )
506536
507537 # Test scenarios when no adapters have been loaded.
508- model .active_adapters = Mock (side_effect = ValueError ("No adapter loaded. Please load an adapter first." ))
509- _assert_correct_adapters ("" , model ) # This will fail if peft ever changes the error message.
538+ model .active_adapters = Mock (
539+ side_effect = ValueError ("No adapter loaded. Please load an adapter first." )
540+ )
541+ _assert_correct_adapters (
542+ "" , model
543+ ) # This will fail if peft ever changes the error message.
510544 with pytest .raises (AssertionError ):
511545 _assert_correct_adapters ("new" , model )
512546
547+
513548if __name__ == "__main__" :
514549 import pytest
515550
0 commit comments