Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wllama settings #37

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/components/chat/Conversation.jsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useEffect, useRef, useState } from "react";
import ConversationBubble from "./ConversationBubble";
import { CheckCircle, FileImageFill, FileTextFill, Paperclip, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons';
import { CheckCircle, FileImageFill, FileTextFill, Paperclip, PencilFill, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons';
import useIDB from "../../utils/idb";
import { isModelLoaded, loadModel } from '../../utils/workers/worker'
import { getCompletionFunctions } from "../../utils/workers";
Expand Down Expand Up @@ -85,7 +85,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
});
setPendingMessage('')
}
messages = [user_msg];

messages =
chat_functions.current.continue_chat ?
[...conversation, user_msg] : [user_msg];
} else {
let user_message = user_msg;
if(upload_file) {
Expand Down Expand Up @@ -167,7 +170,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
<CheckCircle className="btn clickable" onClick={submitUpdateTitle} />
<XCircle className="btn clickable" onClick={()=>{setEditedTitle(title); toggleEditTitle(false)}} />
</form>:
<div className="text" onClick={()=>toggleEditTitle(true)}>{ title }</div>
<div className="display-title clickable" onClick={()=>toggleEditTitle(true)}>
<div className="text">{ title }</div>
<PencilFill className="edit-icon" />
</div>
}
</div>
<div className="bubbles" ref={bubblesRef}>
Expand Down
2 changes: 1 addition & 1 deletion src/components/settings/OpenaiSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export default function OpenaiSettings({ trigger, enabled, updateEnabled }) {
}, [])

return (
<SettingSection title={'General Model Settings'}>
<SettingSection title={'OpenAI Settings'}>
<TrueFalseComponent
title={"Use OpenAI For Completion"}
value={enabled} cb={updateEnabled}
Expand Down
76 changes: 76 additions & 0 deletions src/components/settings/WllamaSettings.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { useEffect, useState } from "react";
import SettingSection from "./SettingSection";
import TrueFalseComponent from "./components/TrueFalseComponent";
import ScrollBarComponent from "./components/ScrollBarComponent";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { loadModel, loadModelSamplingSettings } from "../../utils/workers/worker";

export default function WllamaSettings({ trigger, enabled }) {

const [ threads, setThreads ] = useState(1);
const [ batch_size, setBatchSize ] = useState(256);
const [ context_length, setContextLength ] = useState(4096);
const [ continue_conv, setContinueConversation ] = useState(false);

function saveSettings() {
updatePlatformSettings({
wllama_threads: threads,
wllama_batch_size: batch_size,
wllama_context_length: context_length,
wllama_continue_conv: continue_conv
})

if(enabled) {
loadModelSamplingSettings();
loadModel('completion');
}
}

useEffect(()=>{
trigger && saveSettings();
// eslint-disable-next-line
}, [trigger])

useEffect(()=>{
const {
wllama_threads,
wllama_batch_size,
wllama_context_length,
wllama_continue_conv
} = getPlatformSettings();

setThreads(wllama_threads)
setBatchSize(wllama_batch_size)
setContextLength(wllama_context_length)
setContinueConversation(wllama_continue_conv)

}, [])

return (
<SettingSection title={'Local Model Settings'}>
<ScrollBarComponent
title={"Set Threads to use"}
value={threads} cb={setThreads}
description={'Please set how many threads you want to use, max is your CPU cores.'}
min={1} max={navigator.hardwareConcurrency}
/>
<TrueFalseComponent
title={"Enable Continue Conversation"}
description={"Open to continue conversation instead treate any question as a new conversation. This can cause the response speed becomes extreamly slow."}
value={continue_conv} cb={setContinueConversation}
/>
<ScrollBarComponent
title={"Set Batch Size"}
value={batch_size} cb={setBatchSize}
description={'Adjust batch size to balance the performance and cost.'}
min={1} max={512}
/>
<ScrollBarComponent
title={"Set Context Length"}
value={context_length} cb={setContextLength}
description={'Adjust the max tokens of a conversation, over this size would reset the conversation.'}
min={1024} max={4096}
/>
</SettingSection>
)
}
3 changes: 2 additions & 1 deletion src/components/settings/components/ScrollBarComponent.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ export default function ScrollBarComponent({ cb, value, disabled, title, descrip
setScrollValue(value);
} else {
if(!isNaN(+value)) {
if(value > max) value = max;
setScrollValue(times_10 ? value * 10 : value);
}
setTextValue(value);
}
}

useEffect(()=>{
!isNaN(+textValue) && checkValue() && cb(+textValue);
textValue !== value && !isNaN(+textValue) && checkValue() && cb(+textValue);
// eslint-disable-next-line
}, [textValue])

Expand Down
5 changes: 5 additions & 0 deletions src/components/settings/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import AwsSettings from "./AwsSettings";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import ModelSettings from "./ModelSettings";
import OpenaiSettings from "./OpenaiSettings";
import WllamaSettings from "./WllamaSettings";

export default function Settings() {

Expand All @@ -24,6 +25,10 @@ export default function Settings() {
<ModelSettings
trigger={saveSettingTrigger}
/>
<WllamaSettings
trigger={saveSettingTrigger}
enabled={!enabled_platform}
/>
<AwsSettings
trigger={saveSettingTrigger}
enabled={enabled_platform === 'AWS'}
Expand Down
12 changes: 12 additions & 0 deletions src/styles/chat.css
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,22 @@
font-size: 15px;
font-weight: bold;
color: rgb(50, 50, 50);
align-content: center;

--elem-height: calc(var(--title-bar-height) - 14px);
}

.chat > .conversation-main > .title-bar .display-title {
width: fit-content;
height: 100%;
display: flex;
align-items: center;
}

.chat > .conversation-main > .title-bar .display-title > .edit-icon {
margin-left: 10px;
}

.chat > .conversation-main > .title-bar > form {
display: flex;
}
Expand Down
7 changes: 6 additions & 1 deletion src/utils/general_settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ const DEFAULT_PLATFORM_SETTINGS = {
// aws
aws_model_id: '', aws_region: '',
// openai
openai_model: ''
openai_model: '',
// wllama
wllama_threads: 4,
wllama_batch_size: 128,
wllama_context_length: 4096,
wllama_continue_conv: false
}

const MODEL_SETTINGS_KEY = 'general-model-settings'
Expand Down
5 changes: 4 additions & 1 deletion src/utils/workers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ export function getCompletionFunctions() {
case 'OpenAI':
return { completions: OpenaiCompletions, abort: OpenaiAbort, platform: "OpenAI"}
default:
return { completions: WllamaCompletions, abort: WllamaAbort, platform: "Wllama" }
return {
completions: WllamaCompletions, abort: WllamaAbort,
platform: "Wllama", continue_chat: platform_settings.wllama_continue_conv
}
}

}
38 changes: 33 additions & 5 deletions src/utils/workers/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url';
import wllamaMultiJS from '@wllama/wllama/src/multi-thread/wllama.js?url';
import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url';
import wllamaMultiWorker from '@wllama/wllama/src/multi-thread/wllama.worker.mjs?url';
import { getModelSettings, getPlatformSettings } from "../general_settings";

const CONFIG_PATHS = {
'single-thread/wllama.js': wllamaSingleJS,
Expand All @@ -15,6 +16,32 @@ const CONFIG_PATHS = {
'multi-thread/wllama.worker.mjs': wllamaMultiWorker,
};

let model_sampling_settings = {}

export function loadModelSamplingSettings() {
const {
wllama_threads,
wllama_batch_size,
wllama_context_length
} = getPlatformSettings();

const {
max_tokens,
top_p,
temperature
} = getModelSettings();

model_sampling_settings = {
n_threads: wllama_threads,
n_batch: wllama_batch_size,
n_ctx: wllama_context_length,
nPredict: max_tokens,
temp: temperature,
top_p
}
}
loadModelSamplingSettings();

const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";

const engines = {
Expand Down Expand Up @@ -74,10 +101,9 @@ export async function loadModel(type = 'completion', cb = null) {
await downloadModel(type, cb);
}
cb && cb('loading')
const {n_threads, n_batch, n_ctx} = model_sampling_settings;
await instance.loadModelFromUrl(model_src, {
n_threads: 6,
n_ctx: 4096,
n_batch: 128,
n_threads, n_ctx, n_batch,
});
} catch(error) {
console.error(error)
Expand All @@ -104,11 +130,13 @@ export async function formatPrompt(messages) {
export async function chatCompletions(messages, cb = null) {
stop_signal = false;
try {
const { nPredict, temp, top_p } = model_sampling_settings;

const prompt = await formatPrompt(messages)
const result = await engines['completion'].instance.createCompletion(prompt, {
nPredict: 256,
nPredict,
sampling: {
temp: 0.7
temp, top_p
},
onNewToken: (token, piece, currentText, optionals) => {
cb && cb(currentText, false);
Expand Down