From 32e46d527af8cca4e8787aea008a1f105f9d8717 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 21 Oct 2024 14:02:08 +0900 Subject: [PATCH] Add support for gemma --- .github/workflows/build.yml | 3 +++ src/models/gemma.ts | 11 ++++++++++ src/models/llama.ts | 40 ++++++++++++++++++++++++++++--------- 3 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 src/models/gemma.ts diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2958ee6..e54cf75 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,9 @@ jobs: huggingface download --silent Xenova/t5-small --hf # fix tokenizer yarn tsx src/generate.ts --temperature=0 t5-small 'translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast.' + huggingface download --silent mlx-community/codegemma-1.1-2b-8bit + yarn tsx src/generate.ts --max-tokens=128 codegemma-1.1-2b-8bit + publish: if: startsWith(github.ref, 'refs/tags/') needs: [build] diff --git a/src/models/gemma.ts b/src/models/gemma.ts new file mode 100644 index 0000000..e4f971c --- /dev/null +++ b/src/models/gemma.ts @@ -0,0 +1,11 @@ +import * as llama from './llama.js'; + +export class Model extends llama.Model { + constructor(json: any) { + if (json.force_use_exact_gelu) + json.hidden_act = 'gelu'; + else if (!json.hidden_act || json.hidden_act == 'gelu_pytorch_tanh') + json.hidden_act = 'geluApprox'; + super(json); + } +} diff --git a/src/models/llama.ts b/src/models/llama.ts index 2fb30e4..7da6b9e 100644 --- a/src/models/llama.ts +++ b/src/models/llama.ts @@ -12,10 +12,11 @@ export interface RopeScaling { } export interface ModelArgs { - modelType: 'llama'; + modelType: 'gemma' | 'llama' | 'qwen2'; attentionBias: boolean; attentionOutProjectionBias: boolean; headDim?: number; + hiddenAct: 'gelu' | 'geluApprox' | 'silu'; hiddenSize: number; intermediateSize: number; maxPositionEmbeddings: number; @@ -39,12 +40,12 @@ export function modelArgs(json: any): ModelArgs { ropeTraditional: false, tieWordEmbeddings: true, }, baseModelArgs(json)); - if (args.attentionOutProjectionBias == undefined) { + if (args.attentionOutProjectionBias === undefined) args.attentionOutProjectionBias = args.attentionBias; - } - if (!args.numKeyValueHeads) { + if (args.hiddenAct == undefined) + args.mlpActFunc = 'silu'; + if (args.numKeyValueHeads === undefined) args.numKeyValueHeads = args.numAttentionHeads; - } if (args.ropeScaling) { if (!args.ropeScaling.factor) throw new Error('rope_scaling must contain "factor"') @@ -57,6 +58,16 @@ export function modelArgs(json: any): ModelArgs { return args; } +export class GemmaRMSNorm extends nn.RMSNorm { + constructor(dims: number, eps: number) { + super(dims, eps); + } + + override forward(x: mx.array): mx.array { + return mx.fast.rmsNorm(x, mx.add(1, this.weight), this.eps); + } +} + class DynamicNTKScalingRoPE extends nn.Module { #freqs?: mx.array; @@ -191,6 +202,7 @@ class MLP extends nn.Module { gateProj: nn.Linear; downProj: nn.Linear; upProj: nn.Linear; + activation: (x: mx.array) => mx.array; constructor(args: ModelArgs) { super(); @@ -202,10 +214,11 @@ class MLP extends nn.Module { this.gateProj = new nn.Linear(dim, hiddenDim, mlpBias); this.downProj = new nn.Linear(hiddenDim, dim, mlpBias); this.upProj = new nn.Linear(dim, hiddenDim, mlpBias); + this.activation = nn[args.hiddenAct as 'silu']; } forward(x: mx.array) { - return this.downProj.forward(mx.multiply(nn.silu(this.gateProj.forward(x)), + return this.downProj.forward(mx.multiply(this.activation(this.gateProj.forward(x)), this.upProj.forward(x))); } } @@ -220,8 +233,9 @@ class TransformerBlock extends nn.Module { super(); this.selfAttn = new Attention(args); this.mlp = new MLP(args); - this.inputLayernorm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps); - this.postAttentionLayernorm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps); + const norm = args.modelType == 'gemma' ? GemmaRMSNorm : nn.RMSNorm; + this.inputLayernorm = new norm(args.hiddenSize, args.rmsNormEps); + this.postAttentionLayernorm = new norm(args.hiddenSize, args.rmsNormEps); } forward(x: mx.array, mask: mx.array, cache?: BaseKVCache) { @@ -236,6 +250,7 @@ class LlamaModel extends nn.Module { embedTokens: nn.Embedding; layers: TransformerBlock[]; norm: nn.RMSNorm; + gemmaNormalizer?: number; constructor(args: ModelArgs) { super(); @@ -243,11 +258,18 @@ class LlamaModel extends nn.Module { this.layers = []; for (let i = 0; i < args.numHiddenLayers; ++i) this.layers.push(new TransformerBlock(args)); - this.norm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps); + if (args.modelType == 'gemma') { + this.norm = new GemmaRMSNorm(args.hiddenSize, args.rmsNormEps); + this.gemmaNormalizer = args.hiddenSize ** 0.5; + } else { + this.norm = new nn.RMSNorm(args.hiddenSize, args.rmsNormEps); + } } forward(embeddings: mx.array, cache?: BaseKVCache[]) { let h = embeddings; + if (this.gemmaNormalizer !== undefined) + h = mx.multiply(h, this.gemmaNormalizer); const mask = createAttentionMask(h, cache); for (let i in this.layers) h = this.layers[i].forward(h, mask, cache ? cache[i] : undefined);