Skip to content

Commit

Permalink
add Wllama related settings
Browse files Browse the repository at this point in the history
Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 committed Sep 24, 2024
1 parent 1b89334 commit ae2c7df
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/components/chat/Conversation.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
});
setPendingMessage('')
}
messages = [user_msg];

messages =
chat_functions.current.continue_chat ?
[...conversation, user_msg] : [user_msg];
} else {
let user_message = user_msg;
if(upload_file) {
Expand Down
76 changes: 76 additions & 0 deletions src/components/settings/WllamaSettings.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { useEffect, useState } from "react";
import SettingSection from "./SettingSection";
import TrueFalseComponent from "./components/TrueFalseComponent";
import ScrollBarComponent from "./components/ScrollBarComponent";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { loadModel, loadModelSamplingSettings } from "../../utils/workers/worker";

export default function WllamaSettings({ trigger, enabled }) {

const [ threads, setThreads ] = useState(1);
const [ batch_size, setBatchSize ] = useState(256);
const [ context_length, setContextLength ] = useState(4096);
const [ continue_conv, setContinueConversation ] = useState(false);

function saveSettings() {
updatePlatformSettings({
wllama_threads: threads,
wllama_batch_size: batch_size,
wllama_context_length: context_length,
wllama_continue_conv: continue_conv
})

if(enabled) {
loadModelSamplingSettings();
loadModel('completion');
}
}

useEffect(()=>{
trigger && saveSettings();
// eslint-disable-next-line
}, [trigger])

useEffect(()=>{
const {
wllama_threads,
wllama_batch_size,
wllama_context_length,
wllama_continue_conv
} = getPlatformSettings();

setThreads(wllama_threads)
setBatchSize(wllama_batch_size)
setContextLength(wllama_context_length)
setContinueConversation(wllama_continue_conv)

}, [])

return (
<SettingSection title={'Local Model Settings'}>
<ScrollBarComponent
title={"Set Threads to use"}
value={threads} cb={setThreads}
description={'Please set how many threads you want to use, max is your CPU cores.'}
min={1} max={navigator.hardwareConcurrency}
/>
<TrueFalseComponent
title={"Enable Continue Conversation"}
description={"Open to continue conversation instead treate any question as a new conversation. This can cause the response speed becomes extreamly slow."}
value={continue_conv} cb={setContinueConversation}
/>
<ScrollBarComponent
title={"Set Batch Size"}
value={batch_size} cb={setBatchSize}
description={'Adjust batch size to balance the performance and cost.'}
min={1} max={512}
/>
<ScrollBarComponent
title={"Set Context Length"}
value={context_length} cb={setContextLength}
description={'Adjust the max tokens of a conversation, over this size would reset the conversation.'}
min={1024} max={4096}
/>
</SettingSection>
)
}
5 changes: 5 additions & 0 deletions src/components/settings/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import AwsSettings from "./AwsSettings";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import ModelSettings from "./ModelSettings";
import OpenaiSettings from "./OpenaiSettings";
import WllamaSettings from "./WllamaSettings";

export default function Settings() {

Expand All @@ -24,6 +25,10 @@ export default function Settings() {
<ModelSettings
trigger={saveSettingTrigger}
/>
<WllamaSettings
trigger={saveSettingTrigger}
enabled={!enabled_platform}
/>
<AwsSettings
trigger={saveSettingTrigger}
enabled={enabled_platform === 'AWS'}
Expand Down
7 changes: 6 additions & 1 deletion src/utils/general_settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ const DEFAULT_PLATFORM_SETTINGS = {
// aws
aws_model_id: '', aws_region: '',
// openai
openai_model: ''
openai_model: '',
// wllama
wllama_threads: 4,
wllama_batch_size: 128,
wllama_context_length: 4096,
wllama_continue_conv: false
}

const MODEL_SETTINGS_KEY = 'general-model-settings'
Expand Down
5 changes: 4 additions & 1 deletion src/utils/workers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ export function getCompletionFunctions() {
case 'OpenAI':
return { completions: OpenaiCompletions, abort: OpenaiAbort, platform: "OpenAI"}
default:
return { completions: WllamaCompletions, abort: WllamaAbort, platform: "Wllama" }
return {
completions: WllamaCompletions, abort: WllamaAbort,
platform: "Wllama", continue_chat: platform_settings.wllama_continue_conv
}
}

}
38 changes: 33 additions & 5 deletions src/utils/workers/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url';
import wllamaMultiJS from '@wllama/wllama/src/multi-thread/wllama.js?url';
import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url';
import wllamaMultiWorker from '@wllama/wllama/src/multi-thread/wllama.worker.mjs?url';
import { getModelSettings, getPlatformSettings } from "../general_settings";

const CONFIG_PATHS = {
'single-thread/wllama.js': wllamaSingleJS,
Expand All @@ -15,6 +16,32 @@ const CONFIG_PATHS = {
'multi-thread/wllama.worker.mjs': wllamaMultiWorker,
};

let model_sampling_settings = {}

export function loadModelSamplingSettings() {
const {
wllama_threads,
wllama_batch_size,
wllama_context_length
} = getPlatformSettings();

const {
max_tokens,
top_p,
temperature
} = getModelSettings();

model_sampling_settings = {
n_threads: wllama_threads,
n_batch: wllama_batch_size,
n_ctx: wllama_context_length,
nPredict: max_tokens,
temp: temperature,
top_p
}
}
loadModelSamplingSettings();

const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";

const engines = {
Expand Down Expand Up @@ -74,10 +101,9 @@ export async function loadModel(type = 'completion', cb = null) {
await downloadModel(type, cb);
}
cb && cb('loading')
const {n_threads, n_batch, n_ctx} = model_sampling_settings;
await instance.loadModelFromUrl(model_src, {
n_threads: 6,
n_ctx: 4096,
n_batch: 128,
n_threads, n_ctx, n_batch,
});
} catch(error) {
console.error(error)
Expand All @@ -104,11 +130,13 @@ export async function formatPrompt(messages) {
export async function chatCompletions(messages, cb = null) {
stop_signal = false;
try {
const { nPredict, temp, top_p } = model_sampling_settings;

const prompt = await formatPrompt(messages)
const result = await engines['completion'].instance.createCompletion(prompt, {
nPredict: 256,
nPredict,
sampling: {
temp: 0.7
temp, top_p
},
onNewToken: (token, piece, currentText, optionals) => {
cb && cb(currentText, false);
Expand Down

0 comments on commit ae2c7df

Please sign in to comment.