Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds SmolLM #78

Merged
merged 3 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ android {
applicationId "com.pocketpalai"
minSdkVersion rootProject.ext.minSdkVersion
targetSdkVersion rootProject.ext.targetSdkVersion
versionCode 12
versionName "1.4.5"
versionCode 13
versionName "1.4.6"
ndk {
abiFilters "arm64-v8a", "x86_64"
}
Expand Down
8 changes: 4 additions & 4 deletions ios/PocketPal.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CURRENT_PROJECT_VERSION = 12;
CURRENT_PROJECT_VERSION = 13;
DEVELOPMENT_TEAM = MYXGXY23Y6;
ENABLE_BITCODE = NO;
INFOPLIST_FILE = PocketPal/Info.plist;
Expand All @@ -488,7 +488,7 @@
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.4.5;
MARKETING_VERSION = 1.4.6;
OTHER_CPLUSPLUSFLAGS = (
"$(OTHER_CFLAGS)",
"-DFOLLY_NO_CONFIG",
Expand Down Expand Up @@ -521,15 +521,15 @@
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CURRENT_PROJECT_VERSION = 12;
CURRENT_PROJECT_VERSION = 13;
DEVELOPMENT_TEAM = MYXGXY23Y6;
INFOPLIST_FILE = PocketPal/Info.plist;
IPHONEOS_DEPLOYMENT_TARGET = 14.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.4.5;
MARKETING_VERSION = 1.4.6;
OTHER_CPLUSPLUSFLAGS = (
"$(OTHER_CFLAGS)",
"-DFOLLY_NO_CONFIG",
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "PocketPal",
"version": "1.4.5",
"version": "1.4.6",
"private": true,
"scripts": {
"prepare": "husky",
Expand Down
12 changes: 4 additions & 8 deletions src/hooks/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ import {L10nContext} from '../utils';
import {chatSessionStore, modelStore} from '../store';

import {MessageType, User} from '../utils/types';
import {
applyChatTemplate,
chatTemplates,
convertToChatMessages,
} from '../utils/chat';
import {applyChatTemplate, convertToChatMessages} from '../utils/chat';

export const useChatSession = (
context: LlamaContext | undefined,
Expand Down Expand Up @@ -113,10 +109,10 @@ export const useChatSession = (
]),
];

const prompt = applyChatTemplate(
modelStore.activeModel?.chatTemplate || chatTemplates.default,
const prompt = await applyChatTemplate(
chatMessages,
1000, // This is not used.
modelStore.activeModel ?? null,
context,
);

const completionParams = toJS(modelStore.activeModel?.completionSettings);
Expand Down
68 changes: 67 additions & 1 deletion src/store/defaultModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ import {Model} from '../utils/types';
import {chatTemplates, defaultCompletionParams} from '../utils/chat';
import {Platform} from 'react-native';

export const MODEL_LIST_VERSION = 6;
export const MODEL_LIST_VERSION = 8;

export const defaultModels: Model[] = [
// -------- Gemma --------
{
id: 'google/gemma-2-2b-it-GGUF',
name: 'gemma-2-2b-it-GGUF (Q6_K)',
Expand Down Expand Up @@ -97,6 +98,7 @@ export const defaultModels: Model[] = [
stop: ['<end_of_turn>'],
},
},
// -------- Danube --------
{
id: 'h2o-danube3-4b-chat-Q4_K_M.gguf',
name: 'H2O.ai Danube 3 (Q4_K_M)',
Expand Down Expand Up @@ -193,6 +195,7 @@ export const defaultModels: Model[] = [
penalty_repeat: 1.075,
},
},
// -------- Phi --------
{
id: 'Phi-3.5-mini-instruct.Q4_K_M.gguf',
name: 'Phi-3.5 mini 4k instruct (Q4_K_M)',
Expand Down Expand Up @@ -249,6 +252,7 @@ export const defaultModels: Model[] = [
stop: ['<|end|>'],
},
},
// -------- Qwen --------
{
id: 'qwen2-1_5b-instruct-q8_0.gguf',
name: 'Qwen2-1.5B-Instruct (Q8_0)',
Expand Down Expand Up @@ -333,6 +337,7 @@ export const defaultModels: Model[] = [
stop: ['<|im_end|>'],
},
},
// -------- Llama --------
...(Platform.OS === 'android'
? [
{
Expand Down Expand Up @@ -537,4 +542,65 @@ export const defaultModels: Model[] = [
stop: ['<|eot_id|>'],
},
},
// -------- SmolLM --------
{
id: 'default-bartowski/SmolLM2-1.7B-Instruct-Q8_0.gguf',
name: 'SmolLM2-1.7B-Instruct (Q8_0)',
type: 'SmolLM',
size: '1.82',
params: '1.7',
isDownloaded: false,
downloadUrl:
'https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF/resolve/main/SmolLM2-1.7B-Instruct-Q8_0.gguf?download=true',
hfUrl: 'https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF',
progress: 0,
filename: 'default-SmolLM2-1.7B-Instruct-Q8_0.gguf',
isLocal: false,
defaultChatTemplate: chatTemplates.smolLM,
chatTemplate: chatTemplates.smolLM,
defaultCompletionSettings: {
...defaultCompletionParams,
n_predict: 500,
temperature: 0.7,
stop: ['<|endoftext|>', '<|im_end|>'],
},
completionSettings: {
...defaultCompletionParams,
n_predict: 500,
temperature: 0.7,
stop: ['<|endoftext|>', '<|im_end|>'],
},
},
...(Platform.OS === 'android'
? [
{
id: 'default-bartowski/SmolLM2-1.7B-Instruct-Q4_0_4_4.gguf',
name: 'SmolLM2-1.7B-Instruct (Q4_0_4_4)',
type: 'SmolLM',
size: '0.99',
params: '1.7',
isDownloaded: false,
downloadUrl:
'https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF/resolve/main/SmolLM2-1.7B-Instruct-Q4_0_4_4.gguf?download=true',
hfUrl: 'https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF',
progress: 0,
filename: 'default-SmolLM2-1.7B-Instruct-Q4_0_4_4.gguf',
isLocal: false,
defaultChatTemplate: chatTemplates.smolLM,
chatTemplate: chatTemplates.smolLM,
defaultCompletionSettings: {
...defaultCompletionParams,
n_predict: 1000,
temperature: 0.7,
stop: ['<|endoftext|>', '<|im_end|>'],
},
completionSettings: {
...defaultCompletionParams,
n_predict: 1000,
temperature: 0.7,
stop: ['<|endoftext|>', '<|im_end|>'],
},
},
]
: []),
];
4 changes: 3 additions & 1 deletion src/utils/__tests__/chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {Templates} from 'chat-formatter';
import {applyChatTemplate} from '../chat';
import {ChatMessage, ChatTemplateConfig} from '../types';
import {createModel} from '../../../jest/fixtures/models';

const conversationWSystem: ChatMessage[] = [
{role: 'system', content: 'System prompt. '},
Expand All @@ -18,7 +19,8 @@ describe('Test Danube2 Chat Templates', () => {
addGenerationPrompt: true,
name: 'danube2',
};
const result = applyChatTemplate(chatTemplate, conversationWSystem, -1);
const model = createModel({chatTemplate: chatTemplate});
const result = await applyChatTemplate(conversationWSystem, model, null);
expect(result).toBe(
'System prompt. </s><|prompt|>Hi there!</s><|answer|>Nice to meet you!</s><|prompt|>Can I ask a question?</s><|answer|>',
);
Expand Down
60 changes: 45 additions & 15 deletions src/utils/chat.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {applyTemplate, Templates} from 'chat-formatter';
import {ChatMessage, ChatTemplateConfig, MessageType} from './types';
//import {assistant} from '../store/ChatSessionStore';
import {CompletionParams} from '@pocketpalai/llama.rn';
import {ChatMessage, ChatTemplateConfig, MessageType, Model} from './types';
import {CompletionParams, LlamaContext} from '@pocketpalai/llama.rn';

export const userId = 'y9d7f8pgn';
export const assistantId = 'h3o3lc5xj';
Expand All @@ -22,21 +21,42 @@ export function convertToChatMessages(
.reverse();
}

export function applyChatTemplate(
template: ChatTemplateConfig,
chat: ChatMessage[],
// eslint-disable-next-line @typescript-eslint/no-unused-vars
length: number, //TODO: inforce length of formattedChat to fit the context.
): string {
const formattedChat: string = applyTemplate(chat, {
customTemplate: template,
addGenerationPrompt: template.addGenerationPrompt,
}) as string;
export async function applyChatTemplate(
messages: ChatMessage[],
model: Model | null,
context: LlamaContext | null,
): Promise<string> {
const modelChatTemplate = model?.chatTemplate;
const contextChatTemplate = (context?.model as any)?.metadata?.[
'tokenizer.chat_template'
];

return formattedChat;
let formattedChat: string | undefined;

try {
if (modelChatTemplate?.chatTemplate) {
formattedChat = applyTemplate(messages, {
customTemplate: modelChatTemplate,
addGenerationPrompt: modelChatTemplate.addGenerationPrompt,
}) as string;
} else if (contextChatTemplate) {
formattedChat = await context?.getFormattedChat(messages);
}

if (!formattedChat) {
formattedChat = applyTemplate(messages, {
customTemplate: chatTemplates.default,
addGenerationPrompt: chatTemplates.default.addGenerationPrompt,
}) as string;
}
} catch (error) {
console.error('Error applying chat template:', error); // TODO: handle error
}

return formattedChat || ' ';
}

export const chatTemplates = {
export const chatTemplates: Record<string, ChatTemplateConfig> = {
danube3: {
...Templates.templates.danube2,
name: 'danube3',
Expand Down Expand Up @@ -112,6 +132,16 @@ export const chatTemplates = {
systemPrompt:
'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.',
},
smolLM: {
name: 'smolLM',
addGenerationPrompt: true,
systemPrompt: 'You are a helpful assistant.',
bosToken: '<|im_start|>',
eosToken: '<|im_end|>',
addBosToken: false,
addEosToken: false,
chatTemplate: '',
},
};

export const defaultCompletionParams: CompletionParams = {
Expand Down
Loading