Skip to content

Commit

Permalink
fix: add empty systemPrompt for hf and local models (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ghorbani authored Nov 29, 2024
1 parent 8d05741 commit b2d6386
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 3 deletions.
14 changes: 12 additions & 2 deletions __mocks__/stores/modelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ export const mockModelStore = {
useAutoRelease: true,
useMetal: false,
n_gpu_layers: 50,
activeModel: null,
activeModelId: undefined as string | undefined,
setNContext: jest.fn(),
updateUseAutoRelease: jest.fn(),
updateUseMetal: jest.fn(),
setNGPULayers: jest.fn(),
refreshDownloadStatuses: jest.fn(),
addLocalModel: jest.fn(),
resetModels: jest.fn(),
//initContext: jest.fn().mockResolvedValue(undefined),
initContext: jest.fn().mockResolvedValue(Promise.resolve()),
checkSpaceAndDownload: jest.fn(),
getDownloadProgress: jest.fn(),
manualReleaseContext: jest.fn(),
setActiveModel(modelId: string) {
this.activeModelId = modelId;
},
};
Object.defineProperty(mockModelStore, 'lastUsedModel', {
get: jest.fn(() => undefined),
Expand All @@ -29,3 +31,11 @@ Object.defineProperty(mockModelStore, 'isDownloading', {
get: jest.fn(() => () => false),
configurable: true,
});
Object.defineProperty(mockModelStore, 'activeModel', {
get: jest.fn(() =>
mockModelStore.models.find(
model => model.id === mockModelStore.activeModelId,
),
),
configurable: true,
});
70 changes: 70 additions & 0 deletions src/hooks/__tests__/useChatSession.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ import {LlamaContext} from '@pocketpalai/llama.rn';
import {renderHook, act} from '@testing-library/react-native';

import {textMessage} from '../../../jest/fixtures';
import {mockBasicModel, modelsList} from '../../../jest/fixtures/models';

import {useChatSession} from '../useChatSession';

import {chatSessionStore, modelStore} from '../../store';

import {l10n} from '../../utils/l10n';
import {assistant} from '../../utils/chat';
import {ChatMessage} from '../../utils/types';

const mockL10n = l10n.en;

Expand All @@ -25,8 +27,17 @@ beforeEach(() => {
model: {},
});
});
modelStore.models = modelsList;

const applyChatTemplateSpy = jest
.spyOn(require('../../utils/chat'), 'applyChatTemplate')
.mockImplementation(async () => 'mocked prompt');

describe('useChatSession', () => {
beforeEach(() => {
applyChatTemplateSpy.mockClear();
});

it('should send a message and update the chat session', async () => {
const {result} = renderHook(() =>
useChatSession(
Expand Down Expand Up @@ -234,4 +245,63 @@ describe('useChatSession', () => {
});
expect(result.current.inferencing).toBe(false);
});

test.each([
{systemPrompt: undefined, shouldInclude: false, description: 'undefined'},
{systemPrompt: '', shouldInclude: false, description: 'empty string'},
{systemPrompt: ' ', shouldInclude: false, description: 'whitespace-only'},
{
systemPrompt: 'You are a helpful assistant',
shouldInclude: true,
description: 'valid prompt',
},
{
systemPrompt: ' Trimmed prompt ',
shouldInclude: true,
description: 'prompt with whitespace',
},
])(
'should handle system prompt for $description',
async ({systemPrompt, shouldInclude}) => {
const testModel = {
...mockBasicModel,
id: 'test-model',
chatTemplate: {...mockBasicModel.chatTemplate, systemPrompt},
};

modelStore.models = [testModel];
modelStore.setActiveModel(testModel.id);

const {result} = renderHook(() =>
useChatSession(
modelStore.context,
{current: null},
[],
textMessage.author,
mockAssistant,
),
);

await act(async () => {
await result.current.handleSendPress(textMessage);
});

if (shouldInclude && systemPrompt) {
expect(applyChatTemplateSpy).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({
role: 'system',
content: systemPrompt,
}),
]),
expect.any(Object),
expect.any(Object),
);
} else {
const call = applyChatTemplateSpy.mock.calls[0];
const messages = call[0] as ChatMessage[];
expect(messages.some(msg => msg.role === 'system')).toBe(false);
}
},
);
});
2 changes: 1 addition & 1 deletion src/hooks/useChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export const useChatSession = (
currentMessageInfo.current = {createdAt, id};

const chatMessages = [
...(modelStore.activeModel?.chatTemplate?.systemPrompt
...(modelStore.activeModel?.chatTemplate?.systemPrompt?.trim()
? [
{
role: 'system' as 'system',
Expand Down
2 changes: 2 additions & 0 deletions src/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export const chatTemplates: Record<string, ChatTemplateConfig> = {
bosToken: '',
eosToken: '',
chatTemplate: '',
systemPrompt: '',
},
danube3: {
...Templates.templates.danube2,
Expand Down Expand Up @@ -196,6 +197,7 @@ export function getHFDefaultSettings(hfModel: HuggingFaceModel): {
//chatTemplate: hfModel.specs?.gguf?.chat_template ?? '',
chatTemplate: '', // At the moment chatTemplate needs to be nunjucks, not jinja2. So by using empty string we force the use of gguf's chat template.
addGenerationPrompt: true,
systemPrompt: '',
name: 'custom',
};

Expand Down

0 comments on commit b2d6386

Please sign in to comment.