Skip to content

Commit

Permalink
Merge pull request #37 from SkywardAI/improve
Browse files Browse the repository at this point in the history
add wllama settings
  • Loading branch information
cbh778899 authored Sep 24, 2024
2 parents 345c309 + ecda9f5 commit b8335c2
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 12 deletions.
12 changes: 9 additions & 3 deletions src/components/chat/Conversation.jsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useEffect, useRef, useState } from "react";
import ConversationBubble from "./ConversationBubble";
import { CheckCircle, FileImageFill, FileTextFill, Paperclip, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons';
import { CheckCircle, FileImageFill, FileTextFill, Paperclip, PencilFill, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons';
import useIDB from "../../utils/idb";
import { isModelLoaded, loadModel } from '../../utils/workers/worker'
import { getCompletionFunctions } from "../../utils/workers";
Expand Down Expand Up @@ -85,7 +85,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
});
setPendingMessage('')
}
messages = [user_msg];

messages =
chat_functions.current.continue_chat ?
[...conversation, user_msg] : [user_msg];
} else {
let user_message = user_msg;
if(upload_file) {
Expand Down Expand Up @@ -167,7 +170,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
<CheckCircle className="btn clickable" onClick={submitUpdateTitle} />
<XCircle className="btn clickable" onClick={()=>{setEditedTitle(title); toggleEditTitle(false)}} />
</form>:
<div className="text" onClick={()=>toggleEditTitle(true)}>{ title }</div>
<div className="display-title clickable" onClick={()=>toggleEditTitle(true)}>
<div className="text">{ title }</div>
<PencilFill className="edit-icon" />
</div>
}
</div>
<div className="bubbles" ref={bubblesRef}>
Expand Down
2 changes: 1 addition & 1 deletion src/components/settings/OpenaiSettings.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export default function OpenaiSettings({ trigger, enabled, updateEnabled }) {
}, [])

return (
<SettingSection title={'General Model Settings'}>
<SettingSection title={'OpenAI Settings'}>
<TrueFalseComponent
title={"Use OpenAI For Completion"}
value={enabled} cb={updateEnabled}
Expand Down
76 changes: 76 additions & 0 deletions src/components/settings/WllamaSettings.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { useEffect, useState } from "react";
import SettingSection from "./SettingSection";
import TrueFalseComponent from "./components/TrueFalseComponent";
import ScrollBarComponent from "./components/ScrollBarComponent";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import { loadModel, loadModelSamplingSettings } from "../../utils/workers/worker";

export default function WllamaSettings({ trigger, enabled }) {

const [ threads, setThreads ] = useState(1);
const [ batch_size, setBatchSize ] = useState(256);
const [ context_length, setContextLength ] = useState(4096);
const [ continue_conv, setContinueConversation ] = useState(false);

function saveSettings() {
updatePlatformSettings({
wllama_threads: threads,
wllama_batch_size: batch_size,
wllama_context_length: context_length,
wllama_continue_conv: continue_conv
})

if(enabled) {
loadModelSamplingSettings();
loadModel('completion');
}
}

useEffect(()=>{
trigger && saveSettings();
// eslint-disable-next-line
}, [trigger])

useEffect(()=>{
const {
wllama_threads,
wllama_batch_size,
wllama_context_length,
wllama_continue_conv
} = getPlatformSettings();

setThreads(wllama_threads)
setBatchSize(wllama_batch_size)
setContextLength(wllama_context_length)
setContinueConversation(wllama_continue_conv)

}, [])

return (
<SettingSection title={'Local Model Settings'}>
<ScrollBarComponent
title={"Set Threads to use"}
value={threads} cb={setThreads}
description={'Please set how many threads you want to use, max is your CPU cores.'}
min={1} max={navigator.hardwareConcurrency}
/>
<TrueFalseComponent
title={"Enable Continue Conversation"}
description={"Open to continue conversation instead treate any question as a new conversation. This can cause the response speed becomes extreamly slow."}
value={continue_conv} cb={setContinueConversation}
/>
<ScrollBarComponent
title={"Set Batch Size"}
value={batch_size} cb={setBatchSize}
description={'Adjust batch size to balance the performance and cost.'}
min={1} max={512}
/>
<ScrollBarComponent
title={"Set Context Length"}
value={context_length} cb={setContextLength}
description={'Adjust the max tokens of a conversation, over this size would reset the conversation.'}
min={1024} max={4096}
/>
</SettingSection>
)
}
3 changes: 2 additions & 1 deletion src/components/settings/components/ScrollBarComponent.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ export default function ScrollBarComponent({ cb, value, disabled, title, descrip
setScrollValue(value);
} else {
if(!isNaN(+value)) {
if(value > max) value = max;
setScrollValue(times_10 ? value * 10 : value);
}
setTextValue(value);
}
}

useEffect(()=>{
!isNaN(+textValue) && checkValue() && cb(+textValue);
textValue !== value && !isNaN(+textValue) && checkValue() && cb(+textValue);
// eslint-disable-next-line
}, [textValue])

Expand Down
5 changes: 5 additions & 0 deletions src/components/settings/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import AwsSettings from "./AwsSettings";
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
import ModelSettings from "./ModelSettings";
import OpenaiSettings from "./OpenaiSettings";
import WllamaSettings from "./WllamaSettings";

export default function Settings() {

Expand All @@ -24,6 +25,10 @@ export default function Settings() {
<ModelSettings
trigger={saveSettingTrigger}
/>
<WllamaSettings
trigger={saveSettingTrigger}
enabled={!enabled_platform}
/>
<AwsSettings
trigger={saveSettingTrigger}
enabled={enabled_platform === 'AWS'}
Expand Down
12 changes: 12 additions & 0 deletions src/styles/chat.css
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,22 @@
font-size: 15px;
font-weight: bold;
color: rgb(50, 50, 50);
align-content: center;

--elem-height: calc(var(--title-bar-height) - 14px);
}

.chat > .conversation-main > .title-bar .display-title {
width: fit-content;
height: 100%;
display: flex;
align-items: center;
}

.chat > .conversation-main > .title-bar .display-title > .edit-icon {
margin-left: 10px;
}

.chat > .conversation-main > .title-bar > form {
display: flex;
}
Expand Down
7 changes: 6 additions & 1 deletion src/utils/general_settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ const DEFAULT_PLATFORM_SETTINGS = {
// aws
aws_model_id: '', aws_region: '',
// openai
openai_model: ''
openai_model: '',
// wllama
wllama_threads: 4,
wllama_batch_size: 128,
wllama_context_length: 4096,
wllama_continue_conv: false
}

const MODEL_SETTINGS_KEY = 'general-model-settings'
Expand Down
5 changes: 4 additions & 1 deletion src/utils/workers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ export function getCompletionFunctions() {
case 'OpenAI':
return { completions: OpenaiCompletions, abort: OpenaiAbort, platform: "OpenAI"}
default:
return { completions: WllamaCompletions, abort: WllamaAbort, platform: "Wllama" }
return {
completions: WllamaCompletions, abort: WllamaAbort,
platform: "Wllama", continue_chat: platform_settings.wllama_continue_conv
}
}

}
38 changes: 33 additions & 5 deletions src/utils/workers/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url';
import wllamaMultiJS from '@wllama/wllama/src/multi-thread/wllama.js?url';
import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url';
import wllamaMultiWorker from '@wllama/wllama/src/multi-thread/wllama.worker.mjs?url';
import { getModelSettings, getPlatformSettings } from "../general_settings";

const CONFIG_PATHS = {
'single-thread/wllama.js': wllamaSingleJS,
Expand All @@ -15,6 +16,32 @@ const CONFIG_PATHS = {
'multi-thread/wllama.worker.mjs': wllamaMultiWorker,
};

let model_sampling_settings = {}

export function loadModelSamplingSettings() {
const {
wllama_threads,
wllama_batch_size,
wllama_context_length
} = getPlatformSettings();

const {
max_tokens,
top_p,
temperature
} = getModelSettings();

model_sampling_settings = {
n_threads: wllama_threads,
n_batch: wllama_batch_size,
n_ctx: wllama_context_length,
nPredict: max_tokens,
temp: temperature,
top_p
}
}
loadModelSamplingSettings();

const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";

const engines = {
Expand Down Expand Up @@ -74,10 +101,9 @@ export async function loadModel(type = 'completion', cb = null) {
await downloadModel(type, cb);
}
cb && cb('loading')
const {n_threads, n_batch, n_ctx} = model_sampling_settings;
await instance.loadModelFromUrl(model_src, {
n_threads: 6,
n_ctx: 4096,
n_batch: 128,
n_threads, n_ctx, n_batch,
});
} catch(error) {
console.error(error)
Expand All @@ -104,11 +130,13 @@ export async function formatPrompt(messages) {
export async function chatCompletions(messages, cb = null) {
stop_signal = false;
try {
const { nPredict, temp, top_p } = model_sampling_settings;

const prompt = await formatPrompt(messages)
const result = await engines['completion'].instance.createCompletion(prompt, {
nPredict: 256,
nPredict,
sampling: {
temp: 0.7
temp, top_p
},
onNewToken: (token, piece, currentText, optionals) => {
cb && cb(currentText, false);
Expand Down

0 comments on commit b8335c2

Please sign in to comment.