Skip to content

Commit

Permalink
Get ready for multiple models
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Jul 17, 2024
1 parent 8856f3a commit 7366ad3
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 77 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ jobs:

- name: Test
run: |
huggingface download --to weights-4bit mlx-community/Meta-Llama-3-8B-Instruct-4bit
yarn
echo Hello | node chat.js weights-4bit
huggingface download mlx-community/Meta-Llama-3-8B-Instruct-4bit
echo Hello | node chat.js Meta-Llama-3-8B-Instruct-4bit
huggingface download frost-beta/Llama3-33.5M-Japanese
node generate.js Llama3-33.5M-Japanese
publish:
if: startsWith(github.ref, 'refs/tags/')
Expand Down
22 changes: 8 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# llama3.js
# llm.js

A JavaScript implementation of Llama 3 using [node-mlx](https://github.com/frost-beta/node-mlx),
code modified from [mlx-examples](https://github.com/ml-explore/mlx-examples).
Load language models locally with JavaScript, using
[node-mlx](https://github.com/frost-beta/node-mlx), code modified from
[mlx-examples](https://github.com/ml-explore/mlx-examples).

__Quantized models can only run on Macs with Apple Silicon.__

## Usage

Download weights
(more can be found at [mlx-community](https://huggingface.co/collections/mlx-community/llama-3-662156b069a5d33b3328603c)):
(more can be found at [mlx-community](https://huggingface.co/collections/mlx-community/)):

```sh
npm install -g @frost-beta/huggingface
Expand All @@ -18,19 +19,12 @@ huggingface download --to weights mlx-community/Meta-Llama-3-8B-Instruct-8bit
Start chating:

```sh
npm install -g llama3
llama3-chat ./weights
npm install -g @frost-beta/llm
llm-chat ./weights
```

Or do text generation:

```sh
llama3-generate ./weights 'Write a short story'
llm-generate ./weights 'Write a short story'
```

## Development

This project serves as a demo of node-mlx, and code is intended to keep as
simple as possible.

For general purpose LLM modules, please visit :construction:.
2 changes: 1 addition & 1 deletion chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {core as mx} from '@frost-beta/mlx'
import {loadTokenizer, loadModel, step} from './llm.js'

if (process.argv.length < 3) {
console.error('Usage: llama3-chat /path/to/weights/dir')
console.error('Usage: llm-chat /path/to/weights/dir')
process.exit(0)
}

Expand Down
2 changes: 1 addition & 1 deletion generate.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const argv = process.argv.slice(2).filter((arg) => {
})

if (argv.length < 1) {
console.error('Usage: llama3-generate /path/to/weights/dir [--max-tokens=1024] [prompt]')
console.error('Usage: llm-generate /path/to/weights/dir [--max-tokens=1024] [prompt]')
process.exit(0)
}

Expand Down
63 changes: 57 additions & 6 deletions llm.js
Original file line number Diff line number Diff line change
@@ -1,17 +1,58 @@
import fs from 'node:fs/promises'
import path from 'node:path'
import nextTick from 'tick-promise'
import {existsSync} from 'node:fs'
import {TokenizerLoader} from '@lenml/tokenizers'
import {fromPreTrained} from '@lenml/tokenizer-llama3'
import {core as mx, nn} from '@frost-beta/mlx'

import {KVCache, Model} from './model.js'
// A design of KV cache friendly to MLX's memory cache design, which allocates
// arrays in same shapes.
// See also https://github.com/ml-explore/mlx-examples/issues/724.
export class KVCache {
constructor(headDim, nKVHeads) {
this.nKVHeads = nKVHeads
this.headDim = headDim
this.keys = null
this.values = null
this.offset = 0
this.step = 256
}

updateAndFetch(keys, values) {
const prev = this.offset
if (!this.keys || (prev + keys.shape[2] > this.keys.shape[2])) {
const nSteps = Math.floor((this.step + keys.shape[2] - 1) / this.step)
const shape = [1, this.nKVHeads, nSteps * this.step, this.headDim]
const newK = mx.zeros(shape, keys.dtype)
const newV = mx.zeros(shape, values.dtype)
if (this.keys) {
const old = [this.keys, this.values]
if (prev % this.step != 0) {
const get = ['...', mx.Slice(null, prev), mx.Slice()]
this.keys = this.keys.index(get)
this.values = this.values.index(get)
}
this.keys = mx.concatenate([this.keys, newK], 2)
this.values = mx.concatenate([this.values, newV], 2)
mx.dispose(old)
} else {
this.keys = newK
this.values = newV
}
}

this.offset += keys.shape[2]

const insert = ['...', mx.Slice(prev, this.offset), mx.Slice()]
this.keys.indexPut_(insert, keys)
this.values.indexPut_(insert, values)

const get = ['...', mx.Slice(null, this.offset), mx.Slice()]
return [this.keys.index(...get), this.values.index(...get)]
}
}

// Return a tokenizer.
export async function loadTokenizer(dir) {
if (!existsSync(path.join(dir, 'tokenizer_config.json')))
return fromPreTrained()
return TokenizerLoader.fromPreTrained({
tokenizerJSON: JSON.parse(await fs.readFile(path.join(dir, 'tokenizer.json'))),
tokenizerConfig: JSON.parse(await fs.readFile(path.join(dir, 'tokenizer_config.json'))),
Expand All @@ -29,7 +70,17 @@ export async function loadModel(dir) {
}

// Create llama3 model.
const model = new Model(config)
let model
try {
const {Model} = await import(`./models/${config.model_type}.js`)
model = new Model(config)
} catch (error) {
if (error.code == 'ERR_MODULE_NOT_FOUND') {
console.error('Unsupported model type:', config.model_type)
process.exit(1)
}
throw error
}

// Quantization.
if (config.quantization) {
Expand Down
48 changes: 1 addition & 47 deletions model.js → models/llama.js
Original file line number Diff line number Diff line change
@@ -1,51 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'

// A design of KV cache friendly to MLX's memory cache design, which allocates
// arrays in same shapes.
// See also https://github.com/ml-explore/mlx-examples/issues/724.
export class KVCache {
constructor(headDim, nKVHeads) {
this.nKVHeads = nKVHeads
this.headDim = headDim
this.keys = null
this.values = null
this.offset = 0
this.step = 256
}

updateAndFetch(keys, values) {
const prev = this.offset
if (!this.keys || (prev + keys.shape[2] > this.keys.shape[2])) {
const nSteps = Math.floor((this.step + keys.shape[2] - 1) / this.step)
const shape = [1, this.nKVHeads, nSteps * this.step, this.headDim]
const newK = mx.zeros(shape, keys.dtype)
const newV = mx.zeros(shape, values.dtype)
if (this.keys) {
const old = [this.keys, this.values]
if (prev % this.step != 0) {
const get = ['...', mx.Slice(null, prev), mx.Slice()]
this.keys = this.keys.index(get)
this.values = this.values.index(get)
}
this.keys = mx.concatenate([this.keys, newK], 2)
this.values = mx.concatenate([this.values, newV], 2)
mx.dispose(old)
} else {
this.keys = newK
this.values = newV
}
}

this.offset += keys.shape[2]

const insert = ['...', mx.Slice(prev, this.offset), mx.Slice()]
this.keys.indexPut_(insert, keys)
this.values.indexPut_(insert, values)

const get = ['...', mx.Slice(null, this.offset), mx.Slice()]
return [this.keys.index(...get), this.values.index(...get)]
}
}
import {KVCache} from '../llm.js'

class Attention extends nn.Module {
constructor(args) {
Expand Down
11 changes: 5 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
{
"name": "llama3",
"name": "@frost-beta/llm",
"version": "0.0.1-dev",
"main": "llm.js",
"type": "module",
"bin": {
"llama3-chat": "chat.js",
"llama3-generate": "generate.js"
"llm-chat": "chat.js",
"llm-generate": "generate.js"
},
"author": "zcbenz",
"license": "MIT",
"repository": {
"type": "git",
"url": "git+https://github.com/frost-beta/llama3.js.git"
"url": "git+https://github.com/frost-beta/llm.js.git"
},
"bugs": {
"url": "https://github.com/frost-beta/llama3.js/issues"
"url": "https://github.com/frost-beta/llm.js/issues"
},
"dependencies": {
"@frost-beta/mlx": "0.0.14",
"@lenml/tokenizer-llama3": "1.0.10",
"@lenml/tokenizers": "1.0.9",
"tick-promise": "1.0.0"
},
Expand Down

0 comments on commit 7366ad3

Please sign in to comment.