Skip to content

Commit

Permalink
removed local connectors tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fractalego committed Dec 23, 2023
1 parent 6f28782 commit 42a3cd2
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import asyncio
import os
import wave

from unittest import TestCase

import numpy as np

from wafl.config import Configuration
from wafl.connectors.bridges.llm_chitchat_answer_bridge import LLMChitChatAnswerBridge
from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector
from wafl.listener.whisper_listener import WhisperListener
from wafl.speaker.fairseq_speaker import FairSeqSpeaker

_path = os.path.dirname(__file__)
Expand Down Expand Up @@ -51,33 +46,3 @@ def test__connection_to_generative_model_can_generate_a_python_list(self):
prediction = asyncio.run(connector.predict(prompt))
print(prediction)
assert len(prediction) > 0

def test__local_llm_connector_can_generate_a_python_list(self):
config = Configuration.load_from_filename("local_config.json")
connector = LocalLLMConnector(config.get_value("llm_model"))
connector._num_prediction_tokens = 200
prompt = "Generate a list of 4 chapters names for a space opera book. The output needs to be a python list of strings: "
prediction = asyncio.run(connector.predict(prompt))
print(prediction)
assert len(prediction) > 0

def test__chit_chat_bridge_can_run_locally(self):
config = Configuration.load_from_filename("local_config.json")
dialogue_bridge = LLMChitChatAnswerBridge(config)
answer = asyncio.run(dialogue_bridge.get_answer("", "", "bot: hello"))
assert len(answer) > 0

def test__listener_local_connector(self):
config = Configuration.load_from_filename("local_config.json")
listener = WhisperListener(config)
f = wave.open(os.path.join(_path, "data/1002.wav"), "rb")
waveform = np.frombuffer(f.readframes(f.getnframes()), dtype=np.int16) / 32768
result = asyncio.run(listener.input_waveform(waveform))
expected = "DELETE BATTERIES FROM THE GROCERY LIST"
assert expected.lower() in result

def test__speaker_local_connector(self):
config = Configuration.load_from_filename("local_config.json")
speaker = FairSeqSpeaker(config)
text = "Hello world"
asyncio.run(speaker.speak(text))

0 comments on commit 42a3cd2

Please sign in to comment.