diff --git a/Cargo.lock b/Cargo.lock index 399b7e5..dc58d15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -692,6 +692,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -703,6 +712,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -807,6 +825,15 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "primal-check" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -841,6 +868,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "realfft" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953d9f7e5cdd80963547b456251296efc2626ed4e3cbf36c869d9564e0220571" +dependencies = [ + "rustfft", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -879,6 +915,18 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "rubato" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d18b486e7d29a408ef3f825bc1327d8f87af091c987ca2f5b734625940e234" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -894,6 +942,21 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43806561bc506d0c5d160643ad742e3161049ac01027b5e6d7524091fd401d86" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", + "version_check", +] + [[package]] name = "rustix" version = "0.38.32" @@ -1050,6 +1113,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213701ba3370744dcd1a12960caa4843b3d68b4d1c0a5d575e0d65b2ee9d16c0" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" @@ -1159,6 +1228,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -1171,6 +1250,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "voice" version = "0.1.0" @@ -1184,6 +1269,7 @@ dependencies = [ "itertools", "log", "regex", + "rubato", "serde", "serde_json", "sttx", diff --git a/Cargo.toml b/Cargo.toml index 428ff88..26fbb1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ log = "0.4.21" crossbeam = "0.8.4" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" +rubato = "0.15.0" diff --git a/run.sh b/run.sh index 28c3757..6020ea7 100755 --- a/run.sh +++ b/run.sh @@ -25,7 +25,7 @@ model_path="$2" # The Unix socket the program will read and write to. socket_path="$3" -./voice \ +./voice run-daemon \ --socket-path "$socket_path" \ --model "$model_path" \ --device-name "$device_name" diff --git a/src/app/mod.rs b/src/app/mod.rs index 9a5a2e3..9837715 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -12,8 +12,8 @@ use sttx::{IteratorExt, Timing}; use crate::app::input::iter::alpha_only; use crate::audio::input::{controlled_recording, Recording}; -use crate::sync; -use crate::{socket::receive_instructions, whisper, App}; +use crate::whisper::TranscriptionJob; +use crate::{socket::receive_instructions, sync, whisper, DaemonInit}; use self::command::{CmdStream, Command}; use self::input::iter; @@ -21,7 +21,7 @@ use self::response::Response; use self::state::{Chat, Mode}; pub struct Daemon { - app: App, + config: DaemonInit, input_device: Option, state: state::State, } @@ -139,9 +139,9 @@ fn handle_hey_robot(content: &str) -> Option { } impl Daemon { - pub fn new(app: App, input_device: Option) -> Self { + pub fn new(config: DaemonInit, input_device: Option) -> Self { Self { - app, + config, input_device, state: state::State::default(), } @@ -150,23 +150,22 @@ impl Daemon { /// Runs the main application loop. /// pub fn run_loop(&mut self) -> Result { - let model = self.app.model.clone(); + let model = self.config.model.clone(); let device = self .input_device .as_ref() .ok_or_else(|| anyhow!("No input device"))?; let (to_whisper, from_recordings) = unbounded(); - let (whisper_output, tx_worker) = - whisper::transcription_worker(&model, self.app.strategy(), from_recordings)?; + let (whisper_output, tx_worker) = whisper::transcription_worker(&model, from_recordings)?; - let ((rcmds, scmds), resps, listener) = receive_instructions(&self.app.socket_path)?; + let ((rcmds, scmds), resps, listener) = receive_instructions(&self.config.socket_path)?; let mut commands = CmdStream::new(rcmds); #[allow(unused_assignments)] let mut exit_code = 0_u8; - let mut rec: Option>> = None; + let mut rec: Option>> = None; for result in commands.run_state_machine(&mut self.state) { let (ref command, ref new_state) = match result { @@ -204,8 +203,12 @@ impl Daemon { assert!(rec.is_some()); assert!(!new_state.running()); - let audio = rec.take().unwrap().stop()?; - to_whisper.send(audio)?; + let (metadata, audio) = rec.take().unwrap().stop()?; + to_whisper.send(TranscriptionJob::new( + audio, + self.config.strategy(), + metadata.sample_rate.0 as i32, + ))?; let now = std::time::Instant::now(); let transcription = whisper_output @@ -276,7 +279,7 @@ impl Daemon { listener.join().unwrap()?; // remove socket - std::fs::remove_file(&self.app.socket_path)?; + std::fs::remove_file(&self.config.socket_path)?; Ok(false) } } diff --git a/src/audio/input.rs b/src/audio/input.rs index 6e8e0ce..5eef05a 100644 --- a/src/audio/input.rs +++ b/src/audio/input.rs @@ -1,16 +1,17 @@ use std::{ fmt::{Debug, Display}, ops::Deref, - sync::{mpsc, Arc, Condvar, Mutex}, + sync::{Arc, Condvar, Mutex}, thread, }; use anyhow::anyhow; use cpal::{ traits::{DeviceTrait, HostTrait, StreamTrait}, - Sample, + StreamConfig, }; use crossbeam::channel::Sender; +use rubato::{FftFixedOut, Resampler, ResamplerConstructionError}; use whisper_rs::convert_stereo_to_mono_audio; #[derive(Debug, Clone)] @@ -85,8 +86,8 @@ impl Controller { } } -pub trait MySample: Send + hound::Sample + Sample + 'static {} -impl MySample for S where S: Send + hound::Sample + Sample + 'static {} +pub trait MySample: Send + hound::Sample + cpal::Sample + 'static {} +impl MySample for S where S: Send + hound::Sample + cpal::Sample + 'static {} pub struct Recording where @@ -94,7 +95,7 @@ where RS: Send, RE: Send, { - handle: thread::JoinHandle>, + handle: thread::JoinHandle>, controller: Controller, phantom: std::marker::PhantomData, receiving_handle: thread::JoinHandle, @@ -107,6 +108,7 @@ where { #[error("failed to join recording thread")] Sync, + #[error("{0}")] Other(E), } @@ -114,8 +116,8 @@ where impl Recording where S: MySample, - RS: Send, - RE: Send + Sync + Display, + RS: Send + Into>, + RE: Send + Sync + Display + Debug, { #[tracing::instrument(skip(self))] pub fn start(&self) { @@ -124,21 +126,28 @@ where log::info!("Recording started"); } - pub fn stop(self) -> Result> { + pub fn stop(self) -> Result<(StreamConfig, Vec), RecordingError> { self.controller.stop(); - let _ = Self::join_handle(self.handle)?; + let metadata = Self::join_handle(self.handle) + .map_err(|e| { + log::error!("Error joining recording thread: {}", e); + RecordingError::::Sync + }) + .unwrap() + .map_err(|e| { + log::error!("Error joining recording thread: {}", e); + RecordingError::Other(e) + })?; - Self::join_handle(self.receiving_handle) + let audio = Self::join_handle(self.receiving_handle)?.into(); + Ok((metadata, audio)) } - fn join_handle(handle: thread::JoinHandle) -> Result> - where - E: Display + Send + Sync + 'static, - { + fn join_handle(handle: thread::JoinHandle) -> Result> { match handle.join() { Ok(inner) => Ok(inner), Err(e) => { - let inner: Box = e.downcast::().map_err(|_| RecordingError::Sync)?; + let inner: Box = e.downcast::().map_err(|_| RecordingError::Sync)?; Err(RecordingError::Other(*inner)) } } @@ -183,9 +192,9 @@ pub fn record_from_input_device( device_name: String, chan: Sender, controller: Controller, -) -> Result<(), anyhow::Error> +) -> Result where - S: Send + hound::Sample + Sample + 'static, + S: MySample, { let device = host .input_devices()? @@ -203,53 +212,154 @@ where ); c }) - .find_map(|c| { - (c.max_sample_rate().0 >= 16000 && c.min_sample_rate().0 <= 16000) - .then(|| c.with_sample_rate(cpal::SampleRate(16000))) + .map(|c| { + let min_sample_rate = c.min_sample_rate(); + c.with_sample_rate(min_sample_rate) }) + .next() .ok_or(anyhow!("no supported input configuration"))?; controller.wait_for(RecordState::Started); + let cpal::SampleFormat::F32 = supported_config.sample_format() else { + panic!("unsupported sample format"); + }; + + let cfg: cpal::StreamConfig = supported_config.clone().into(); + let cfg_inner = cfg.clone(); + let resampler = Processor::::new(chan.clone(), cfg_inner.clone(), 512) + .expect("failed to create resampler"); + + let resampler = Arc::new(Mutex::new(resampler)); { - let stream = match supported_config.sample_format() { - cpal::SampleFormat::F32 => { - let cfg = supported_config.clone().into(); - let is_mono = supported_config.channels() == 1; - device.build_input_stream( - &cfg, - move |data, _| { - if !is_mono { - let mono_data = convert_stereo_to_mono_audio(data) - .expect("failed to convert stereo to mono"); - write_input_data::(&mono_data, chan.clone()) - } else { - write_input_data::(data, chan.clone()) - } - .expect("failed to write data") - }, - move |err| log::trace!("an error occurred on stream: {}", err), - )? - } - _ => panic!("unsupported sample format"), - }; + let resampler_send = resampler.clone(); + let stream = device.build_input_stream( + &cfg, + move |data, _| { + resampler_send + .lock() + .expect("failed to lock resampler") + .write_input_data(data) + .expect("failed to write data"); + }, + move |err| log::trace!("an error occurred on stream: {}", err), + )?; stream.play()?; controller.recording(); controller.wait_for(RecordState::Stopped); } - Ok(()) + let mut resampler = resampler.lock().expect("failed to lock resampler"); + resampler.flush_to_sink()?; + Ok(cfg) +} + +struct Processor +where + S: cpal::Sample + rubato::Sample, +{ + config: cpal::StreamConfig, + resampler: FftFixedOut, + buffer: Vec, + sink: Sender, } -fn write_input_data(input: &[T], chan: Sender) -> Result<(), mpsc::SendError> +impl Processor where - T: Sample, - U: Sample + hound::Sample, + S: cpal::Sample + rubato::Sample, + O: MySample, { - for &sample in input.iter() { - let sample: U = U::from(&sample); - let _ = chan.send(sample); + fn new( + sink: Sender, + config: cpal::StreamConfig, + chunk_size_out: usize, + ) -> Result { + let input_rate = config.sample_rate.0 as usize; + let channels = config.channels as usize; + log::debug!( + "Creating resampler with input rate: {}, chunk size: {}, channels: {}", + input_rate, + chunk_size_out, + channels + ); + + // Set to next multiple of 512 + let chunk_size_out = (chunk_size_out + 511) & !511; + let resampler = FftFixedOut::new(input_rate, 16_000, chunk_size_out, 1, channels)?; + + Ok(Self { + buffer: Vec::with_capacity(resampler.nbr_channels() * resampler.input_frames_next()), + config, + resampler, + sink, + }) + } +} + +impl Processor { + fn write_input_data(&mut self, input: &[f32]) -> Result<(), anyhow::Error> { + log::trace!( + "Got input data: {}, remaining in buffer: {}", + input.len(), + self.buffer.capacity() - self.buffer.len() + ); + let remaining = self.consume_input_data(input); + if self.buffer.len() == self.buffer.capacity() { + self.flush_to_sink()?; + self.consume_input_data(remaining); + log::trace!( + "Buffer full to its capacity of {}, remaining: {}", + self.buffer.capacity(), + remaining.len() + ); + assert!(self.buffer.len() <= self.buffer.capacity()); + Ok(()) + } else { + Ok(()) + } + } + + fn consume_input_data<'a>(&mut self, input: &'a [f32]) -> &'a [f32] { + let remaining = self.buffer.capacity() - self.buffer.len(); + let cutoff = remaining.min(input.len()); + self.buffer.extend_from_slice(&input[0..cutoff]); + + &input[cutoff..] + } + + fn input_buffer(&self) -> Vec { + let cap = self.resampler.nbr_channels() * self.resampler.input_frames_max(); + log::trace!("Allocating input buffer with capacity: {}", cap); + Vec::with_capacity(cap) } - Ok(()) + fn flush_to_sink(&mut self) -> Result<(), anyhow::Error> { + if self.buffer.len() < self.buffer.capacity() { + log::trace!("Buffer not full, returning"); + return Ok(()); + } + let mut data = self.input_buffer(); + std::mem::swap(&mut self.buffer, &mut data); + + log::trace!("Data length: {}", data.len()); + + if self.config.sample_rate != cpal::SampleRate(16_000) { + let mut output = self.resampler.output_buffer_allocate(true); + self.resampler + .process_into_buffer(&[data], &mut output, None)?; + data = output.first().cloned().expect("no output from resampler"); + log::trace!("Data length after resampling: {}", data.len()); + }; + + if self.config.channels as usize != 1 { + data = convert_stereo_to_mono_audio(&data).expect("failed to convert stereo to mono"); + log::trace!("Data length after stereo to mono: {}", data.len()); + }; + + for sample in data { + let sample: O = O::from(&sample); + let _ = self.sink.send(sample); + } + Ok(()) + } } diff --git a/src/main.rs b/src/main.rs index 1d19b2e..b398fc1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,20 @@ use whisper_rs::install_whisper_log_trampoline; use crate::app::Daemon; #[derive(Debug, clap::Parser)] +#[command(version, about, long_about = None)] struct App { + #[command(subcommand)] + command: Commands, +} + +#[derive(Debug, clap::Subcommand)] +enum Commands { + ListChannels, + RunDaemon(DaemonInit), +} + +#[derive(Debug, clap::Args)] +struct DaemonInit { /// Path to the model file #[clap(short, long)] model: String, @@ -30,7 +43,7 @@ struct App { socket_path: String, } -impl App { +impl DaemonInit { pub fn strategy(&self) -> whisper_rs::SamplingStrategy { self.strategy.clone().unwrap_or_default().into() } @@ -52,23 +65,50 @@ fn main() -> Result<(), anyhow::Error> { env_logger::init(); - let app = App::parse(); - log::info!("Launching with settings: {:?}", app); - - let device = match &app.device_name { - Some(n) => device_matching_name(n)?, - None => cpal::default_host() - .default_input_device() - .ok_or(anyhow!("no input device available"))?, - }; - - log::info!("Found device: {:?}", device.name()?); - - let mut daemon = Daemon::new(app, Some(device)); - let should_reset = daemon.run_loop()?; - if should_reset { - std::process::exit(1); + match App::parse().command { + Commands::ListChannels => { + let host = cpal::default_host(); + for (device_name, config) in host.input_devices()?.flat_map(|d| { + let name = d.name().expect("failed to get device name"); + d.supported_input_configs() + .unwrap() + .map(move |c| (name.clone(), c)) + }) { + let (buf_floor, buf_ceil) = match config.buffer_size() { + cpal::SupportedBufferSize::Range { min, max } => (min, max), + cpal::SupportedBufferSize::Unknown => unimplemented!(), + }; + println!( + "{}: sample_rate:{}-{}, sample_format:{:?}, channels:{}, buffer_size: {}-{}", + device_name, + config.min_sample_rate().0, + config.max_sample_rate().0, + config.sample_format(), + config.channels(), + buf_floor, + buf_ceil + ); + } + return Ok(()); + } + Commands::RunDaemon(app) => { + log::info!("Launching with settings: {:?}", app); + + let device = match &app.device_name { + Some(n) => device_matching_name(n)?, + None => cpal::default_host() + .default_input_device() + .ok_or(anyhow!("no input device available"))?, + }; + + log::info!("Found device: {:?}", device.name()?); + + let mut daemon = Daemon::new(app, Some(device)); + let should_reset = daemon.run_loop()?; + if should_reset { + std::process::exit(1); + } + } } - Ok(()) } diff --git a/src/whisper.rs b/src/whisper.rs index 47c4aa9..133022b 100644 --- a/src/whisper.rs +++ b/src/whisper.rs @@ -1,53 +1,39 @@ use anyhow::anyhow; -use std::thread::JoinHandle; +use std::{thread::JoinHandle, time::Duration}; use crossbeam::channel::{unbounded, Receiver, SendError}; use itertools::Itertools; use sttx::Timing; -use whisper_rs::{convert_integer_to_float_audio, FullParams, WhisperContext, WhisperError}; +use whisper_rs::{FullParams, WhisperContext, WhisperError}; pub struct Whisper { context: WhisperContext, - strategy: whisper_rs::SamplingStrategy, } impl Whisper { - pub fn new( - model_path: &str, - strategy: whisper_rs::SamplingStrategy, - ) -> Result { + pub fn new(model_path: &str) -> Result { let mut params = whisper_rs::WhisperContextParameters::default(); params.use_gpu(true); let context = whisper_rs::WhisperContext::new_with_params(model_path, params)?; - Ok(Self { context, strategy }) + Ok(Self { context }) } pub fn create_state(&self) -> Result { self.context.create_state() } - pub fn transcribe_audio(&self, data: T) -> Result, WhisperError> - where - T: AsRef<[f32]>, - { + pub fn transcribe_audio( + &self, + job: TranscriptionJob, + ) -> Result, WhisperError> { let mut state = self.create_state()?; - let mut params = FullParams::new(self.strategy.clone()); - // params.set_audio_ctx({ - // let blen = data.as_ref().len(); - // let audio_secs = blen as f32 / 16000.0; - // log::debug!("audio_secs: {}", audio_secs); - // if audio_secs > 30.0 { - // 1500 - // } else { - // ((audio_secs / 30.0 * 1500.0) as i32).max(128) - // } - // }); + let mut params = FullParams::new(job.strategy); params.set_token_timestamps(true); params.set_max_len(1); params.set_split_on_word(true); - match state.full(params, data.as_ref()) { + match state.full(params, job.audio.as_slice()) { Ok(0) => {} Ok(n) => return Err(WhisperError::GenericError(n)), Err(e) => return Err(e), @@ -107,22 +93,41 @@ type WorkerHandle = ( JoinHandle>, ); +#[derive(Debug)] +pub struct TranscriptionJob { + audio: Vec, + strategy: whisper_rs::SamplingStrategy, + sample_rate: i32, +} + +impl TranscriptionJob { + pub fn new(audio: Vec, strategy: whisper_rs::SamplingStrategy, sample_rate: i32) -> Self { + Self { + audio, + strategy, + sample_rate, + } + } + + pub fn duration(&self) -> Duration { + Duration::from_secs_f32(self.audio.len() as f32 / self.sample_rate as f32) + } +} + pub fn transcription_worker( model: &str, - strategy: whisper_rs::SamplingStrategy, - jobs: Receiver>, + jobs: Receiver, ) -> Result { let (snd, recv) = unbounded(); - let whisper = Whisper::new(model, strategy)?; + let whisper = Whisper::new(model)?; Ok(( recv, std::thread::spawn(move || { - for audio in jobs.iter() { - let mut audio_fl = vec![0_f32; audio.len()]; - convert_integer_to_float_audio(&audio, &mut audio_fl)?; + for job in jobs.iter() { + log::debug!("Transcribing audio with duration: {:?}", job.duration()); let results = whisper - .transcribe_audio(audio_fl) + .transcribe_audio(job) .map_err(TranscriptionError::from); snd.send(results).map_err(Box::new)?; } @@ -139,7 +144,10 @@ pub enum StrategyOpt { impl Default for StrategyOpt { fn default() -> Self { - StrategyOpt::Greedy { best_of: 1 } + StrategyOpt::Beam { + beam_size: 5, + patience: 0.0, + } } }