Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid interfaces #90

Merged
merged 4 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions wafl/connectors/remote/remote_whisper_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]:
async with session.post(self._server_url, json=payload) as response:
data = await response.text()
prediction = json.loads(data)
if "transcription" not in prediction:
raise RuntimeError("No transcription found in prediction. Is your microphone working?")

transcription = prediction["transcription"]
score = prediction["score"]
logp = prediction["logp"]
Expand Down
4 changes: 3 additions & 1 deletion wafl/events/conversation_events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import traceback

from wafl.events.answerer_creator import create_answerer
from wafl.simple_text_processing.normalize import normalized
Expand Down Expand Up @@ -59,6 +60,7 @@ async def _process_query(self, text: str):

if (
not text_is_question
and self._interface.get_utterances_list()
and self._interface.get_utterances_list()[-1].find("user:") == 0
):
await self._interface.output("I don't know what to reply")
Expand Down Expand Up @@ -108,7 +110,7 @@ def reload_knowledge(self):

def reset_discourse_memory(self):
self._answerer = create_answerer(
self._config, self._knowledge, self._interface, logger
self._config, self._knowledge, self._interface, self._logger
)

def _activation_word_in_text(self, activation_word, text):
Expand Down
13 changes: 10 additions & 3 deletions wafl/interface/base_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ async def input(self) -> str:
def bot_has_spoken(self, to_set: bool = None):
raise NotImplementedError

async def insert_input(self, text: str):
pass

def is_listening(self):
return self._is_listening

Expand All @@ -32,9 +35,6 @@ def deactivate(self):
self._facts = []
self._utterances = []

def add_hotwords(self, hotwords: List[str]):
raise NotImplementedError

async def add_choice(self, text):
self._choices.append((time.time(), text))
await self.output(f"Making the choice: {text}", silent=True)
Expand All @@ -60,8 +60,15 @@ def reset_history(self):
self._choices = []
self._facts = []

def add_hotwords(self, hotwords):
pass

def _decorate_reply(self, text: str) -> str:
if not self._decorator:
return text

return self._decorator.extract(text, self._utterances)

def _insert_utterance(self, speaker, text: str):
if self._utterances == [] or text != self._utterances[-1][1].replace(f"{speaker}: ", ""):
self._utterances.append((time.time(), f"{speaker}: {text}"))
54 changes: 54 additions & 0 deletions wafl/interface/list_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
from typing import List

from wafl.interface.base_interface import BaseInterface


class ListInterface(BaseInterface):
def __init__(self, interfaces_list: List[BaseInterface]):
super().__init__()
self._interfaces_list = interfaces_list
self._synchronize_interfaces()

async def output(self, text: str, silent: bool = False):
await asyncio.wait(
[interface.output(text, silent) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
)

async def input(self) -> str:
done, pending = await asyncio.wait(
[interface.input() for interface in self._interfaces_list],
return_when=asyncio.FIRST_COMPLETED
)
return done.pop().result()

async def insert_input(self, text: str):
await asyncio.wait(
[interface.insert_input(text) for interface in self._interfaces_list],
return_when=asyncio.ALL_COMPLETED
)

def bot_has_spoken(self, to_set: bool = None):
for interface in self._interfaces_list:
interface.bot_has_spoken(to_set)

def activate(self):
for interface in self._interfaces_list:
interface.activate()
super().activate()

def deactivate(self):
for interface in self._interfaces_list:
interface.deactivate()
super().deactivate()
self._synchronize_interfaces()


def add_hotwords(self, hotwords):
for interface in self._interfaces_list:
interface.add_hotwords(hotwords)

def _synchronize_interfaces(self):
for interface in self._interfaces_list:
interface._utterances = self._utterances
11 changes: 6 additions & 5 deletions wafl/interface/queue_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time

from wafl.interface.base_interface import BaseInterface

Expand All @@ -16,19 +15,21 @@ async def output(self, text: str, silent: bool = False):
self.output_queue.append({"text": text, "silent": True})
return

utterance = text
self.output_queue.append({"text": utterance, "silent": False})
self._utterances.append((time.time(), f"bot: {text}"))
self.output_queue.append({"text": text, "silent": False})
self._insert_utterance("bot", text)
self.bot_has_spoken(True)

async def input(self) -> str:
while not self.input_queue:
await asyncio.sleep(0.1)

text = self.input_queue.pop(0)
self._utterances.append((time.time(), f"user: {text}"))
self._insert_utterance("user", text)
return text

async def insert_input(self, text: str):
self.input_queue.append(text)

def bot_has_spoken(self, to_set: bool = None):
if to_set != None:
self._bot_has_spoken = to_set
Expand Down
21 changes: 4 additions & 17 deletions wafl/interface/voice_interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
import random
import re
import time

from wafl.events.utils import remove_text_between_brackets
from wafl.simple_text_processing.deixis import from_bot_to_user
from wafl.interface.base_interface import BaseInterface
from wafl.interface.utils import get_most_common_words, not_good_enough
from wafl.interface.utils import not_good_enough
from wafl.listener.whisper_listener import WhisperListener
from wafl.speaker.fairseq_speaker import FairSeqSpeaker
from wafl.speaker.soundfile_speaker import SoundFileSpeaker
Expand Down Expand Up @@ -42,17 +40,6 @@ def __init__(self, config):
self._bot_has_spoken = False
self._utterances = []

async def add_hotwords_from_knowledge(
self, knowledge: "Knowledge", max_num_words: int = 100, count_threshold: int = 5
):
hotwords = get_most_common_words(
knowledge.get_facts_and_rule_as_text(),
max_num_words=max_num_words,
count_threshold=count_threshold,
)
hotwords = [word.lower() for word in hotwords]
self._listener.add_hotwords(hotwords)

def add_hotwords(self, hotwords):
self._listener.add_hotwords(hotwords)

Expand All @@ -65,8 +52,8 @@ async def output(self, text: str, silent: bool = False):
return

self._listener.activate()
text = from_bot_to_user(text)
self._utterances.append((time.time(), f"bot: {text}"))
text = text
self._insert_utterance("bot", text)
print(COLOR_START + "bot> " + text + COLOR_END)
await self._speaker.speak(text)
self.bot_has_spoken(True)
Expand All @@ -89,7 +76,7 @@ async def input(self) -> str:
print(COLOR_START + "user> " + text + COLOR_END)
utterance = remove_text_between_brackets(text)
if utterance.strip():
self._utterances.append((time.time(), f"user: {text}"))
self._insert_utterance("user", text)

return text

Expand Down
7 changes: 6 additions & 1 deletion wafl/listener/whisper_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ async def input(self):

while True:
await asyncio.sleep(0)
inp = self.stream.read(self._chunk)
try:
inp = self.stream.read(self._chunk)
except IOError:
self.activate()
inp = self.stream.read(self._chunk)

rms_val = _rms(inp)
if rms_val > self._volume_threshold:
waveform = self.record(start_with=inp)
Expand Down
56 changes: 2 additions & 54 deletions wafl/runners/routes.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import asyncio
import os
import random
import sys
import threading

from flask import Flask, render_template, redirect, url_for
from flask import Flask
from flask_cors import CORS
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.queue_interface import QueueInterface
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.scheduler.conversation_loop import ConversationLoop
from wafl.scheduler.scheduler import Scheduler
from wafl.scheduler.web_loop import WebLoop

_path = os.path.dirname(__file__)
_logger = LocalFileLogger()
app = Flask(
__name__,
static_url_path="",
Expand All @@ -25,51 +13,11 @@
CORS(app)


@app.route("/create_new_instance", methods=["POST"])
def create_new_instance():
conversation_id = random.randint(0, sys.maxsize)
result = create_scheduler_and_webserver_loop(conversation_id)
add_new_rules(app, conversation_id, result["web_server_loop"])
thread = threading.Thread(target=result["scheduler"].run)
thread.start()
return redirect(url_for(f"index_{conversation_id}"))


@app.route("/")
async def index():
return render_template("selector.html")


def get_app():
return app


def create_scheduler_and_webserver_loop(conversation_id):
config = Configuration.load_local_config()
interface = QueueInterface()
interface.activate()
conversation_events = ConversationEvents(
config=config,
interface=interface,
logger=_logger,
)
conversation_loop = ConversationLoop(
interface,
conversation_events,
_logger,
activation_word="",
max_misses=-1,
deactivate_on_closed_conversation=False,
)
asyncio.run(interface.output("Hello. How may I help you?"))
web_loop = WebLoop(interface, conversation_id, conversation_events)
return {
"scheduler": Scheduler([conversation_loop, web_loop]),
"web_server_loop": web_loop,
}


def add_new_rules(app, conversation_id, web_server_loop):
def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"):
app.add_url_rule(
f"/{conversation_id}/",
f"index_{conversation_id}",
Expand Down
63 changes: 63 additions & 0 deletions wafl/runners/run_web_and_audio_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
import random
import sys
import threading

from flask import render_template, redirect, url_for

from wafl.interface.list_interface import ListInterface
from wafl.interface.voice_interface import VoiceInterface
from wafl.scheduler.scheduler import Scheduler
from wafl.scheduler.web_loop import WebLoop
from wafl.scheduler.conversation_loop import ConversationLoop
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.queue_interface import QueueInterface
from wafl.config import Configuration
from wafl.runners.routes import get_app, add_new_rules


app = get_app()
_logger = LocalFileLogger()


def run_app():
@app.route("/create_new_instance", methods=["POST"])
def create_new_instance():
conversation_id = random.randint(0, sys.maxsize)
result = create_scheduler_and_webserver_loop(conversation_id)
add_new_rules(app, conversation_id, result["web_server_loop"])
thread = threading.Thread(target=result["scheduler"].run)
thread.start()
return redirect(url_for(f"index_{conversation_id}"))

@app.route("/")
async def index():
return render_template("selector.html")

def create_scheduler_and_webserver_loop(conversation_id):
config = Configuration.load_local_config()
interface = ListInterface([VoiceInterface(config), QueueInterface()])
interface.activate()
conversation_events = ConversationEvents(
config=config,
interface=interface,
logger=_logger,
)
conversation_loop = ConversationLoop(
interface,
conversation_events,
_logger,
activation_word=config.get_value("waking_up_word"),
)
web_loop = WebLoop(interface, conversation_id, conversation_events)
return {
"scheduler": Scheduler([conversation_loop, web_loop]),
"web_server_loop": web_loop,
}

app.run(host="0.0.0.0", port=8889)


if __name__ == "__main__":
run_app()
Loading
Loading