diff --git a/tests/test_server.py b/tests/test_server.py index f836be70..284572a9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,7 +8,7 @@ import evaluate from websockets.exceptions import ConnectionClosed -from whisper_live.server import TranscriptionServer +from whisper_live.server import TranscriptionServer, BackendType from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient from whisper.normalizers import EnglishTextNormalizer @@ -49,7 +49,7 @@ def test_connection(self, mock_websocket): 'task': 'transcribe', 'model': 'tiny.en' }) - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType.FASTER_WHISPER) @mock.patch('websockets.WebSocketCommonProtocol') def test_recv_audio_exception_handling(self, mock_websocket): @@ -61,7 +61,7 @@ def test_recv_audio_exception_handling(self, mock_websocket): }), np.array([1, 2, 3]).tobytes()] with self.assertLogs(level="ERROR"): - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType.FASTER_WHISPER) self.assertNotIn(mock_websocket, self.server.client_manager.clients) @@ -127,7 +127,7 @@ def test_connection_closed_exception(self, mock_websocket): mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed") with self.assertLogs(level="INFO") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType.FASTER_WHISPER) self.assertTrue(any("Connection closed by client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -135,7 +135,7 @@ def test_json_decode_exception(self, mock_websocket): mock_websocket.recv.return_value = "invalid json" with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType.FASTER_WHISPER) self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output)) @mock.patch('websockets.WebSocketCommonProtocol') @@ -143,7 +143,7 @@ def test_unexpected_exception_handling(self, mock_websocket): mock_websocket.recv.side_effect = RuntimeError("Unexpected error") with self.assertLogs(level="ERROR") as log: - self.server.recv_audio(mock_websocket, "faster_whisper") + self.server.recv_audio(mock_websocket, BackendType.FASTER_WHISPER) for message in log.output: print(message) print() diff --git a/whisper_live/client.py b/whisper_live/client.py index a4bd0f67..be0ef470 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -28,7 +28,8 @@ def __init__( translate=False, model="small", srt_file_path="output.srt", - use_vad=True + use_vad=True, + secure_connection=False ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -63,8 +64,10 @@ def __init__( self.timestamp_offset = 0.0 self.audio_bytes = None - if host is not None and port is not None: - socket_url = f"ws://{host}:{port}" + if host is not None: + protocol = "wss" if secure_connection else "ws" + port = f":{port}" if port else "" + socket_url = f"{protocol}{host}{port}" self.client_socket = websocket.WebSocketApp( socket_url, on_open=lambda ws: self.on_open(ws), @@ -75,7 +78,7 @@ def __init__( ), ) else: - print("[ERROR]: No host or port specified.") + print("[ERROR]: No host specified.") return Client.INSTANCES[self.uid] = self