Skip to content

Commit

Permalink
Support llama 3.1 model
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Aug 4, 2024
1 parent bb24028 commit 7676b9f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 12 deletions.
5 changes: 3 additions & 2 deletions llm.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ export class KVCache {

// Convert snake_case args into camelCase args.
export function baseModelArgs(args) {
const isObject = (obj) => typeof obj === 'object' && !Array.isArray(obj) && obj !== null
const newArgs = {}
for (const key in args) {
const newKey = key.replace(/(\_\w)/g, (s) => s[1].toUpperCase())
newArgs[newKey] = args[key]
newArgs[newKey] = isObject(args[key]) ? baseModelArgs(args[key]) : args[key]
}
return newArgs
}
Expand Down Expand Up @@ -107,7 +108,7 @@ export async function loadModel(dir) {
return (`${p}.scales` in weights) &&
((m instanceof nn.Linear) || (m instanceof nn.Embedding))
}
const {group_size: groupSize, bits} = config.quantization
const {groupSize, bits} = config.quantization
nn.quantize(model, groupSize, bits, predicate)
}

Expand Down
98 changes: 90 additions & 8 deletions models/llama.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,99 @@ function modelArgs(args) {
args.numKeyValueHeads = args.numAttentionHeads
}
if (args.ropeScaling) {
const requiredKeys = [ 'factor', 'type' ]
if (!Object.keys(args.ropeScaling).every(key => requiredKeys.includes(key)))
throw Error(`rope_scaling must contain keys ${requiredKeys}`)
if (this.ropeScaling.type != 'linear')
throw Error("rope_scaling 'type' currently only supports 'linear'")
if (!args.ropeScaling.factor)
throw new Error('rope_scaling must contain "factor"')
const ropeType = this.ropeScaling.type || this.ropeScaling.ropeType
if (!ropeType)
throw new Error('rope_scaling must contain either "type" or "rope_type"')
if (!['linear', 'dynamic', 'llama3'].includes(ropeType))
throw new Error('rope_scaling "type" currently only supports "linear", "dynamic" or "llama3"')
}
return args
}

class DynamicNTKScalingRoPE extends nn.Module {
constructor(dims,
maxPositionEmbeddings = 2048,
traditional = false,
base = 10000,
scale = 1.0,
ropeType = 'default',
ropeScaling = null) {
super()
this.dims = dims
this.maxPositionEmbeddings = maxPositionEmbeddings
this.traditional = traditional
this.originalBase = base
this.scale = scale
this.ropeType = ropeType
this.ropeScaling = ropeScaling
this.base = this.computeBaseFreq()
}

computeBaseFreq() {
if (this.ropeType === 'llama3')
return this.computeLlama3BaseFreq()
return this.originalBase
}

computeLlama3BaseFreq() {
const factor = this.ropeScaling.factor
const lowFreqFactor = this.ropeScaling.lowFreqFactor ?? 1.0
const highFreqFactor = this.ropeScaling.highFreqFactor ?? 4.0
const oldContextLen = this.ropeScaling.originalMaxPositionEmbeddings ?? 8192

const lowFreqWavelen = oldContextLen / lowFreqFactor
const highFreqWavelen = oldContextLen / highFreqFactor

const freqs = mx.power(mx.this.originalBase, mx.divide(mx.arange(0, this.dims, 2), this.dims))
const wavelens = mx.multiply(2 * mx.pi, freqs)

const smooths = mx.divide(mx.subtract(wavelens, highFreqWavelen),
mx.subtract(lowFreqWavelen, highFreqWavelen))
let newBaseFreqs = mx.add(mx.multiply(mx.multiply(freqs,
mx.subtract(1, smooths)),
factor),
smooths);
newBaseFreqs = mx.where(mx.less(wavelens, highFreqWavelen), freqs, newBaseFreqs)
newBaseFreqs = mx.where(ms.greater(wavelens, lowFreqWavelen), mx.multiply(freqs, factor), newBaseFreqs)
return mx.mean(newBaseFreqs).item()
}

forward(x, offset = 0) {
const seqLen = x.shape[1] + offset
let base = this.base
if (this.maxPositionEmbeddings && seqLen > this.maxPositionEmbeddings) {
base *= ((this.scale * seqLen / this.maxPositionEmbeddings) - (this.scale - 1)) ** (this.dims / (this.dims - 2))
}
return mx.fast.rope(x, this.dims, this.traditional, base, this.scale, offset)
}
}

function initializeRoPE(args) {
const headDim = args.headDim ?? Math.floor(args.hiddenSize / args.numAttentionHeads)

const ropeScaling = args.ropeScaling
let ropeType = 'default'
let ropeScale = 1.0

if (ropeScaling) {
ropeType = (ropeScaling.type ?? ropeScaling.ropeType) ?? 'default'
if (ropeType === 'linear')
ropeScale = 1 / ropeScaling.factor
else if (ropeType === 'llama3')
ropeScale = 1.0
}

return new DynamicNTKScalingRoPE(headDim,
args.maxPositionEmbeddings,
args.ropeTraditional,
args.ropeTheta,
ropeScale,
ropeType,
ropeScaling)
}

class Attention extends nn.Module {
constructor(args) {
super()
Expand All @@ -37,9 +121,7 @@ class Attention extends nn.Module {
this.vProj = new nn.Linear(dim, this.nKVHeads * headDim, args.attentionBias)
this.oProj = new nn.Linear(this.nHeads * headDim, dim, args.attentionBias)

const ropeScale = args.ropeScaling?.type == 'linear' ? 1 / args.ropeScaling.factor
: 1
this.rope = new nn.RoPE(headDim, args.ropeTraditional, args.ropeTheta, ropeScale)
this.rope = initializeRoPE(args);
}

forward(x, mask, cache) {
Expand Down
2 changes: 1 addition & 1 deletion models/qwen2.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {KVCache, baseModelArgs, createAdditiveCausalMask} from '../llm.js'
import {baseModelArgs, createAdditiveCausalMask} from '../llm.js'

function modelArgs(args) {
args = Object.assign({
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"url": "https://github.com/frost-beta/llm.js/issues"
},
"dependencies": {
"@frost-beta/mlx": "0.0.14",
"@frost-beta/mlx": "0.0.16",
"@lenml/tokenizers": "1.0.9",
"tick-promise": "1.0.0"
},
Expand Down

0 comments on commit 7676b9f

Please sign in to comment.