Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add stop tokens that are in the chat template and remove tfs_z #113

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion jest/fixtures/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ export const mockDefaultCompletionParams: CompletionParams = {
temperature: 0.7,
top_k: 40,
top_p: 0.95,
tfs_z: 1.0,
min_p: 0.05,
xtc_threshold: 0.1,
xtc_probability: 0.01,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ export const CompletionSettings: React.FC<Props> = ({settings, onChange}) => {
{renderSlider('temperature', 0, 1)}
{renderSlider('top_k', 1, 128, 1)}
{renderSlider('top_p', 0, 1)}
{renderSlider('tfs_z', 0, 2)}
{renderSlider('min_p', 0, 1)}
{renderSlider('xtc_threshold', 0, 1)}
{renderSlider('xtc_probability', 0, 1)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ describe('CompletionSettings', () => {
const topPSlider = getByTestId('top_p-slider');
expect(topPSlider.props.value).toBe(0.95);

expect(getByTestId('tfs_z-slider')).toBeTruthy();
const tfsZSlider = getByTestId('tfs_z-slider');
expect(tfsZSlider.props.value).toBe(1);

expect(getByTestId('min_p-slider')).toBeTruthy();
const minPSlider = getByTestId('min_p-slider');
expect(minPSlider.props.value).toBe(0.05);
Expand Down
105 changes: 74 additions & 31 deletions src/store/ModelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils';
import {
getHFDefaultSettings,
getLocalModelDefaultSettings,
stops,
} from '../utils/chat';
import {
ChatTemplateConfig,
Expand Down Expand Up @@ -565,37 +566,7 @@ class ModelStore {
},
);

// Get stop token from the model and add to the list of stop tokens.
const eos_token_id = Number(
(ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'],
);

if (!isNaN(eos_token_id)) {
const detokenized = await ctx.detokenize([eos_token_id]);
const storeModel = this.models.find(m => m.id === model.id);
if (detokenized && storeModel) {
runInAction(() => {
// Helper function to check and update stop tokens
const updateStopTokens = (settings: CompletionParams) => {
if (!settings.stop) {
settings.stop = [detokenized];
} else if (!settings.stop.includes(detokenized)) {
settings.stop = [...settings.stop, detokenized];
}
// Create new object reference to ensure MobX picks up the change
return {...settings};
};

// Update both default and current completion settings
storeModel.defaultCompletionSettings = updateStopTokens(
storeModel.defaultCompletionSettings,
);
storeModel.completionSettings = updateStopTokens(
storeModel.completionSettings,
);
});
}
}
await this.updateModelStopTokens(ctx, model);

runInAction(() => {
this.context = ctx;
Expand Down Expand Up @@ -808,6 +779,78 @@ class ModelStore {
(this.context?.model as any)?.metadata?.['general.name'] ?? 'Chat Page'
);
}

/**
* Updates stop tokens for a model based on its context and chat template
* @param ctx - The LlamaContext instance
* @param model - App model to update stop tokens for
*/
private async updateModelStopTokens(ctx: LlamaContext, model: Model) {
const storeModel = this.models.find(m => m.id === model.id);
if (!storeModel) {
return;
}

const stopTokens: string[] = [];

try {
// Get EOS token from model metadata
const eos_token_id = Number(
(ctx.model as any)?.metadata?.['tokenizer.ggml.eos_token_id'],
);

if (!isNaN(eos_token_id)) {
const detokenized = await ctx.detokenize([eos_token_id]);
if (detokenized) {
stopTokens.push(detokenized);
}
}

// Add relevant stop tokens from chat templates
// First check model's custom chat template.
const template = storeModel.chatTemplate?.chatTemplate;
console.log('template: ', template);
if (template) {
const templateStops = stops.filter(stop => template.includes(stop));
stopTokens.push(...templateStops);
}

// Then check context's chat template
const ctxtTemplate = (ctx.model as any)?.metadata?.[
'tokenizer.chat_template'
];
console.log('ctxtTemplate: ', ctxtTemplate);
if (ctxtTemplate) {
const contextStops = stops.filter(stop => ctxtTemplate.includes(stop));
stopTokens.push(...contextStops);
}

console.log('stopTokens: ', stopTokens);
// Only update if we found stop tokens
if (stopTokens.length > 0) {
runInAction(() => {
// Helper function to check and update stop tokens
const updateStopTokens = (settings: CompletionParams) => {
const uniqueStops = Array.from(
new Set([...(settings.stop || []), ...stopTokens]),
).filter(Boolean); // Remove any null/undefined/empty values
return {...settings, stop: uniqueStops};
};

// Update both default and current completion settings
storeModel.defaultCompletionSettings = updateStopTokens(
storeModel.defaultCompletionSettings,
);
storeModel.completionSettings = updateStopTokens(
storeModel.completionSettings,
);
});
}
} catch (error) {
console.error('Error updating model stop tokens:', error);
// Continue execution - stop token update is not critical
}
}
}

export const modelStore = new ModelStore();
13 changes: 12 additions & 1 deletion src/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ export const defaultCompletionParams: CompletionParams = {
temperature: 0.7, // The randomness of the generated text.
top_k: 40, // Limit the next token selection to the K most probable tokens.
top_p: 0.95, // Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.
tfs_z: 1.0, //Enable tail free sampling with parameter z. Default: `1.0`, which is disabled.
min_p: 0.05, //The minimum probability for a token to be considered, relative to the probability of the most likely token.
xtc_threshold: 0.1, // Sets a minimum probability threshold for tokens to be removed.
xtc_probability: 0.0, // Sets the chance for token removal (checked once on sampler start)
Expand All @@ -236,3 +235,15 @@ export const defaultCompletionParams: CompletionParams = {
stop: ['</s>'],
// emit_partial_completion: true, // This is not used in the current version of llama.rn
};

export const stops = [
'</s>',
'<|end|>',
'<|eot_id|>',
'<|end_of_text|>',
'<|im_end|>',
'<|EOT|>',
'<|END_OF_TURN_TOKEN|>',
'<|end_of_turn|>',
'<|endoftext|>',
];