diff --git a/src/components/chat/Conversation.jsx b/src/components/chat/Conversation.jsx index 8abdabb..83dd984 100644 --- a/src/components/chat/Conversation.jsx +++ b/src/components/chat/Conversation.jsx @@ -1,6 +1,6 @@ import { useEffect, useRef, useState } from "react"; import ConversationBubble from "./ConversationBubble"; -import { CheckCircle, FileImageFill, FileTextFill, Paperclip, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons'; +import { CheckCircle, FileImageFill, FileTextFill, Paperclip, PencilFill, Send, StopCircleFill, XCircle } from 'react-bootstrap-icons'; import useIDB from "../../utils/idb"; import { isModelLoaded, loadModel } from '../../utils/workers/worker' import { getCompletionFunctions } from "../../utils/workers"; @@ -85,7 +85,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl }); setPendingMessage('') } - messages = [user_msg]; + + messages = + chat_functions.current.continue_chat ? + [...conversation, user_msg] : [user_msg]; } else { let user_message = user_msg; if(upload_file) { @@ -167,7 +170,10 @@ export default function Conversation({ uid, title, updateTitle, client, updateCl {setEditedTitle(title); toggleEditTitle(false)}} /> : -
toggleEditTitle(true)}>{ title }
+
toggleEditTitle(true)}> +
{ title }
+ +
}
diff --git a/src/components/settings/OpenaiSettings.jsx b/src/components/settings/OpenaiSettings.jsx index 4611506..0c6f2d8 100644 --- a/src/components/settings/OpenaiSettings.jsx +++ b/src/components/settings/OpenaiSettings.jsx @@ -36,7 +36,7 @@ export default function OpenaiSettings({ trigger, enabled, updateEnabled }) { }, []) return ( - + { + trigger && saveSettings(); + // eslint-disable-next-line + }, [trigger]) + + useEffect(()=>{ + const { + wllama_threads, + wllama_batch_size, + wllama_context_length, + wllama_continue_conv + } = getPlatformSettings(); + + setThreads(wllama_threads) + setBatchSize(wllama_batch_size) + setContextLength(wllama_context_length) + setContinueConversation(wllama_continue_conv) + + }, []) + + return ( + + + + + + + ) +} \ No newline at end of file diff --git a/src/components/settings/components/ScrollBarComponent.jsx b/src/components/settings/components/ScrollBarComponent.jsx index e9018cf..9de6c8d 100644 --- a/src/components/settings/components/ScrollBarComponent.jsx +++ b/src/components/settings/components/ScrollBarComponent.jsx @@ -16,6 +16,7 @@ export default function ScrollBarComponent({ cb, value, disabled, title, descrip setScrollValue(value); } else { if(!isNaN(+value)) { + if(value > max) value = max; setScrollValue(times_10 ? value * 10 : value); } setTextValue(value); @@ -23,7 +24,7 @@ export default function ScrollBarComponent({ cb, value, disabled, title, descrip } useEffect(()=>{ - !isNaN(+textValue) && checkValue() && cb(+textValue); + textValue !== value && !isNaN(+textValue) && checkValue() && cb(+textValue); // eslint-disable-next-line }, [textValue]) diff --git a/src/components/settings/index.jsx b/src/components/settings/index.jsx index 7cd84b4..9bd1bb1 100644 --- a/src/components/settings/index.jsx +++ b/src/components/settings/index.jsx @@ -3,6 +3,7 @@ import AwsSettings from "./AwsSettings"; import { getPlatformSettings, updatePlatformSettings } from "../../utils/general_settings"; import ModelSettings from "./ModelSettings"; import OpenaiSettings from "./OpenaiSettings"; +import WllamaSettings from "./WllamaSettings"; export default function Settings() { @@ -24,6 +25,10 @@ export default function Settings() { + .conversation-main > .title-bar .display-title { + width: fit-content; + height: 100%; + display: flex; + align-items: center; +} + +.chat > .conversation-main > .title-bar .display-title > .edit-icon { + margin-left: 10px; +} + .chat > .conversation-main > .title-bar > form { display: flex; } diff --git a/src/utils/general_settings.js b/src/utils/general_settings.js index b8fc1c5..225bbe2 100644 --- a/src/utils/general_settings.js +++ b/src/utils/general_settings.js @@ -4,7 +4,12 @@ const DEFAULT_PLATFORM_SETTINGS = { // aws aws_model_id: '', aws_region: '', // openai - openai_model: '' + openai_model: '', + // wllama + wllama_threads: 4, + wllama_batch_size: 128, + wllama_context_length: 4096, + wllama_continue_conv: false } const MODEL_SETTINGS_KEY = 'general-model-settings' diff --git a/src/utils/workers/index.js b/src/utils/workers/index.js index df30481..1b1d61e 100644 --- a/src/utils/workers/index.js +++ b/src/utils/workers/index.js @@ -23,7 +23,10 @@ export function getCompletionFunctions() { case 'OpenAI': return { completions: OpenaiCompletions, abort: OpenaiAbort, platform: "OpenAI"} default: - return { completions: WllamaCompletions, abort: WllamaAbort, platform: "Wllama" } + return { + completions: WllamaCompletions, abort: WllamaAbort, + platform: "Wllama", continue_chat: platform_settings.wllama_continue_conv + } } } \ No newline at end of file diff --git a/src/utils/workers/worker.js b/src/utils/workers/worker.js index 4c6325b..c671616 100644 --- a/src/utils/workers/worker.js +++ b/src/utils/workers/worker.js @@ -6,6 +6,7 @@ import wllamaSingle from '@wllama/wllama/src/single-thread/wllama.wasm?url'; import wllamaMultiJS from '@wllama/wllama/src/multi-thread/wllama.js?url'; import wllamaMulti from '@wllama/wllama/src/multi-thread/wllama.wasm?url'; import wllamaMultiWorker from '@wllama/wllama/src/multi-thread/wllama.worker.mjs?url'; +import { getModelSettings, getPlatformSettings } from "../general_settings"; const CONFIG_PATHS = { 'single-thread/wllama.js': wllamaSingleJS, @@ -15,6 +16,32 @@ const CONFIG_PATHS = { 'multi-thread/wllama.worker.mjs': wllamaMultiWorker, }; +let model_sampling_settings = {} + +export function loadModelSamplingSettings() { + const { + wllama_threads, + wllama_batch_size, + wllama_context_length + } = getPlatformSettings(); + + const { + max_tokens, + top_p, + temperature + } = getModelSettings(); + + model_sampling_settings = { + n_threads: wllama_threads, + n_batch: wllama_batch_size, + n_ctx: wllama_context_length, + nPredict: max_tokens, + temp: temperature, + top_p + } +} +loadModelSamplingSettings(); + const DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"; const engines = { @@ -74,10 +101,9 @@ export async function loadModel(type = 'completion', cb = null) { await downloadModel(type, cb); } cb && cb('loading') + const {n_threads, n_batch, n_ctx} = model_sampling_settings; await instance.loadModelFromUrl(model_src, { - n_threads: 6, - n_ctx: 4096, - n_batch: 128, + n_threads, n_ctx, n_batch, }); } catch(error) { console.error(error) @@ -104,11 +130,13 @@ export async function formatPrompt(messages) { export async function chatCompletions(messages, cb = null) { stop_signal = false; try { + const { nPredict, temp, top_p } = model_sampling_settings; + const prompt = await formatPrompt(messages) const result = await engines['completion'].instance.createCompletion(prompt, { - nPredict: 256, + nPredict, sampling: { - temp: 0.7 + temp, top_p }, onNewToken: (token, piece, currentText, optionals) => { cb && cb(currentText, false);