From b2d6386cc2c438a517e0647fb70a5407c8a4be21 Mon Sep 17 00:00:00 2001 From: Asghar Ghorbani Date: Fri, 29 Nov 2024 18:57:52 +0100 Subject: [PATCH] fix: add empty systemPrompt for hf and local models (#111) --- __mocks__/stores/modelStore.ts | 14 ++++- src/hooks/__tests__/useChatSession.test.ts | 70 ++++++++++++++++++++++ src/hooks/useChatSession.ts | 2 +- src/utils/chat.ts | 2 + 4 files changed, 85 insertions(+), 3 deletions(-) diff --git a/__mocks__/stores/modelStore.ts b/__mocks__/stores/modelStore.ts index 0de626e..aea647a 100644 --- a/__mocks__/stores/modelStore.ts +++ b/__mocks__/stores/modelStore.ts @@ -7,7 +7,7 @@ export const mockModelStore = { useAutoRelease: true, useMetal: false, n_gpu_layers: 50, - activeModel: null, + activeModelId: undefined as string | undefined, setNContext: jest.fn(), updateUseAutoRelease: jest.fn(), updateUseMetal: jest.fn(), @@ -15,11 +15,13 @@ export const mockModelStore = { refreshDownloadStatuses: jest.fn(), addLocalModel: jest.fn(), resetModels: jest.fn(), - //initContext: jest.fn().mockResolvedValue(undefined), initContext: jest.fn().mockResolvedValue(Promise.resolve()), checkSpaceAndDownload: jest.fn(), getDownloadProgress: jest.fn(), manualReleaseContext: jest.fn(), + setActiveModel(modelId: string) { + this.activeModelId = modelId; + }, }; Object.defineProperty(mockModelStore, 'lastUsedModel', { get: jest.fn(() => undefined), @@ -29,3 +31,11 @@ Object.defineProperty(mockModelStore, 'isDownloading', { get: jest.fn(() => () => false), configurable: true, }); +Object.defineProperty(mockModelStore, 'activeModel', { + get: jest.fn(() => + mockModelStore.models.find( + model => model.id === mockModelStore.activeModelId, + ), + ), + configurable: true, +}); diff --git a/src/hooks/__tests__/useChatSession.test.ts b/src/hooks/__tests__/useChatSession.test.ts index 45b6f87..5fa6efd 100644 --- a/src/hooks/__tests__/useChatSession.test.ts +++ b/src/hooks/__tests__/useChatSession.test.ts @@ -2,6 +2,7 @@ import {LlamaContext} from '@pocketpalai/llama.rn'; import {renderHook, act} from '@testing-library/react-native'; import {textMessage} from '../../../jest/fixtures'; +import {mockBasicModel, modelsList} from '../../../jest/fixtures/models'; import {useChatSession} from '../useChatSession'; @@ -9,6 +10,7 @@ import {chatSessionStore, modelStore} from '../../store'; import {l10n} from '../../utils/l10n'; import {assistant} from '../../utils/chat'; +import {ChatMessage} from '../../utils/types'; const mockL10n = l10n.en; @@ -25,8 +27,17 @@ beforeEach(() => { model: {}, }); }); +modelStore.models = modelsList; + +const applyChatTemplateSpy = jest + .spyOn(require('../../utils/chat'), 'applyChatTemplate') + .mockImplementation(async () => 'mocked prompt'); describe('useChatSession', () => { + beforeEach(() => { + applyChatTemplateSpy.mockClear(); + }); + it('should send a message and update the chat session', async () => { const {result} = renderHook(() => useChatSession( @@ -234,4 +245,63 @@ describe('useChatSession', () => { }); expect(result.current.inferencing).toBe(false); }); + + test.each([ + {systemPrompt: undefined, shouldInclude: false, description: 'undefined'}, + {systemPrompt: '', shouldInclude: false, description: 'empty string'}, + {systemPrompt: ' ', shouldInclude: false, description: 'whitespace-only'}, + { + systemPrompt: 'You are a helpful assistant', + shouldInclude: true, + description: 'valid prompt', + }, + { + systemPrompt: ' Trimmed prompt ', + shouldInclude: true, + description: 'prompt with whitespace', + }, + ])( + 'should handle system prompt for $description', + async ({systemPrompt, shouldInclude}) => { + const testModel = { + ...mockBasicModel, + id: 'test-model', + chatTemplate: {...mockBasicModel.chatTemplate, systemPrompt}, + }; + + modelStore.models = [testModel]; + modelStore.setActiveModel(testModel.id); + + const {result} = renderHook(() => + useChatSession( + modelStore.context, + {current: null}, + [], + textMessage.author, + mockAssistant, + ), + ); + + await act(async () => { + await result.current.handleSendPress(textMessage); + }); + + if (shouldInclude && systemPrompt) { + expect(applyChatTemplateSpy).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + role: 'system', + content: systemPrompt, + }), + ]), + expect.any(Object), + expect.any(Object), + ); + } else { + const call = applyChatTemplateSpy.mock.calls[0]; + const messages = call[0] as ChatMessage[]; + expect(messages.some(msg => msg.role === 'system')).toBe(false); + } + }, + ); }); diff --git a/src/hooks/useChatSession.ts b/src/hooks/useChatSession.ts index b405740..181bfe2 100644 --- a/src/hooks/useChatSession.ts +++ b/src/hooks/useChatSession.ts @@ -97,7 +97,7 @@ export const useChatSession = ( currentMessageInfo.current = {createdAt, id}; const chatMessages = [ - ...(modelStore.activeModel?.chatTemplate?.systemPrompt + ...(modelStore.activeModel?.chatTemplate?.systemPrompt?.trim() ? [ { role: 'system' as 'system', diff --git a/src/utils/chat.ts b/src/utils/chat.ts index c3f26e1..c807fb8 100644 --- a/src/utils/chat.ts +++ b/src/utils/chat.ts @@ -86,6 +86,7 @@ export const chatTemplates: Record = { bosToken: '', eosToken: '', chatTemplate: '', + systemPrompt: '', }, danube3: { ...Templates.templates.danube2, @@ -196,6 +197,7 @@ export function getHFDefaultSettings(hfModel: HuggingFaceModel): { //chatTemplate: hfModel.specs?.gguf?.chat_template ?? '', chatTemplate: '', // At the moment chatTemplate needs to be nunjucks, not jinja2. So by using empty string we force the use of gguf's chat template. addGenerationPrompt: true, + systemPrompt: '', name: 'custom', };