Skip to content

Commit

Permalink
Merge pull request #2 from liberate-org/exp/interrupt
Browse files Browse the repository at this point in the history
Exp/interrupt
  • Loading branch information
rde8026 authored Feb 24, 2024
2 parents 96465e6 + 5a6fce9 commit fbf1874
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
else:
self.openaiAsyncClient = AsyncOpenAI(
base_url = "https://api.openai.com/v1",
api_key = openai_api_key or getenv("OPENAI_API_KEY")
api_key = openai_api_key or getenv("OPENAI_API_KEY"),
)
self.openaiSyncClient = OpenAI(
base_url = "https://api.openai.com/v1",
Expand Down Expand Up @@ -141,6 +141,7 @@ async def respond(
text = self.first_response
else:
chat_parameters = self.get_chat_parameters()
chat_parameters["stream"] = True
# chat_completion = await openai.ChatCompletion.acreate(**chat_parameters)
chat_completion = await self.openaiAsyncClient.chat.completions.create(**chat_parameters)
text = chat_completion.choices[0].message.content
Expand Down
2 changes: 2 additions & 0 deletions vocode/streaming/models/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class SynthesizerConfig(TypedModel, type=SynthesizerType.BASE.value):
audio_encoding: AudioEncoding
should_encode_as_wav: bool = False
sentiment_config: Optional[SentimentConfig] = None
reengage_timeout: Optional[float] = None
reengage_options: Optional[List[str]] = None

class Config:
arbitrary_types_allowed = True
Expand Down
92 changes: 91 additions & 1 deletion vocode/streaming/streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@

OutputDeviceType = TypeVar("OutputDeviceType", bound=BaseOutputDevice)


class StreamingConversation(Generic[OutputDeviceType]):
class QueueingInterruptibleEventFactory(InterruptibleEventFactory):
def __init__(self, conversation: "StreamingConversation"):
Expand Down Expand Up @@ -119,11 +118,27 @@ def __init__(
self.conversation = conversation
self.interruptible_event_factory = interruptible_event_factory

def kill_tasks_when_human_is_talking(self):
has_task = self.conversation.synthesis_results_worker.current_task is not None
if has_task and not self.conversation.synthesis_results_worker.current_task.done():
self.conversation.logger.info("###### Synthesis task is running, attempting to cancel it ######")
self.conversation.synthesis_results_worker.current_task.cancel()
self.conversation.logger.info("###### Synthesis task is running, has been canceled ######")
has_agent_task = self.conversation.agent_responses_worker.current_task
if has_agent_task and not self.conversation.agent_responses_worker.current_task.done():
self.conversation.logger.info("&&&&&&& Agent Response task is running, attempting to cancel it &&&&&&&")
self.conversation.agent_responses_worker.current_task.cancel()
self.conversation.logger.info("&&&&&&& Agent Response task is running, has been canceled &&&&&&&")

async def process(self, transcription: Transcription):
self.conversation.mark_last_action_timestamp()
if transcription.message.strip() == "":
self.conversation.logger.info("Ignoring empty transcription")
return
elif transcription.message.strip() == "<INTERRUPT>" and transcription.confidence == 1.0:
# self.kill_tasks_when_human_is_talking()
self.conversation.broadcast_interrupt()

if transcription.is_final:
self.conversation.logger.debug(
"Got transcription: {}, confidence: {}".format(
Expand Down Expand Up @@ -156,6 +171,10 @@ async def process(self, transcription: Transcription):
)
)
self.output_queue.put_nowait(event)
self.conversation.mark_last_final_transcript_from_human()
# else:
# self.kill_tasks_when_human_is_talking()
# self.conversation.broadcast_interrupt()

class FillerAudioWorker(InterruptibleAgentResponseWorker):
"""
Expand Down Expand Up @@ -365,6 +384,7 @@ async def process(
await self.conversation.terminate()
except asyncio.TimeoutError:
pass
self.conversation.mark_last_agent_response()
except asyncio.CancelledError:
pass

Expand Down Expand Up @@ -508,6 +528,12 @@ async def start(self, mark_ready: Optional[Callable[[], Awaitable[None]]] = None
self.check_for_idle_task = asyncio.create_task(self.check_for_idle())
if len(self.events_manager.subscriptions) > 0:
self.events_task = asyncio.create_task(self.events_manager.start())
if (
self.synthesizer.get_synthesizer_config().reengage_timeout and
(self.synthesizer.get_synthesizer_config().reengage_options and
len(self.synthesizer.get_synthesizer_config().reengage_options) > 0)
):
self.human_prompt_checker = asyncio.create_task(self.check_if_human_should_be_prompted())

async def send_initial_message(self, initial_message: BaseMessage):
# TODO: configure if initial message is interruptible
Expand Down Expand Up @@ -571,6 +597,13 @@ def warmup_synthesizer(self):
def mark_last_action_timestamp(self):
self.last_action_timestamp = time.time()

def mark_last_final_transcript_from_human(self):
self.last_final_transcript_from_human = time.time()

def mark_last_agent_response(self):
self.last_agent_response = time.time()


def broadcast_interrupt(self):
"""Stops all inflight events and cancels all workers that are sending output
Expand All @@ -588,13 +621,32 @@ def broadcast_interrupt(self):
break
self.agent.cancel_current_task()
self.agent_responses_worker.cancel_current_task()

# Clearing these queues cuts time from finishing interruption talking to bot talking cut by 1 second from ~4.5 to ~3.5 seconds.
self.clear_queue(self.agent.output_queue, 'agent.output_queue')
self.clear_queue(self.agent_responses_worker.output_queue, 'agent_responses_worker.output_queue')
self.clear_queue(self.agent_responses_worker.input_queue, 'agent_responses_worker.input_queue')
self.clear_queue(self.synthesis_results_worker.output_queue, 'synthesis_results_worker.output_queue')
self.clear_queue(self.synthesis_results_worker.input_queue, 'synthesis_results_worker.input_queue')
if hasattr(self.output_device, 'queue'):
self.clear_queue(self.output_device.queue, 'output_device.queue')

return num_interrupts > 0

def is_interrupt(self, transcription: Transcription):
return transcription.confidence >= (
self.transcriber.get_transcriber_config().min_interrupt_confidence or 0
)

@staticmethod
def clear_queue(q: asyncio.Queue, queue_name: str):
while not q.empty():
logging.info(f'Clearing queue {queue_name} with size {q.qsize()}')
try:
q.get_nowait()
except asyncio.QueueEmpty:
continue

async def send_speech_to_output(
self,
message: str,
Expand Down Expand Up @@ -726,3 +778,41 @@ async def terminate(self):

def is_active(self):
return self.active

async def check_if_human_should_be_prompted(self):
self.logger.debug("starting should prompt user task")
self.last_agent_response = None
self.last_final_transcript_from_human = None
reengage_timeout = self.synthesizer.get_synthesizer_config().reengage_timeout
reengage_options = self.synthesizer.get_synthesizer_config().reengage_options
while self.active:
if self.last_agent_response and self.last_final_transcript_from_human:
last_human_touchpoint = time.time() - self.last_final_transcript_from_human
last_agent_touchpoint = time.time() - self.last_agent_response
if last_human_touchpoint >= reengage_timeout and last_agent_touchpoint >= reengage_timeout:
reengage_statement = random.choice(reengage_options)
self.logger.debug(f"Prompting user with {reengage_statement}: no interaction has happened in {reengage_timeout} seconds")
self.chunk_size = (
get_chunk_size_per_second(
self.synthesizer.get_synthesizer_config().audio_encoding,
self.synthesizer.get_synthesizer_config().sampling_rate,
)
* TEXT_TO_SPEECH_CHUNK_SIZE_SECONDS
)
message = BaseMessage(text=reengage_statement)
synthesis_result = await self.synthesizer.create_speech(
message,
self.chunk_size,
bot_sentiment=self.bot_sentiment,
)
self.agent_responses_worker.produce_interruptible_agent_response_event_nonblocking(
(message, synthesis_result),
is_interruptible=True,
agent_response_tracker=asyncio.Event(),
)
self.mark_last_agent_response()
await asyncio.sleep(1)
else:
await asyncio.sleep(1)
self.logger.debug("stopped check if human should be prompted")

0 comments on commit fbf1874

Please sign in to comment.