Skip to content

Commit ae2c7df

Browse files
committed
add Wllama related settings
Signed-off-by: cbh778899 <[email protected]>
1 parent 1b89334 commit ae2c7df

File tree

6 files changed

+128
-8
lines changed

6 files changed

+128
-8
lines changed

src/components/chat/Conversation.jsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl
8585
});
8686
setPendingMessage('')
8787
}
88-
messages = [user_msg];
88+
89+
messages =
90+
chat_functions.current.continue_chat ?
91+
[...conversation, user_msg] : [user_msg];
8992
} else {
9093
let user_message = user_msg;
9194
if(upload_file) {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import { useEffect, useState } from "react";
2+
import SettingSection from "./SettingSection";
3+
import TrueFalseComponent from "./components/TrueFalseComponent";
4+
import ScrollBarComponent from "./components/ScrollBarComponent";
5+
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
6+
import { loadModel, loadModelSamplingSettings } from "../../utils/workers/worker";
7+
8+
export default function WllamaSettings({ trigger, enabled }) {
9+
10+
const [ threads, setThreads ] = useState(1);
11+
const [ batch_size, setBatchSize ] = useState(256);
12+
const [ context_length, setContextLength ] = useState(4096);
13+
const [ continue_conv, setContinueConversation ] = useState(false);
14+
15+
function saveSettings() {
16+
updatePlatformSettings({
17+
wllama_threads: threads,
18+
wllama_batch_size: batch_size,
19+
wllama_context_length: context_length,
20+
wllama_continue_conv: continue_conv
21+
})
22+
23+
if(enabled) {
24+
loadModelSamplingSettings();
25+
loadModel('completion');
26+
}
27+
}
28+
29+
useEffect(()=>{
30+
trigger && saveSettings();
31+
// eslint-disable-next-line
32+
}, [trigger])
33+
34+
useEffect(()=>{
35+
const {
36+
wllama_threads,
37+
wllama_batch_size,
38+
wllama_context_length,
39+
wllama_continue_conv
40+
} = getPlatformSettings();
41+
42+
setThreads(wllama_threads)
43+
setBatchSize(wllama_batch_size)
44+
setContextLength(wllama_context_length)
45+
setContinueConversation(wllama_continue_conv)
46+
47+
}, [])
48+
49+
return (
50+
<SettingSection title={'Local Model Settings'}>
51+
<ScrollBarComponent
52+
title={"Set Threads to use"}
53+
value={threads} cb={setThreads}
54+
description={'Please set how many threads you want to use, max is your CPU cores.'}
55+
min={1} max={navigator.hardwareConcurrency}
56+
/>
57+
<TrueFalseComponent
58+
title={"Enable Continue Conversation"}
59+
description={"Open to continue conversation instead treate any question as a new conversation. This can cause the response speed becomes extreamly slow."}
60+
value={continue_conv} cb={setContinueConversation}
61+
/>
62+
<ScrollBarComponent
63+
title={"Set Batch Size"}
64+
value={batch_size} cb={setBatchSize}
65+
description={'Adjust batch size to balance the performance and cost.'}
66+
min={1} max={512}
67+
/>
68+
<ScrollBarComponent
69+
title={"Set Context Length"}
70+
value={context_length} cb={setContextLength}
71+
description={'Adjust the max tokens of a conversation, over this size would reset the conversation.'}
72+
min={1024} max={4096}
73+
/>
74+
</SettingSection>
75+
)
76+
}

src/components/settings/index.jsx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import AwsSettings from "./AwsSettings";
33
import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings";
44
import ModelSettings from "./ModelSettings";
55
import OpenaiSettings from "./OpenaiSettings";
6+
import WllamaSettings from "./WllamaSettings";
67

78
export default function Settings() {
89

@@ -24,6 +25,10 @@ export default function Settings() {
2425
<ModelSettings
2526
trigger={saveSettingTrigger}
2627
/>
28+
<WllamaSettings
29+
trigger={saveSettingTrigger}
30+
enabled={!enabled_platform}
31+
/>
2732
<AwsSettings
2833
trigger={saveSettingTrigger}
2934
enabled={enabled_platform === 'AWS'}

src/utils/general_settings.js

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ const DEFAULT_PLATFORM_SETTINGS = {
44
// aws
55
aws_model_id: '', aws_region: '',
66
// openai
7-
openai_model: ''
7+
openai_model: '',
8+
// wllama
9+
wllama_threads: 4,
10+
wllama_batch_size: 128,
11+
wllama_context_length: 4096,
12+
wllama_continue_conv: false
813
}
914

1015
const MODEL_SETTINGS_KEY = 'general-model-settings'

src/utils/workers/index.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ export function getCompletionFunctions() {
2323
case 'OpenAI':
2424
return { completions: OpenaiCompletions, abort: OpenaiAbort, platform: "OpenAI"}
2525
default:
26-
return { completions: WllamaCompletions, abort: WllamaAbort, platform: "Wllama" }
26+
return {
27+
completions: WllamaCompletions, abort: WllamaAbort,
28+
platform: "Wllama", continue_chat: platform_settings.wllama_continue_conv
29+
}
2730
}
2831

2932
}

src/utils/workers/worker.js

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url';
66
import wllamaMultiJS from '@wllama/wllama/src/multi-thread/wllama.js?url';
77
import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url';
88
import wllamaMultiWorker from '@wllama/wllama/src/multi-thread/wllama.worker.mjs?url';
9+
import { getModelSettings, getPlatformSettings } from "../general_settings";
910

1011
const CONFIG_PATHS = {
1112
'single-thread/wllama.js': wllamaSingleJS,
@@ -15,6 +16,32 @@ const CONFIG_PATHS = {
1516
'multi-thread/wllama.worker.mjs': wllamaMultiWorker,
1617
};
1718

19+
let model_sampling_settings = {}
20+
21+
export function loadModelSamplingSettings() {
22+
const {
23+
wllama_threads,
24+
wllama_batch_size,
25+
wllama_context_length
26+
} = getPlatformSettings();
27+
28+
const {
29+
max_tokens,
30+
top_p,
31+
temperature
32+
} = getModelSettings();
33+
34+
model_sampling_settings = {
35+
n_threads: wllama_threads,
36+
n_batch: wllama_batch_size,
37+
n_ctx: wllama_context_length,
38+
nPredict: max_tokens,
39+
temp: temperature,
40+
top_p
41+
}
42+
}
43+
loadModelSamplingSettings();
44+
1845
const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
1946

2047
const engines = {
@@ -74,10 +101,9 @@ export async function loadModel(type = 'completion', cb = null) {
74101
await downloadModel(type, cb);
75102
}
76103
cb && cb('loading')
104+
const {n_threads, n_batch, n_ctx} = model_sampling_settings;
77105
await instance.loadModelFromUrl(model_src, {
78-
n_threads: 6,
79-
n_ctx: 4096,
80-
n_batch: 128,
106+
n_threads, n_ctx, n_batch,
81107
});
82108
} catch(error) {
83109
console.error(error)
@@ -104,11 +130,13 @@ export async function formatPrompt(messages) {
104130
export async function chatCompletions(messages, cb = null) {
105131
stop_signal = false;
106132
try {
133+
const { nPredict, temp, top_p } = model_sampling_settings;
134+
107135
const prompt = await formatPrompt(messages)
108136
const result = await engines['completion'].instance.createCompletion(prompt, {
109-
nPredict: 256,
137+
nPredict,
110138
sampling: {
111-
temp: 0.7
139+
temp, top_p
112140
},
113141
onNewToken: (token, piece, currentText, optionals) => {
114142
cb && cb(currentText, false);

0 commit comments

Comments
 (0)