diff --git a/package.json b/package.json index 7e9d0ee..c99efc6 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,7 @@ "name": "Bohan Cheng", "email": "cbh778899@outlook.com" }, - "version": "0.1.11", + "version": "0.1.12", "main": "electron.js", "scripts": { "dev": "npm run start & npm run electron", diff --git a/preloader/node-llama-cpp-preloader.js b/preloader/node-llama-cpp-preloader.js index e8e3466..7c4e7e3 100644 --- a/preloader/node-llama-cpp-preloader.js +++ b/preloader/node-llama-cpp-preloader.js @@ -1,4 +1,4 @@ -const { createWriteStream } = require("fs"); +const { createWriteStream, existsSync, statSync } = require("fs"); const path = require("path"); let llama, getLlama, LlamaChatSession, current_model; @@ -150,6 +150,17 @@ function downloadModel(url, cb=null) { return new Promise(resolve=>{ (async function() { const model_name = url.split('/').pop(); + const model_save_path = path.join(model_path, model_name) + + if(existsSync(model_save_path)) { + cb && cb(100, true); + const size = statSync(model_save_path).size + resolve({ + model_name, url, finish_time: Date.now(), + size + }) + return; + } const download_req = await fetch(url); if(!download_req.ok) { @@ -166,7 +177,7 @@ function downloadModel(url, cb=null) { } let downloaded = 0; - const write_stream = createWriteStream(path.join(model_path, model_name)) + const write_stream = createWriteStream(model_save_path) const middle_write_stream = new WritableStream({ write(chunk) { write_stream.write(chunk); diff --git a/src/components/chat/index.jsx b/src/components/chat/index.jsx index b9a3424..956fb3a 100644 --- a/src/components/chat/index.jsx +++ b/src/components/chat/index.jsx @@ -137,6 +137,7 @@ export default function Chat() { }, [chat]) return ( + settings.current ?
-
+ : + <> ) } \ No newline at end of file diff --git a/src/components/settings/AwsSettings.jsx b/src/components/settings/AwsSettings.jsx index 3d9f55f..9222e48 100644 --- a/src/components/settings/AwsSettings.jsx +++ b/src/components/settings/AwsSettings.jsx @@ -6,7 +6,7 @@ import PasswordComponent from "./components/PasswordComponent"; import { getJSONCredentials, storeCredentials } from "../../utils/workers/aws-worker"; import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings"; -export default function AwsSettings({ trigger, enabled, updateEnabled }) { +export default function AwsSettings({ trigger, enabled, updateEnabled, updateState }) { const [ aws_region, setAwsRegion ] = useState(''); const [ aws_key_id, setAwsKeyID ] = useState(''); @@ -14,20 +14,21 @@ export default function AwsSettings({ trigger, enabled, updateEnabled }) { const [ aws_session_token, setAwsSessionToken ] = useState(''); const [ aws_model_id, setAwsModelID ] = useState(''); - function saveSettings() { + async function saveSettings() { const credentials = { key_id: aws_key_id, secret_key: aws_secret_key } if(aws_session_token) { credentials.session_token = aws_session_token } - storeCredentials( - credentials, aws_key_id && aws_secret_key, - enabled - ) updatePlatformSettings({ aws_model_id, aws_region }) + await storeCredentials( + credentials, aws_key_id && aws_secret_key, + enabled + ) + updateState(); } // get aws credentials from db diff --git a/src/components/settings/LlamaSettings.jsx b/src/components/settings/LlamaSettings.jsx index d6d163c..ddcc04b 100644 --- a/src/components/settings/LlamaSettings.jsx +++ b/src/components/settings/LlamaSettings.jsx @@ -6,7 +6,7 @@ import useIDB from "../../utils/idb"; import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings"; import { DEFAULT_LLAMA_CPP_MODEL_URL } from "../../utils/types"; -export default function LlamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector }) { +export default function LlamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector, updateState }) { const [model_download_link, setModelDownloadLink] = useState(''); const idb = useIDB(); @@ -44,6 +44,7 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow // load model using the model name retrieved await window['node-llama-cpp'].loadModel(stored_model['model-name']) } + updateState(); } useEffect(()=>{ diff --git a/src/components/settings/ModelSettings.jsx b/src/components/settings/ModelSettings.jsx index 8c35c9a..63015d8 100644 --- a/src/components/settings/ModelSettings.jsx +++ b/src/components/settings/ModelSettings.jsx @@ -3,7 +3,7 @@ import ScrollBarComponent from "./components/ScrollBarComponent"; import SettingSection from "./SettingSection"; import { getModelSettings, updateModelSettings } from "../../utils/general_settings"; -export default function ModelSettings({ trigger }) { +export default function ModelSettings({ trigger, updateState }) { const [max_tokens, setMaxTokens] = useState(0); const [top_p, setTopP] = useState(0); @@ -13,6 +13,7 @@ export default function ModelSettings({ trigger }) { updateModelSettings({ max_tokens, top_p, temperature }) + updateState() } useEffect(()=>{ diff --git a/src/components/settings/OpenaiSettings.jsx b/src/components/settings/OpenaiSettings.jsx index 0c6f2d8..f8b8d37 100644 --- a/src/components/settings/OpenaiSettings.jsx +++ b/src/components/settings/OpenaiSettings.jsx @@ -6,16 +6,17 @@ import PasswordComponent from "./components/PasswordComponent"; import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings"; import { getCredentials, storeCredentials } from "../../utils/workers/openai-worker"; -export default function OpenaiSettings({ trigger, enabled, updateEnabled }) { +export default function OpenaiSettings({ trigger, enabled, updateEnabled, updateState }) { const [api_key, setAPIKey] = useState(''); const [model_name, setModelName] = useState(''); - function saveSettings() { + async function saveSettings() { updatePlatformSettings({ openai_model: model_name }) - storeCredentials({api_key}) + await storeCredentials({api_key}) + updateState() } useEffect(()=>{ diff --git a/src/components/settings/WllamaSettings.jsx b/src/components/settings/WllamaSettings.jsx index 83de0b9..3e525ab 100644 --- a/src/components/settings/WllamaSettings.jsx +++ b/src/components/settings/WllamaSettings.jsx @@ -6,7 +6,7 @@ import { getPlatformSettings, updatePlatformSettings } from "../../utils/general import { downloadModel, isModelDownloaded, loadModel, loadModelSamplingSettings } from "../../utils/workers/wllama-worker"; import useIDB from "../../utils/idb"; -export default function WllamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector }) { +export default function WllamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector, updateState }) { const [ threads, setThreads ] = useState(1); const [ batch_size, setBatchSize ] = useState(256); @@ -45,6 +45,7 @@ export default function WllamaSettings({ trigger, enabled, updateEnabled, openDo await loadModel('completion'); } + updateState(); } useEffect(()=>{ diff --git a/src/components/settings/index.jsx b/src/components/settings/index.jsx index d5a8da9..4721377 100644 --- a/src/components/settings/index.jsx +++ b/src/components/settings/index.jsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useRef, useState } from "react"; import AwsSettings from "./AwsSettings"; import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings"; import ModelSettings from "./ModelSettings"; @@ -10,7 +10,10 @@ import LlamaSettings from "./LlamaSettings"; export default function Settings({ complete }) { const [enabled_platform, setEnabledPlatform] = useState(getPlatformSettings().enabled_platform) - const [ saveSettingTrigger, toggleSaveSetting ] = useState(false); + const [saveSettingTrigger, toggleSaveSetting] = useState(false); + + const setting_sections = useRef(5); + const [updatedSettings, setUpdatedSettings] = useState(0); // Download model dialog const [download_title, setDownloadTitle] = useState(''); @@ -26,7 +29,6 @@ export default function Settings({ complete }) { function save() { toggleSaveSetting(true); setTimeout(()=>toggleSaveSetting(false), 1000); - complete && complete(); } async function openDownloadProtector(title, description, downloader) { @@ -42,32 +44,49 @@ export default function Settings({ complete }) { }) } + function updateSaveState() { + setUpdatedSettings((prev)=>prev+1) + } + + useEffect(()=>{ + if(updatedSettings === setting_sections.current) { + complete && complete(); + setUpdatedSettings(0); + } + // eslint-disable-next-line + }, [updatedSettings]) + return (
updatePlatform(set ? "Llama" : null)} openDownloadProtector={openDownloadProtector} + updateState={updateSaveState} /> updatePlatform(set ? "Wllama" : null)} openDownloadProtector={openDownloadProtector} + updateState={updateSaveState} /> updatePlatform(set ? "AWS" : null)} + updateState={updateSaveState} /> updatePlatform(set ? "OpenAI" : null)} + updateState={updateSaveState} />
{ saveSettingTrigger ? "Settings Saved!" : "Save Settings" } diff --git a/src/styles/entry.css b/src/styles/entry.css index e0ca66c..fbd35d3 100644 --- a/src/styles/entry.css +++ b/src/styles/entry.css @@ -47,34 +47,4 @@ margin-top: 10px; text-decoration: underline; color: gray; -} - -.load-page > .first-time > .download-progress { - margin: auto; - margin-top: 20px; - width: 500px; - background-color: lightgray; - border-radius: 20px; - height: 20px; - align-content: center; - font-size: 12px; - color: white; - position: relative; - overflow: hidden; -} - -.load-page > .first-time > .download-progress > .progress-bar { - position: absolute; - width: 100%; - height: 100%; - top: 0; - left: 0; - background-color: limegreen; - transform: translateX(-100%); - z-index: 1; -} - -.load-page > .first-time > .download-progress > .progress-num { - z-index: 2; - position: relative; } \ No newline at end of file diff --git a/src/utils/workers/wllama-worker.js b/src/utils/workers/wllama-worker.js index 427c8e5..1440835 100644 --- a/src/utils/workers/wllama-worker.js +++ b/src/utils/workers/wllama-worker.js @@ -85,7 +85,7 @@ export async function downloadModel(type = 'completion', cb = null) { allowOffline: true, embeddings: type === 'embedding', progressCallback: ({loaded, total})=>{ - cb && cb((loaded / total) * 100, false); + cb && cb(+((loaded / total) * 100).toFixed(2), false); } }) cb && cb(100, true);