Skip to content

Commit

Permalink
Swap out socket for HTTP server (#5)
Browse files Browse the repository at this point in the history
* refactor(app): generalize the control interface to prepare for web

* add(deps): actix-web and tokio

* feat(web): add HTTP interface

* fix(web): respect sample rate sent in start request
  • Loading branch information
dev-msp committed Apr 18, 2024
1 parent f28d026 commit 49b9ff2
Show file tree
Hide file tree
Showing 10 changed files with 1,230 additions and 172 deletions.
918 changes: 910 additions & 8 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ crossbeam = "0.8.4"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
rubato = "0.15.0"
actix-web = "4.5.1"
tokio = { version = "1.37.0", features = ["full"] }
17 changes: 8 additions & 9 deletions src/app/command.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use crossbeam::channel::Receiver;
use itertools::Itertools;

use super::{
response::Response,
state::{Mode, State},
state::{Mode, RecordingSession, State},
};

#[derive(Debug, Clone, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum Command {
#[serde(rename = "start")]
Start,
Start(RecordingSession),

#[serde(rename = "stop")]
Stop, // need timestamp?
Expand All @@ -35,22 +34,22 @@ impl Command {
}
}

pub struct CmdStream(Receiver<serde_json::Value>);
pub struct CmdStream(Receiver<Command>);

impl CmdStream {
pub fn new(recv: Receiver<serde_json::Value>) -> Self {
pub fn new(recv: Receiver<Command>) -> Self {
Self(recv)
}

pub fn iter(&mut self) -> impl Iterator<Item = Result<Command, serde_json::Error>> + '_ {
self.0.iter().map(serde_json::from_value)
pub fn iter(&mut self) -> impl Iterator<Item = Command> + '_ {
self.0.iter()
}

pub fn run_state_machine<'a>(
&'a mut self,
state: &'a mut super::state::State,
) -> impl Iterator<Item = Result<(Command, Option<State>), serde_json::Error>> + 'a {
self.iter().map_ok(move |cmd| {
) -> impl Iterator<Item = (Command, Option<State>)> + 'a {
self.iter().map(move |cmd| {
log::debug!("Received command: {:?}", cmd);
log::trace!("Current state: {:?}", state);
let initial = state.clone();
Expand Down
97 changes: 38 additions & 59 deletions src/app/mod.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
pub mod command;
mod response;
mod state;
pub mod response;
pub mod state;

use std::time::SystemTime;

use anyhow::anyhow;
use cpal::Device;
use crossbeam::channel::unbounded;
use crossbeam::channel::{unbounded, Receiver, Sender};
use sttx::IteratorExt;

use crate::audio::input::{controlled_recording, Recording};
use crate::whisper::TranscriptionJob;
use crate::{socket::receive_instructions, sync, whisper, DaemonInit};
use crate::{sync, whisper, DaemonInit};

use self::command::{CmdStream, Command};
use self::response::Response;

pub struct Daemon {
config: DaemonInit,
input_device: Option<Device>,
state: state::State,
}

Expand All @@ -38,63 +38,51 @@ impl Transcription {
}

impl Daemon {
pub fn new(config: DaemonInit, input_device: Option<Device>) -> Self {
pub fn new(config: DaemonInit) -> Self {
Self {
config,
input_device,
state: state::State::default(),
}
}

/// Runs the main application loop.
///
pub fn run_loop(&mut self) -> Result<bool, anyhow::Error> {
pub fn run_loop(
&mut self,
commands: Receiver<Command>,
responses: Sender<Response>,
) -> Result<bool, anyhow::Error> {
let model = self.config.model.clone();
let device = self
.input_device
.as_ref()
.ok_or_else(|| anyhow!("No input device"))?;

let (to_whisper, from_recordings) = unbounded();
let (whisper_output, tx_worker) = whisper::transcription_worker(&model, from_recordings)?;

let (rcmds, resps, listener) = receive_instructions(&self.config.socket_path)?;
let mut commands = CmdStream::new(rcmds);
let mut commands = CmdStream::new(commands);

#[allow(unused_assignments)]
let mut exit_code = 0_u8;

let mut rec: Option<Recording<_, Vec<f32>>> = None;

for result in commands.run_state_machine(&mut self.state) {
let (ref command, ref new_state) = match result {
Ok((c, s)) => (c, s),
Err(e) => {
log::error!("{e}");
resps.send(Response::Error(e.to_string()).as_json())?;
continue;
}
};
for (ref command, ref new_state) in commands.run_state_machine(&mut self.state) {
let Some(new_state) = new_state else {
resps.send(Response::Nil.as_json())?;
responses.send(Response::Nil)?;
continue;
};

// This handles the state condition where rec must exist.
if new_state.running() && rec.is_none() {
rec = Some(controlled_recording(
device,
sync::ProcessNode::new(|it| it.collect::<Vec<_>>()),
));
}

match command {
Command::Start => {
assert!(rec.is_some());
Command::Start(session) => {
assert!(new_state.running());

rec = Some(controlled_recording(
session.clone(),
sync::ProcessNode::new(|it| it.collect::<Vec<_>>()),
));

rec.as_mut().unwrap().start();
resps.send(Response::Ack.as_json())?;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis();
responses.send(Response::Ack(now))?;
log::debug!("Successfully sent ACK");
}

Expand Down Expand Up @@ -126,23 +114,17 @@ impl Daemon {
log::info!("No transcription");
}

resps.send(
Response::Transcription {
content: t.map(|t| t.content().to_string()),
mode: new_state.mode(),
}
.as_json(),
)?;
responses.send(Response::Transcription {
content: t.map(|t| t.content().to_string()),
mode: new_state.mode(),
})?;
}
Err(e) => {
log::error!("{e}");
resps.send(
Response::Transcription {
content: None,
mode: new_state.mode(),
}
.as_json(),
)?;
responses.send(Response::Transcription {
content: None,
mode: new_state.mode(),
})?;
exit_code = 1;
}
}
Expand All @@ -153,23 +135,20 @@ impl Daemon {
}
c @ Command::Mode(_) => {
assert!(!new_state.running());
resps.send(c.as_response().unwrap_or(Response::Ack).as_json())?;
responses.send(c.as_response().unwrap_or_else(Response::ack))?;
}
Command::Respond(response) => {
log::info!("Responding with: {:?}", response);
resps.send(response.as_json())?;
responses.send(response.clone())?;
}
}
}

resps.send(
serde_json::to_value(Response::Exit(exit_code)).expect("Failed to serialize response"),
)?;
responses.send(Response::Exit(exit_code))?;
// Done responding
drop(resps);
drop(responses);

tx_worker.join().unwrap()?;
listener.join().unwrap()?;

// remove socket
std::fs::remove_file(&self.config.socket_path)?;
Expand Down
20 changes: 12 additions & 8 deletions src/app/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::state::{Chat, Mode};
#[serde(tag = "type", content = "data")]
pub enum Response {
#[serde(rename = "ack")]
Ack,
Ack(u128),

#[serde(rename = "nil")]
Nil,
Expand All @@ -22,6 +22,16 @@ pub enum Response {
Transcription { content: Option<String>, mode: Mode },
}

impl Response {
pub fn ack() -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
Self::Ack(now)
}
}

impl From<sttx::Timing> for Response {
fn from(t: sttx::Timing) -> Self {
Self::Transcription {
Expand All @@ -34,7 +44,7 @@ impl From<sttx::Timing> for Response {
impl std::fmt::Display for Response {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Ack => write!(f, "ACK"),
Self::Ack(n) => write!(f, "ACK {}", n),
Self::Nil => write!(f, "NIL"),
Self::Error(s) => write!(f, "ERROR {}", s),
Self::Exit(code) => write!(f, "EXIT {}", code),
Expand All @@ -54,9 +64,3 @@ impl std::fmt::Display for Response {
}
}
}

impl Response {
pub fn as_json(&self) -> serde_json::Value {
serde_json::to_value(self).unwrap()
}
}
Loading

0 comments on commit 49b9ff2

Please sign in to comment.