Skip to content

Commit

Permalink
fix: add stop tokens that are in the chat template (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ghorbani authored Nov 29, 2024
1 parent d0d5ee2 commit 600918b
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 38 deletions.
1 change: 0 additions & 1 deletion jest/fixtures/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ export const mockDefaultCompletionParams: CompletionParams = {
temperature: 0.7,
top_k: 40,
top_p: 0.95,
tfs_z: 1.0,
min_p: 0.05,
xtc_threshold: 0.1,
xtc_probability: 0.01,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ export const CompletionSettings: React.FC<Props> = ({settings, onChange}) => {
{renderSlider('temperature', 0, 1)}
{renderSlider('top_k', 1, 128, 1)}
{renderSlider('top_p', 0, 1)}
{renderSlider('tfs_z', 0, 2)}
{renderSlider('min_p', 0, 1)}
{renderSlider('xtc_threshold', 0, 1)}
{renderSlider('xtc_probability', 0, 1)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ describe('CompletionSettings', () => {
const topPSlider = getByTestId('top_p-slider');
expect(topPSlider.props.value).toBe(0.95);

expect(getByTestId('tfs_z-slider')).toBeTruthy();
const tfsZSlider = getByTestId('tfs_z-slider');
expect(tfsZSlider.props.value).toBe(1);

expect(getByTestId('min_p-slider')).toBeTruthy();
const minPSlider = getByTestId('min_p-slider');
expect(minPSlider.props.value).toBe(0.05);
Expand Down
105 changes: 74 additions & 31 deletions src/store/ModelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils';
import {
getHFDefaultSettings,
getLocalModelDefaultSettings,
stops,
} from '../utils/chat';
import {
ChatTemplateConfig,
Expand Down Expand Up @@ -565,37 +566,7 @@ class ModelStore {
},
);

// Get stop token from the model and add to the list of stop tokens.
const eos_token_id = Number(
(ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'],
);

if (!isNaN(eos_token_id)) {
const detokenized = await ctx.detokenize([eos_token_id]);
const storeModel = this.models.find(m => m.id === model.id);
if (detokenized && storeModel) {
runInAction(() => {
// Helper function to check and update stop tokens
const updateStopTokens = (settings: CompletionParams) => {
if (!settings.stop) {
settings.stop = [detokenized];
} else if (!settings.stop.includes(detokenized)) {
settings.stop = [...settings.stop, detokenized];
}
// Create new object reference to ensure MobX picks up the change
return {...settings};
};

// Update both default and current completion settings
storeModel.defaultCompletionSettings = updateStopTokens(
storeModel.defaultCompletionSettings,
);
storeModel.completionSettings = updateStopTokens(
storeModel.completionSettings,
);
});
}
}
await this.updateModelStopTokens(ctx, model);

runInAction(() => {
this.context = ctx;
Expand Down Expand Up @@ -808,6 +779,78 @@ class ModelStore {
(this.context?.model as any)?.metadata?.['general.name'] ?? 'Chat Page'
);
}

/**
* Updates stop tokens for a model based on its context and chat template
* @param ctx - The LlamaContext instance
* @param model - App model to update stop tokens for
*/
private async updateModelStopTokens(ctx: LlamaContext, model: Model) {
const storeModel = this.models.find(m => m.id === model.id);
if (!storeModel) {
return;
}

const stopTokens: string[] = [];

try {
// Get EOS token from model metadata
const eos_token_id = Number(
(ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'],
);

if (!isNaN(eos_token_id)) {
const detokenized = await ctx.detokenize([eos_token_id]);
if (detokenized) {
stopTokens.push(detokenized);
}
}

// Add relevant stop tokens from chat templates
// First check model's custom chat template.
const template = storeModel.chatTemplate?.chatTemplate;
console.log('template: ', template);
if (template) {
const templateStops = stops.filter(stop => template.includes(stop));
stopTokens.push(...templateStops);
}

// Then check context's chat template
const ctxtTemplate = (ctx.model as any)?.metadata?.[
'tokenizer.chat_template'
];
console.log('ctxtTemplate: ', ctxtTemplate);
if (ctxtTemplate) {
const contextStops = stops.filter(stop => ctxtTemplate.includes(stop));
stopTokens.push(...contextStops);
}

console.log('stopTokens: ', stopTokens);
// Only update if we found stop tokens
if (stopTokens.length > 0) {
runInAction(() => {
// Helper function to check and update stop tokens
const updateStopTokens = (settings: CompletionParams) => {
const uniqueStops = Array.from(
new Set([...(settings.stop || []), ...stopTokens]),
).filter(Boolean); // Remove any null/undefined/empty values
return {...settings, stop: uniqueStops};
};

// Update both default and current completion settings
storeModel.defaultCompletionSettings = updateStopTokens(
storeModel.defaultCompletionSettings,
);
storeModel.completionSettings = updateStopTokens(
storeModel.completionSettings,
);
});
}
} catch (error) {
console.error('Error updating model stop tokens:', error);
// Continue execution - stop token update is not critical
}
}
}

export const modelStore = new ModelStore();
13 changes: 12 additions & 1 deletion src/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ export const defaultCompletionParams: CompletionParams = {
temperature: 0.7, // The randomness of the generated text.
top_k: 40, // Limit the next token selection to the K most probable tokens.
top_p: 0.95, // Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.
tfs_z: 1.0, //Enable tail free sampling with parameter z. Default: `1.0`, which is disabled.
min_p: 0.05, //The minimum probability for a token to be considered, relative to the probability of the most likely token.
xtc_threshold: 0.1, // Sets a minimum probability threshold for tokens to be removed.
xtc_probability: 0.0, // Sets the chance for token removal (checked once on sampler start)
Expand All @@ -236,3 +235,15 @@ export const defaultCompletionParams: CompletionParams = {
stop: ['</s>'],
// emit_partial_completion: true, // This is not used in the current version of llama.rn
};

export const stops = [
'</s>',
'<|end|>',
'<|eot_id|>',
'<|end_of_text|>',
'<|im_end|>',
'<|EOT|>',
'<|END_OF_TURN_TOKEN|>',
'<|end_of_turn|>',
'<|endoftext|>',
];

0 comments on commit 600918b

Please sign in to comment.