Skip to content

Commit

Permalink
fix wllama load bugs (#41)
Browse files Browse the repository at this point in the history
* remove output

Signed-off-by: cbh778899 <[email protected]>

* rename worker.js to wllama-worker.js

Signed-off-by: cbh778899 <[email protected]>

* remove force load platform to fix bug

Signed-off-by: cbh778899 <[email protected]>

* add download protector call to wllama settings

Signed-off-by: cbh778899 <[email protected]>

* change import path

Signed-off-by: cbh778899 <[email protected]>

---------

Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 authored Oct 14, 2024
1 parent cd48db7 commit 5b3b1fa
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 21 deletions.
7 changes: 1 addition & 6 deletions src/components/chat/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export default function Chat() {

const idb = useIDB();
// const settings = useRef(getCompletionFunctions());
const settings = useRef(getCompletionFunctions('Llama'));
const settings = useRef(getCompletionFunctions());

async function sendMessage(message, files) {
// save user messages
Expand Down Expand Up @@ -143,11 +143,6 @@ export default function Chat() {
setHistory={setTickets} history={tickets}
deleteHistory={requestDelete} platform={settings.current.platform}
/>
{/* <Conversation
uid={chat.uid}
title={chat.title} updateTitle={updateTitle}
client={chat.client} updateClient={updateChatClient}
/> */}
<ChatPage
updateTitle={updateTitle}
chat={chat} chat_history={chat_history}
Expand Down
1 change: 0 additions & 1 deletion src/components/settings/LlamaSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export default function LlamaSettings({ trigger, enabled, updateEnabled, openDow
if(enabled) {
// check if model with this url already downloaded
let stored_model = await idb.getOne("downloaded-models", {where: [{ platform: 'Llama', url }]})
console.log(stored_model)
// if no model record, means not downloaded
if(!stored_model) {
await openDownloadProtector(
Expand Down
29 changes: 25 additions & 4 deletions src/components/settings/WllamaSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ import SettingSection from "./SettingSection";
import TrueFalseComponent from "./components/TrueFalseComponent";
import ScrollBarComponent from "./components/ScrollBarComponent";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { loadModel, loadModelSamplingSettings } from "../../utils/workers/worker";
import { downloadModel, isModelDownloaded, loadModel, loadModelSamplingSettings } from "../../utils/workers/wllama-worker";
import useIDB from "../../utils/idb";

export default function WllamaSettings({ trigger, enabled, updateEnabled }) {
export default function WllamaSettings({ trigger, enabled, updateEnabled, openDownloadProtector }) {

const [ threads, setThreads ] = useState(1);
const [ batch_size, setBatchSize ] = useState(256);
const [ context_length, setContextLength ] = useState(4096);
const [ continue_conv, setContinueConversation ] = useState(false);

function saveSettings() {
const idb = useIDB();

async function saveSettings() {
updatePlatformSettings({
wllama_threads: threads,
wllama_batch_size: batch_size,
Expand All @@ -22,7 +25,25 @@ export default function WllamaSettings({ trigger, enabled, updateEnabled }) {

if(enabled) {
loadModelSamplingSettings();
loadModel('completion');
const model = 'completion'
if(!(await isModelDownloaded(model))) {
await openDownloadProtector(
"Please wait while we downloading the model...",
`Downloading model smollm-360m-instruct-add-basics-q8_0.gguf`,
async callback => {
const download_result = await downloadModel(model, callback);
if(!download_result) return;
const { model_name, url, finish_time, size } = download_result;
await idb.insert("downloaded-models", {
'model-name': model_name,
url, size, createdAt: finish_time,
platform: 'Wllama'
})
}
)
}

await loadModel('completion');
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/utils/start_loader.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getPlatformSettings } from "./general_settings";
import { instance } from "./idb";
import { loadModel } from "./workers/worker";
import { loadModel } from "./workers/wllama-worker";

export default async function loader() {
localStorage.setItem('not-first-time', '1');
Expand Down
10 changes: 3 additions & 7 deletions src/utils/workers/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { getModelSettings, getPlatformSettings } from "../general_settings";
import { chatCompletions as WllamaCompletions, abortCompletion as WllamaAbort } from "./worker";
import { chatCompletions as WllamaCompletions, abortCompletion as WllamaAbort, setClient as WllamaSetClient } from "./wllama-worker";
import { chatCompletions as AwsCompletions, abortCompletion as AwsAbort, setClient as AwsSetClient, formator as AwsFormator } from "./aws-worker"
import { chatCompletions as OpenaiCompletions, abortCompletion as OpenaiAbort, setClient as OpenAISetClient } from "./openai-worker";

Expand Down Expand Up @@ -45,13 +45,9 @@ export function getCompletionFunctions(platform = null) {
return {
completions: WllamaCompletions,
abort: WllamaAbort,
platform: "Wllama"
platform: "Wllama",
initClient: WllamaSetClient
}
// default:
// return {
// completions: WllamaCompletions, abort: WllamaAbort,
// platform: "Wllama"
// }
}

}
16 changes: 14 additions & 2 deletions src/utils/workers/worker.js → src/utils/workers/wllama-worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ loadModelSamplingSettings();

const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";


const engines = {
completion: {
model_src: "https://huggingface.co/HuggingFaceTB/smollm-360M-instruct-v0.2-Q8_0-GGUF/resolve/main/smollm-360m-instruct-add-basics-q8_0.gguf",
Expand All @@ -52,6 +53,11 @@ const engines = {
}
}

export function setClient() {
// we don't need to reset client for wllama
return 'fake-client'
}

let stop_signal = false;

export async function ReloadEngine(type = 'completion') {
Expand Down Expand Up @@ -79,10 +85,16 @@ export async function downloadModel(type = 'completion', cb = null) {
allowOffline: true,
embeddings: type === 'embedding',
progressCallback: ({loaded, total})=>{
cb && cb((loaded / total) * 100);
cb && cb((loaded / total) * 100, false);
}
})
cb && cb(100);
cb && cb(100, true);
return {
model_name: model_src.split('/').pop(),
url: model_src,
size: 0,
finish_time: Date.now()
}
}

export async function deleteModel(type = 'completion') {
Expand Down

0 comments on commit 5b3b1fa

Please sign in to comment.