Skip to content

Commit

Permalink
feat: load model on init
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jun 21, 2024
1 parent 3ff1263 commit 166b5f5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
25 changes: 19 additions & 6 deletions src-tauri/src/cmd.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
use crate::{downloader::Downloader, unzip};
use crate::{downloader::Downloader, unzip, TranslatorHandle};
use ct2rs::Translator;
use eyre::{eyre, ContextCompat, Result};
use std::{fs, path::PathBuf, time};
use tauri::{AppHandle, Manager};
use tauri::{AppHandle, Manager, State};
use tauri_plugin_shell::ShellExt;

#[tauri::command]
pub async fn translate(language: String, text: String, model_path: String) -> Result<Vec<(String, Option<f32>)>> {
log::debug!("translate with {} {} {}", language, model_path, text);
let cfg = ct2rs::config::Config::default();
let t = Translator::new(&model_path, &cfg).map_err(|e| eyre!("{:?}", e))?;
pub async fn translate(
language: String,
text: String,
translator_handle: State<'_, TranslatorHandle>,
) -> Result<Vec<(String, Option<f32>)>> {
log::debug!("translate with {} {}", language, text);
let sources: Vec<String> = text.lines().map(String::from).collect();
let target_prefixes = vec![vec![language]; sources.len()];
let now = time::Instant::now();
let t = translator_handle.lock().await;
let t = t.as_ref().unwrap();
let res = t
.translate_batch_with_target_prefix(&sources, &target_prefixes, &Default::default(), None)
.map_err(|e| eyre!("{:?}", e))?;
Expand Down Expand Up @@ -77,3 +81,12 @@ pub async fn open_models_folder(app: AppHandle) -> Result<()> {
app.shell().open(local_data.to_str().context("tostr")?, None)?;
Ok(())
}

#[tauri::command]
pub async fn load_model(model_path: String, translator_handle: State<'_, TranslatorHandle>) -> Result<()> {
let cfg = ct2rs::config::Config::default();
let t = Translator::new(model_path, &cfg).map_err(|e| eyre!("{:?}", e))?;
let mut translator_handle = translator_handle.lock().await;
*translator_handle = Some(t);
Ok(())
}
9 changes: 8 additions & 1 deletion src-tauri/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]

use ct2rs::{auto::Tokenizer, Translator};
use tokio::sync::Mutex;

mod cmd;
mod downloader;
mod unzip;

pub type TranslatorHandle = Mutex<Option<Translator<Tokenizer>>>;

fn main() {
env_logger::init();
tauri::Builder::default()
.plugin(tauri_plugin_shell::init())
.manage(Mutex::new(None) as TranslatorHandle)
.invoke_handler(tauri::generate_handler![
cmd::translate,
cmd::download_model,
cmd::get_model_path,
cmd::open_models_folder
cmd::open_models_folder,
cmd::load_model
])
.run(tauri::generate_context!())
.expect("error while running tauri application");
Expand Down
9 changes: 5 additions & 4 deletions src/pages/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ export default function Home() {
const [srcText, setSrcText] = useState("");
const [dstText, setDstText] = useState("");
const [isLoading, setIsLoading] = useState(false);
const [language, setLanguage] = useLocalStorage("prefs_language", languages["english"]);
const [language, setLanguage] = useLocalStorage(
"prefs_language",
languages["english"]
);
const navigate = useNavigate();
const [modelPath, setModelPath] = useState<null | string>(null);

async function translate() {
setIsLoading(true);
const resp = await invoke<TranslateResponse>("translate", {
language,
modelPath,
text: srcText,
});
console.log("resp => ", resp);
Expand All @@ -40,7 +41,7 @@ export default function Home() {
if (!modelPathResult) {
navigate("/setup");
}
setModelPath(modelPathResult);
await invoke("load_model", { modelPath: modelPathResult });
}

useEffect(() => {
Expand Down

0 comments on commit 166b5f5

Please sign in to comment.