From 93e69e35a314c60b136c571c3cfdd9de99a8a2fb Mon Sep 17 00:00:00 2001 From: "Chanvut.B" Date: Tue, 16 Jan 2024 13:06:32 +0700 Subject: [PATCH 1/3] Added a callback option for TranscriptionClient --- whisper_live/client.py | 55 ++++++++++++++++++++++++++++++------------ whisper_live/server.py | 4 ++- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 979e3254..6bfb9977 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -12,6 +12,8 @@ import uuid import time +from typing import Callable + def resample(file: str, sr: int = 16000): """ @@ -49,6 +51,11 @@ class Client: """ INSTANCES = {} + __messages__ : list[list[str]] = [] + + def total_messages(self) -> list[list[str]]: + return self.__messages__ + def __init__( self, host=None, @@ -57,7 +64,8 @@ def __init__( lang=None, translate=False, model_size="small", - use_custom_model=False + use_custom_model=False, + callback: Callable[[list[str]], any]=None ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -91,6 +99,7 @@ def __init__( self.model_size = model_size self.server_error = False self.use_custom_model = use_custom_model + self.callback = callback if translate: self.task = "translate" @@ -147,11 +156,13 @@ def on_message(self, ws, message): self.last_response_recieved = time.time() message = json.loads(message) + keys = message.keys() + if self.uid != message.get("uid"): print("[ERROR]: invalid client uid") return - if "status" in message.keys(): + if "status" in keys: if message["status"] == "WAIT": self.waiting = True print( @@ -162,15 +173,15 @@ def on_message(self, ws, message): self.server_error = True return - if "message" in message.keys() and message["message"] == "DISCONNECT": + if "message" in keys and message["message"] == "DISCONNECT": print("[INFO]: Server overtime disconnected.") self.recording = False - if "message" in message.keys() and message["message"] == "SERVER_READY": + if "message" in keys and message["message"] == "SERVER_READY": self.recording = True return - if "language" in message.keys(): + if "language" in keys: self.language = message.get("language") lang_prob = message.get("language_prob") print( @@ -178,7 +189,7 @@ def on_message(self, ws, message): ) return - if "segments" not in message.keys(): + if "segments" not in keys: return message = message["segments"] @@ -194,13 +205,18 @@ def on_message(self, ws, message): text = text[-3:] wrapper = textwrap.TextWrapper(width=60) word_list = wrapper.wrap(text="".join(text)) + + self.callback(word_list) + # Print each line. - if os.name == "nt": - os.system("cls") - else: - os.system("clear") - for element in word_list: - print(element) + # if os.name == "nt": + # os.system("cls") + # else: + # os.system("clear") + # for element in word_list: + # print(element) + + self.__messages__.append(word_list) def on_error(self, ws, error): print(error) @@ -222,6 +238,8 @@ def on_open(self, ws): print(self.multilingual, self.language, self.task) print("[INFO]: Opened connection") + + self.__messages__.clear() ws.send( json.dumps( { @@ -515,16 +533,18 @@ class TranscriptionClient: transcription_client() ``` """ - def __init__(self, + def __init__( + self, host, port, is_multilingual=False, lang=None, translate=False, model_size="small", - use_custom_model=False + use_custom_model=False, + callback: Callable[[list[str]], any]=None ): - self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model) + self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model, callback) def __call__(self, audio=None, hls_url=None): """ @@ -551,4 +571,7 @@ def __call__(self, audio=None, hls_url=None): resampled_file = resample(audio) self.client.play_file(resampled_file) else: - self.client.record() \ No newline at end of file + self.client.record() + + def transcribed_messages(self) -> list[list[str]]: + return self.client.total_messages() \ No newline at end of file diff --git a/whisper_live/server.py b/whisper_live/server.py index ddeaee3e..0106839e 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -86,6 +86,8 @@ def recv_audio(self, websocket, custom_model_path=None): options = websocket.recv() options = json.loads(options) + logging.info(f"with options {options}") + if len(self.clients) >= self.max_clients: logging.warning("Client Queue Full. Asking client to wait ...") wait_time = self.get_wait_time() @@ -139,7 +141,7 @@ def recv_audio(self, websocket, custom_model_path=None): except Exception as e: logging.info(f"[ERROR]: Client with uid '{self.clients[websocket].client_uid}' Disconnected.") - if self.clients[websocket].model_size is not None: + if self.clients[websocket].model_size_or_path is not None: self.clients[websocket].cleanup() self.clients.pop(websocket) self.clients_start_time.pop(websocket) From f376375239806c989d1b7a2971cf3f6e3b78c81d Mon Sep 17 00:00:00 2001 From: "Chanvut.B" Date: Tue, 16 Jan 2024 14:20:11 +0700 Subject: [PATCH 2/3] Added a code to filter out duplicated transcript data --- whisper_live/client.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 6bfb9977..11e6bca5 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -51,9 +51,9 @@ class Client: """ INSTANCES = {} - __messages__ : list[list[str]] = [] + __messages__ : list[frozenset[str]] = [] - def total_messages(self) -> list[list[str]]: + def total_messages(self) -> list[frozenset[str]]: return self.__messages__ def __init__( @@ -65,7 +65,7 @@ def __init__( translate=False, model_size="small", use_custom_model=False, - callback: Callable[[list[str]], any]=None + callback: Callable[[frozenset[str]], any]=None ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -205,18 +205,15 @@ def on_message(self, ws, message): text = text[-3:] wrapper = textwrap.TextWrapper(width=60) word_list = wrapper.wrap(text="".join(text)) + immutable_word_list = frozenset(word_list) - self.callback(word_list) - - # Print each line. - # if os.name == "nt": - # os.system("cls") - # else: - # os.system("clear") - # for element in word_list: - # print(element) + latest_word_list_hash = hash(self.__messages__[-1]) if len(self.__messages__) > 0 else None + current_word_list_hash = hash(immutable_word_list) + + if latest_word_list_hash != current_word_list_hash: + self.callback(immutable_word_list) - self.__messages__.append(word_list) + self.__messages__.append(immutable_word_list) def on_error(self, ws, error): print(error) @@ -316,7 +313,7 @@ def play_file(self, filename): audio_array = self.bytes_to_float_array(data) self.send_packet_to_server(audio_array.tobytes()) - self.stream.write(data) + # self.stream.write(data) wavfile.close() @@ -542,7 +539,7 @@ def __init__( translate=False, model_size="small", use_custom_model=False, - callback: Callable[[list[str]], any]=None + callback: Callable[[frozenset[str]], any]=None ): self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model, callback) From 4a83f46adab8fcdbdf01f16c29fc1ad2a35fce84 Mon Sep 17 00:00:00 2001 From: "Chanvut.B" Date: Tue, 16 Jan 2024 14:32:21 +0700 Subject: [PATCH 3/3] Added an option to replay playback if needed --- whisper_live/client.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 11e6bca5..982f648a 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -65,7 +65,8 @@ def __init__( translate=False, model_size="small", use_custom_model=False, - callback: Callable[[frozenset[str]], any]=None + callback: Callable[[frozenset[str]], any]=None, + replay_playback: bool=False, ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -100,6 +101,7 @@ def __init__( self.server_error = False self.use_custom_model = use_custom_model self.callback = callback + self.replay_playback = replay_playback if translate: self.task = "translate" @@ -313,7 +315,9 @@ def play_file(self, filename): audio_array = self.bytes_to_float_array(data) self.send_packet_to_server(audio_array.tobytes()) - # self.stream.write(data) + + if self.replay_playback: + self.stream.write(data) wavfile.close() @@ -539,9 +543,10 @@ def __init__( translate=False, model_size="small", use_custom_model=False, - callback: Callable[[frozenset[str]], any]=None + callback: Callable[[frozenset[str]], any]=None, + replay_playback: bool=False, ): - self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model, callback) + self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model, callback, replay_playback) def __call__(self, audio=None, hls_url=None): """ @@ -570,5 +575,5 @@ def __call__(self, audio=None, hls_url=None): else: self.client.record() - def transcribed_messages(self) -> list[list[str]]: + def transcribed_messages(self) -> list[frozenset[str]]: return self.client.total_messages() \ No newline at end of file