Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improved a client by added more options and fixed bug on server cleanup #101

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import uuid
import time

from typing import Callable


def resample(file: str, sr: int = 16000):
"""
Expand Down Expand Up @@ -49,6 +51,11 @@ class Client:
"""
INSTANCES = {}

__messages__ : list[frozenset[str]] = []

def total_messages(self) -> list[frozenset[str]]:
return self.__messages__

def __init__(
self,
host=None,
Expand All @@ -57,7 +64,9 @@ def __init__(
lang=None,
translate=False,
model_size="small",
use_custom_model=False
use_custom_model=False,
callback: Callable[[frozenset[str]], any]=None,
replay_playback: bool=False,
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -91,6 +100,8 @@ def __init__(
self.model_size = model_size
self.server_error = False
self.use_custom_model = use_custom_model
self.callback = callback
self.replay_playback = replay_playback

if translate:
self.task = "translate"
Expand Down Expand Up @@ -147,11 +158,13 @@ def on_message(self, ws, message):
self.last_response_recieved = time.time()
message = json.loads(message)

keys = message.keys()

if self.uid != message.get("uid"):
print("[ERROR]: invalid client uid")
return

if "status" in message.keys():
if "status" in keys:
if message["status"] == "WAIT":
self.waiting = True
print(
Expand All @@ -162,23 +175,23 @@ def on_message(self, ws, message):
self.server_error = True
return

if "message" in message.keys() and message["message"] == "DISCONNECT":
if "message" in keys and message["message"] == "DISCONNECT":
print("[INFO]: Server overtime disconnected.")
self.recording = False

if "message" in message.keys() and message["message"] == "SERVER_READY":
if "message" in keys and message["message"] == "SERVER_READY":
self.recording = True
return

if "language" in message.keys():
if "language" in keys:
self.language = message.get("language")
lang_prob = message.get("language_prob")
print(
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
)
return

if "segments" not in message.keys():
if "segments" not in keys:
return

message = message["segments"]
Expand All @@ -194,13 +207,15 @@ def on_message(self, ws, message):
text = text[-3:]
wrapper = textwrap.TextWrapper(width=60)
word_list = wrapper.wrap(text="".join(text))
# Print each line.
if os.name == "nt":
os.system("cls")
else:
os.system("clear")
for element in word_list:
print(element)
immutable_word_list = frozenset(word_list)

latest_word_list_hash = hash(self.__messages__[-1]) if len(self.__messages__) > 0 else None
current_word_list_hash = hash(immutable_word_list)

if latest_word_list_hash != current_word_list_hash:
self.callback(immutable_word_list)

self.__messages__.append(immutable_word_list)

def on_error(self, ws, error):
print(error)
Expand All @@ -222,6 +237,8 @@ def on_open(self, ws):
print(self.multilingual, self.language, self.task)

print("[INFO]: Opened connection")

self.__messages__.clear()
ws.send(
json.dumps(
{
Expand Down Expand Up @@ -298,7 +315,9 @@ def play_file(self, filename):

audio_array = self.bytes_to_float_array(data)
self.send_packet_to_server(audio_array.tobytes())
self.stream.write(data)

if self.replay_playback:
self.stream.write(data)

wavfile.close()

Expand Down Expand Up @@ -515,16 +534,19 @@ class TranscriptionClient:
transcription_client()
```
"""
def __init__(self,
def __init__(
self,
host,
port,
is_multilingual=False,
lang=None,
translate=False,
model_size="small",
use_custom_model=False
use_custom_model=False,
callback: Callable[[frozenset[str]], any]=None,
replay_playback: bool=False,
):
self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model)
self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model, callback, replay_playback)

def __call__(self, audio=None, hls_url=None):
"""
Expand All @@ -551,4 +573,7 @@ def __call__(self, audio=None, hls_url=None):
resampled_file = resample(audio)
self.client.play_file(resampled_file)
else:
self.client.record()
self.client.record()

def transcribed_messages(self) -> list[frozenset[str]]:
return self.client.total_messages()
4 changes: 3 additions & 1 deletion whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def recv_audio(self, websocket, custom_model_path=None):
options = websocket.recv()
options = json.loads(options)

logging.info(f"with options {options}")

if len(self.clients) >= self.max_clients:
logging.warning("Client Queue Full. Asking client to wait ...")
wait_time = self.get_wait_time()
Expand Down Expand Up @@ -139,7 +141,7 @@ def recv_audio(self, websocket, custom_model_path=None):

except Exception as e:
logging.info(f"[ERROR]: Client with uid '{self.clients[websocket].client_uid}' Disconnected.")
if self.clients[websocket].model_size is not None:
if self.clients[websocket].model_size_or_path is not None:
self.clients[websocket].cleanup()
self.clients.pop(websocket)
self.clients_start_time.pop(websocket)
Expand Down
Loading