-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox, QHBoxLayout, QGroupBox | ||
from PySide6.QtCore import Qt | ||
from ct2_logic import VoiceRecorder | ||
|
||
class MyWindow(QWidget): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
layout = QVBoxLayout(self) | ||
|
||
self.status_label = QLabel('', self) | ||
layout.addWidget(self.status_label) | ||
|
||
self.recorder = VoiceRecorder(self) | ||
|
||
for text, callback in [("Record", self.recorder.start_recording), | ||
("Stop and Copy to Clipboard", self.recorder.save_audio)]: | ||
button = QPushButton(text, self) | ||
button.clicked.connect(callback) | ||
layout.addWidget(button) | ||
|
||
settings_group = QGroupBox("Settings") | ||
settings_layout = QVBoxLayout() | ||
|
||
# Model, Quantization, and Device dropdowns | ||
h_layout = QHBoxLayout() | ||
|
||
model_label = QLabel('Model') | ||
h_layout.addWidget(model_label) | ||
self.model_dropdown = QComboBox(self) | ||
self.model_dropdown.addItems(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"]) | ||
h_layout.addWidget(self.model_dropdown) | ||
self.model_dropdown.setCurrentText("base.en") | ||
|
||
quantization_label = QLabel('Quantization') | ||
h_layout.addWidget(quantization_label) | ||
self.quantization_dropdown = QComboBox(self) | ||
self.quantization_dropdown.addItems(["int8", "int8_float32", "int8_float16", "int8_bfloat16", "float16", "bfloat16", "float32"]) | ||
h_layout.addWidget(self.quantization_dropdown) | ||
self.quantization_dropdown.setCurrentText("float32") | ||
|
||
device_label = QLabel('Device') | ||
h_layout.addWidget(device_label) | ||
self.device_dropdown = QComboBox(self) | ||
self.device_dropdown.addItems(["auto", "cpu", "cuda"]) | ||
h_layout.addWidget(self.device_dropdown) | ||
self.device_dropdown.setCurrentText("auto") | ||
|
||
settings_layout.addLayout(h_layout) | ||
|
||
# Update Model button | ||
update_model_btn = QPushButton("Update Settings", self) | ||
update_model_btn.clicked.connect(self.update_model) | ||
settings_layout.addWidget(update_model_btn) | ||
|
||
settings_group.setLayout(settings_layout) | ||
layout.addWidget(settings_group) | ||
|
||
self.setFixedSize(425, 250) | ||
self.setWindowFlag(Qt.WindowStaysOnTopHint) | ||
|
||
def update_model(self): | ||
self.recorder.update_model(self.model_dropdown.currentText(), self.quantization_dropdown.currentText(), self.device_dropdown.currentText()) | ||
|
||
def update_status(self, text): | ||
self.status_label.setText(text) | ||
|
||
if __name__ == "__main__": | ||
app = QApplication([]) | ||
window = MyWindow() | ||
window.show() | ||
app.exec() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pyaudio | ||
import wave | ||
import os | ||
import tempfile | ||
import threading | ||
import pyperclip | ||
from faster_whisper import WhisperModel | ||
|
||
class VoiceRecorder: | ||
def __init__(self, window, format=pyaudio.paInt16, channels=1, rate=44100, chunk=1024): | ||
self.format, self.channels, self.rate, self.chunk = format, channels, rate, chunk | ||
self.window = window | ||
self.is_recording, self.frames = False, [] | ||
self.update_model("base.en", "float32", "auto") | ||
|
||
def update_model(self, model_name, quantization_type, device_type): | ||
model_str = f"ctranslate2-4you/whisper-{model_name}-ct2-{quantization_type}" | ||
self.model = WhisperModel(model_str, device=device_type, compute_type=quantization_type, cpu_threads=10) | ||
self.window.update_status(f"Model updated to {model_name} with {quantization_type} quantization on {device_type} device") | ||
|
||
def transcribe_audio(self, audio_file): | ||
segments, _ = self.model.transcribe(audio_file) | ||
pyperclip.copy("\n".join([segment.text for segment in segments])) | ||
self.window.update_status("Audio saved and transcribed") | ||
|
||
def record_audio(self): | ||
self.window.update_status("Recording...") | ||
p = pyaudio.PyAudio() | ||
try: | ||
stream = p.open(format=self.format, channels=self.channels, rate=self.rate, input=True, frames_per_buffer=self.chunk) | ||
[self.frames.append(stream.read(self.chunk)) for _ in iter(lambda: self.is_recording, False)] | ||
stream.stop_stream() | ||
stream.close() | ||
finally: | ||
p.terminate() | ||
|
||
def save_audio(self): | ||
self.is_recording = False | ||
temp_filename = tempfile.mktemp(suffix=".wav") | ||
with wave.open(temp_filename, "wb") as wf: | ||
wf.setnchannels(self.channels) | ||
wf.setsampwidth(pyaudio.PyAudio().get_sample_size(self.format)) | ||
wf.setframerate(self.rate) | ||
wf.writeframes(b"".join(self.frames)) | ||
self.transcribe_audio(temp_filename) | ||
os.remove(temp_filename) | ||
self.frames.clear() | ||
|
||
def start_recording(self): | ||
if not self.is_recording: | ||
self.is_recording = True | ||
threading.Thread(target=self.record_audio).start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import sys | ||
from PySide6.QtWidgets import QApplication | ||
from ct2_gui import MyWindow | ||
|
||
if __name__ == "__main__": | ||
app = QApplication(sys.argv) | ||
app.setStyle('Fusion') | ||
window = MyWindow() | ||
window.show() | ||
sys.exit(app.exec()) |