From 600918b54f3ad6f79ca5fc30653d9d7a2acaa9fc Mon Sep 17 00:00:00 2001 From: Asghar Ghorbani Date: Fri, 29 Nov 2024 23:11:08 +0100 Subject: [PATCH] fix: add stop tokens that are in the chat template (#113) --- jest/fixtures/models.ts | 1 - .../CompletionSettings/CompletionSettings.tsx | 1 - .../__tests__/CompletionSettings.test.tsx | 4 - src/store/ModelStore.ts | 105 ++++++++++++------ src/utils/chat.ts | 13 ++- 5 files changed, 86 insertions(+), 38 deletions(-) diff --git a/jest/fixtures/models.ts b/jest/fixtures/models.ts index 0bfa39d..c710db6 100644 --- a/jest/fixtures/models.ts +++ b/jest/fixtures/models.ts @@ -16,7 +16,6 @@ export const mockDefaultCompletionParams: CompletionParams = { temperature: 0.7, top_k: 40, top_p: 0.95, - tfs_z: 1.0, min_p: 0.05, xtc_threshold: 0.1, xtc_probability: 0.01, diff --git a/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx b/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx index d37f011..6f3b95d 100644 --- a/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx +++ b/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx @@ -96,7 +96,6 @@ export const CompletionSettings: React.FC = ({settings, onChange}) => { {renderSlider('temperature', 0, 1)} {renderSlider('top_k', 1, 128, 1)} {renderSlider('top_p', 0, 1)} - {renderSlider('tfs_z', 0, 2)} {renderSlider('min_p', 0, 1)} {renderSlider('xtc_threshold', 0, 1)} {renderSlider('xtc_probability', 0, 1)} diff --git a/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx b/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx index 9e483bf..f8574b0 100644 --- a/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx +++ b/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx @@ -29,10 +29,6 @@ describe('CompletionSettings', () => { const topPSlider = getByTestId('top_p-slider'); expect(topPSlider.props.value).toBe(0.95); - expect(getByTestId('tfs_z-slider')).toBeTruthy(); - const tfsZSlider = getByTestId('tfs_z-slider'); - expect(tfsZSlider.props.value).toBe(1); - expect(getByTestId('min_p-slider')).toBeTruthy(); const minPSlider = getByTestId('min_p-slider'); expect(minPSlider.props.value).toBe(0.05); diff --git a/src/store/ModelStore.ts b/src/store/ModelStore.ts index 27cbf8b..99df493 100644 --- a/src/store/ModelStore.ts +++ b/src/store/ModelStore.ts @@ -15,6 +15,7 @@ import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils'; import { getHFDefaultSettings, getLocalModelDefaultSettings, + stops, } from '../utils/chat'; import { ChatTemplateConfig, @@ -565,37 +566,7 @@ class ModelStore { }, ); - // Get stop token from the model and add to the list of stop tokens. - const eos_token_id = Number( - (ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'], - ); - - if (!isNaN(eos_token_id)) { - const detokenized = await ctx.detokenize([eos_token_id]); - const storeModel = this.models.find(m => m.id === model.id); - if (detokenized && storeModel) { - runInAction(() => { - // Helper function to check and update stop tokens - const updateStopTokens = (settings: CompletionParams) => { - if (!settings.stop) { - settings.stop = [detokenized]; - } else if (!settings.stop.includes(detokenized)) { - settings.stop = [...settings.stop, detokenized]; - } - // Create new object reference to ensure MobX picks up the change - return {...settings}; - }; - - // Update both default and current completion settings - storeModel.defaultCompletionSettings = updateStopTokens( - storeModel.defaultCompletionSettings, - ); - storeModel.completionSettings = updateStopTokens( - storeModel.completionSettings, - ); - }); - } - } + await this.updateModelStopTokens(ctx, model); runInAction(() => { this.context = ctx; @@ -808,6 +779,78 @@ class ModelStore { (this.context?.model as any)?.metadata?.['general.name'] ?? 'Chat Page' ); } + + /** + * Updates stop tokens for a model based on its context and chat template + * @param ctx - The LlamaContext instance + * @param model - App model to update stop tokens for + */ + private async updateModelStopTokens(ctx: LlamaContext, model: Model) { + const storeModel = this.models.find(m => m.id === model.id); + if (!storeModel) { + return; + } + + const stopTokens: string[] = []; + + try { + // Get EOS token from model metadata + const eos_token_id = Number( + (ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'], + ); + + if (!isNaN(eos_token_id)) { + const detokenized = await ctx.detokenize([eos_token_id]); + if (detokenized) { + stopTokens.push(detokenized); + } + } + + // Add relevant stop tokens from chat templates + // First check model's custom chat template. + const template = storeModel.chatTemplate?.chatTemplate; + console.log('template: ', template); + if (template) { + const templateStops = stops.filter(stop => template.includes(stop)); + stopTokens.push(...templateStops); + } + + // Then check context's chat template + const ctxtTemplate = (ctx.model as any)?.metadata?.[ + 'tokenizer.chat_template' + ]; + console.log('ctxtTemplate: ', ctxtTemplate); + if (ctxtTemplate) { + const contextStops = stops.filter(stop => ctxtTemplate.includes(stop)); + stopTokens.push(...contextStops); + } + + console.log('stopTokens: ', stopTokens); + // Only update if we found stop tokens + if (stopTokens.length > 0) { + runInAction(() => { + // Helper function to check and update stop tokens + const updateStopTokens = (settings: CompletionParams) => { + const uniqueStops = Array.from( + new Set([...(settings.stop || []), ...stopTokens]), + ).filter(Boolean); // Remove any null/undefined/empty values + return {...settings, stop: uniqueStops}; + }; + + // Update both default and current completion settings + storeModel.defaultCompletionSettings = updateStopTokens( + storeModel.defaultCompletionSettings, + ); + storeModel.completionSettings = updateStopTokens( + storeModel.completionSettings, + ); + }); + } + } catch (error) { + console.error('Error updating model stop tokens:', error); + // Continue execution - stop token update is not critical + } + } } export const modelStore = new ModelStore(); diff --git a/src/utils/chat.ts b/src/utils/chat.ts index c807fb8..f382abf 100644 --- a/src/utils/chat.ts +++ b/src/utils/chat.ts @@ -218,7 +218,6 @@ export const defaultCompletionParams: CompletionParams = { temperature: 0.7, // The randomness of the generated text. top_k: 40, // Limit the next token selection to the K most probable tokens. top_p: 0.95, // Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. - tfs_z: 1.0, //Enable tail free sampling with parameter z. Default: `1.0`, which is disabled. min_p: 0.05, //The minimum probability for a token to be considered, relative to the probability of the most likely token. xtc_threshold: 0.1, // Sets a minimum probability threshold for tokens to be removed. xtc_probability: 0.0, // Sets the chance for token removal (checked once on sampler start) @@ -236,3 +235,15 @@ export const defaultCompletionParams: CompletionParams = { stop: [''], // emit_partial_completion: true, // This is not used in the current version of llama.rn }; + +export const stops = [ + '', + '<|end|>', + '<|eot_id|>', + '<|end_of_text|>', + '<|im_end|>', + '<|EOT|>', + '<|END_OF_TURN_TOKEN|>', + '<|end_of_turn|>', + '<|endoftext|>', +];