Skip to content

Commit

Permalink
refactor(docs,cli): update support code around the HTTP interface (#6)
Browse files Browse the repository at this point in the history
- fail if we don't have exactly one of the socket path/serve options
  • Loading branch information
dev-msp committed Apr 19, 2024
1 parent 49b9ff2 commit edb59d4
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 45 deletions.
29 changes: 15 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,29 @@

Uses [`whisper.cpp`](https://github.com/ggerganov/whisper.cpp) by way of [`whisper-rs`](https://github.com/tazz4843/whisper-rs).

> ⚠️ DISCLAIMER - this probably does not work with MBP mics. Whisper expects 16k
> sample rate, but Macbook mics don't have that option and I'm not handling
> downsampling yet. Use a mic that directly supports 16k in the meantime.
Bind to hotkeys as you prefer. I'm running the server as a daemon with `launchd`.
Bind requests to hotkeys as you prefer. I'm running the server as a daemon with `launchd`.

If you don't already have a `whisper.cpp`-compatible model, follow that project's [quick-start instructions](https://github.com/ggerganov/whisper.cpp#quick-start) to get one.

Start the server:
`./run.sh macbook ggml-base.en.bin /tmp/whisper.sock`

Start recording:

`$ echo -n "{\"type\": \"start\"}" | socat -t2 - /tmp/whisper.sock`
Start recording: `curl -X POST -H "Content-Type: application/json" -d "$body" "http://127.0.0.1:8088/voice/$1"`

Output: `{ "type": "ack" }`

Stop recording:
Example request body:
```json
{
// partial name matches OK
"input_device": "MacBook Pro Microphone",

`echo -n "{\"type\": \"stop\"}" | socat -t2 - /tmp/whisper.sock`
// optional
"sample_rate": 44100,
}
```
Response: `{ "type": "ack" }`

Output:
Stop recording: `curl -X POST http://localhost:8088/voice/stop`
Example response:
```json
{
"data": {
Expand All @@ -36,4 +37,4 @@ Output:
}
```

Note: the modes that you get back in the output are just metadata. Your client application that handles reading from the socket should also handle processing the transcription differently based on the mode.
Note: the modes that you get back in the output are just metadata. Your client application that talks to the server should also handle processing the transcription differently based on the mode.
13 changes: 4 additions & 9 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@ fi
# So you'll still see those in the log output, even if you're suppressing.
export RUST_LOG=whisper_sys_log=error,voice=debug

# Can be a substring of the input device name you want to use? For instance,
# "macbook" for "MacBook Pro Microphone".
device_name="$1"

# See the whisper.cpp repo for details on how to get a model. I recommend using
# base or small for best results.
model_path="$2"

# The Unix socket the program will read and write to.
socket_path="$3"
# The address on which the HTTP server should listen (e.g. localhost:PORT)
addr="$3"

./voice run-daemon \
--socket-path "$socket_path" \
--model "$model_path" \
--device-name "$device_name"
--serve "$addr"
--model "$model_path"
4 changes: 3 additions & 1 deletion src/app/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ impl Daemon {
tx_worker.join().unwrap()?;

// remove socket
std::fs::remove_file(&self.config.socket_path)?;
if let Some(ref p) = self.config.socket_path {
std::fs::remove_file(p)?;
}
Ok(false)
}
}
4 changes: 2 additions & 2 deletions src/app/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use super::command::Command;
#[derive(Debug, Default, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct RecordingSession {
input_device: String,
sample_rate: u32,
sample_rate: Option<u32>,
}

impl RecordingSession {
pub fn device_name(&self) -> &str {
&self.input_device
}

pub fn sample_rate(&self) -> u32 {
pub fn sample_rate(&self) -> Option<u32> {
self.sample_rate
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/audio/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ where
c
})
.find_map(|c| {
if c.min_sample_rate() > SampleRate(session.sample_rate())
|| c.max_sample_rate() < SampleRate(session.sample_rate())
{
let sample_rate = session
.sample_rate()
.map(SampleRate)
.unwrap_or_else(|| c.min_sample_rate().max(SampleRate(16000)));
if c.min_sample_rate() > sample_rate || c.max_sample_rate() < sample_rate {
None
} else {
let min_sample_rate = c.min_sample_rate();
Expand Down
37 changes: 23 additions & 14 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ struct DaemonInit {

/// Socket path
#[clap(long)]
socket_path: String,
socket_path: Option<String>,

#[clap(long)]
serve: bool,
#[clap(long, value_parser = web::parse_addr_option)]
serve: Option<(String, u16)>,
}

impl DaemonInit {
Expand All @@ -62,12 +62,12 @@ fn run_daemon(
daemon.run_loop(commands, responses)
}

async fn run_web_server(app: DaemonInit) -> std::io::Result<bool> {
async fn run_web_server(addr: (String, u16), app: DaemonInit) -> std::io::Result<bool> {
let (commands_out, commands_in) = crossbeam::channel::bounded(1);
let (responses_out, responses_in) = crossbeam::channel::bounded(1);
let handle = spawn_blocking(|| run_daemon(app, commands_in, responses_out));

let server = web::run(commands_out, responses_in);
let server = web::run(addr, commands_out, responses_in);

tokio::select! {
app_finished = handle => {
Expand Down Expand Up @@ -114,16 +114,25 @@ async fn main() -> Result<(), anyhow::Error> {
Ok(())
}
Commands::RunDaemon(app) => {
let _ = if app.serve {
run_web_server(app).await?
} else {
let (rcmds, resps, listener) =
socket::receive_instructions(app.socket_path.clone())?;
let outcome = run_daemon(app, rcmds, resps)?;
listener.join().expect("failed to join listener thread")?;
outcome
let should_reset = match (&app.socket_path, &app.serve) {
(None, Some(a)) => run_web_server(a.clone(), app).await?,
(Some(p), None) => {
let (rcmds, resps, listener) = socket::receive_instructions(p.clone())?;
let outcome = run_daemon(app, rcmds, resps)?;
listener.join().expect("failed to join listener thread")?;
outcome
}
_ => {
log::error!(
"Invalid arguments: socket path and serve flag cannot be used together but at least one must be provided"
);
true
}
};
std::process::exit(1);
if should_reset {
std::process::exit(1);
}
Ok(())
}
}
}
16 changes: 14 additions & 2 deletions src/web/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use actix_web::{
web::{self, Data},
App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use anyhow::anyhow;
use crossbeam::channel::{Receiver, Sender};
use serde::Serialize;

Expand Down Expand Up @@ -67,7 +68,11 @@ async fn set_mode(app: AppChannel, mode: web::Json<Mode>) -> impl Responder {
ApiResponder { content: response }
}

pub async fn run(commands: Sender<Command>, responses: Receiver<Response>) -> std::io::Result<()> {
pub async fn run<A: std::net::ToSocketAddrs>(
addr: A,
commands: Sender<Command>,
responses: Receiver<Response>,
) -> std::io::Result<()> {
let server = HttpServer::new(move || {
let voice = web::scope("/voice")
.service(start)
Expand All @@ -77,9 +82,16 @@ pub async fn run(commands: Sender<Command>, responses: Receiver<Response>) -> st

App::new().wrap(Logger::default()).service(voice)
})
.bind(("localhost", 8088))?;
.bind(addr)?;

let handle = server.run().await;
log::warn!("Server finished?");
handle
}

pub fn parse_addr_option(s: &str) -> Result<(String, u16), anyhow::Error> {
let mut parts = s.split(':');
let host = parts.next().ok_or(anyhow!("no host provided"))?;
let port = parts.next().ok_or(anyhow!("no port provided"))?.parse()?;
Ok((host.to_string(), port))
}

0 comments on commit edb59d4

Please sign in to comment.