Skip to content

Commit

Permalink
Add createAttentionMask helper
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Aug 4, 2024
1 parent 7676b9f commit 1a2a86e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 28 deletions.
14 changes: 13 additions & 1 deletion llm.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ export class KVCache {
updateAndFetch(keys, values) {
const prev = this.offset
if (!this.keys || (prev + keys.shape[2] > this.keys.shape[2])) {
const B = keys.shape[0]
const nSteps = Math.floor((this.step + keys.shape[2] - 1) / this.step)
const shape = [1, this.nKVHeads, nSteps * this.step, this.headDim]
const shape = [B, this.nKVHeads, nSteps * this.step, this.headDim]
const newK = mx.zeros(shape, keys.dtype)
const newV = mx.zeros(shape, values.dtype)
if (this.keys) {
Expand Down Expand Up @@ -70,6 +71,17 @@ export function createAdditiveCausalMask(N, offset = 0) {
return mx.multiply(mask, -1e9)
}

// Create an attention mask.
export function createAttentionMask(h, cache) {
const T = h.shape[1]
if (T > 1) {
const offset = cache && cache[0] ? cache[0].offset : 0
return createAdditiveCausalMask(T, offset).astype(h.dtype)
} else {
return null
}
}

// Return a tokenizer.
export async function loadTokenizer(dir) {
return TokenizerLoader.fromPreTrained({
Expand Down
8 changes: 2 additions & 6 deletions models/gemma.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {baseModelArgs} from '../gemma.js'
import {baseModelArgs, createAttentionMask} from '../llm.js'

function modelArgs(args) {
return Object.assign({
Expand Down Expand Up @@ -113,11 +113,7 @@ class GemmaModel extends nn.Module {
let h = this.embedTokens.forward(inputs)
h = mx.multiply(h, this.hiddenSize ** 0.5)

let mask
if (h.shape[1] > 1) {
mask = nn.MultiHeadAttention.createAdditiveCausalMask(h.shape[1])
mask = mask.astype(h.dtype)
}
const mask = createAttentionMask(h, cache);

cache = cache ?? new Array(this.layers.length)

Expand Down
8 changes: 2 additions & 6 deletions models/gemma2.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {baseModelArgs} from '../llm.js'
import {baseModelArgs, createAttentionMask} from '../llm.js'

function ModelArgs(args) {
return Object.assign({
Expand Down Expand Up @@ -135,11 +135,7 @@ class GemmaModel extends nn.Module {
let h = this.embedTokens.forward(inputs)
h = mx.multiply(h, this.args.hiddenSize ** 0.5)

let mask
if (h.shape[1] > 1) {
mask = nn.MultiHeadAttention.createAdditiveCausalMask(h.shape[1])
mask = mask.astype(h.dtype)
}
const mask = createAttentionMask(h, cache);

cache = cache ?? new Array(this.layers.length)

Expand Down
5 changes: 2 additions & 3 deletions models/gpt2.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {baseModelArgs, createAdditiveCausalMask} from '../gpt2.js'
import {baseModelArgs, createAttentionMask} from '../llm.js'

function modelArgs(args) {
args = baseModelArgs(args)
Expand Down Expand Up @@ -105,8 +105,7 @@ class GPT2Model extends nn.Module {
if (hiddenStates.shape[1] > 1) {
const positionIds = mx.arange(L)
hiddenStates = mx.add(hiddenStates, this.wpe.forward(positionIds))
mask = createAdditiveCausalMask(hiddenStates.shape[1], cache ? cache[0].offset : 0)
mask = mask.astype(hiddenStates.dtype)
mask = createAttentionMask(hiddenStates, cache)
}

cache = cache ?? new Array(this.h.length)
Expand Down
8 changes: 2 additions & 6 deletions models/llama.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {baseModelArgs, createAdditiveCausalMask} from '../llm.js'
import {baseModelArgs, createAttentionMask} from '../llm.js'

function modelArgs(args) {
args = Object.assign({
Expand Down Expand Up @@ -203,11 +203,7 @@ class LlamaModel extends nn.Module {
forward(inputs, cache) {
let h = this.embedTokens.forward(inputs)

let mask
if (h.shape[1] > 1) {
mask = createAdditiveCausalMask(h.shape[1], cache ? cache[0].offset : 0)
mask = mask.astype(h.dtype)
}
const mask = createAttentionMask(h, cache);

cache = cache ?? new Array(this.layers.length)

Expand Down
8 changes: 2 additions & 6 deletions models/qwen2.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {core as mx, nn} from '@frost-beta/mlx'
import {baseModelArgs, createAdditiveCausalMask} from '../llm.js'
import {baseModelArgs, createAttentionMask} from '../llm.js'

function modelArgs(args) {
args = Object.assign({
Expand Down Expand Up @@ -116,11 +116,7 @@ class Qwen2Model extends nn.Module {
forward(inputs, cache) {
let h = this.embedTokens.forward(inputs)

let mask
if (h.shape[1] > 1) {
mask = nn.MultiHeadAttention.createAdditiveCausalMask(h.shape[1])
mask = mask.astype(h.dtype)
}
const mask = createAttentionMask(h, cache);

cache = cache ?? new Array(this.layers.length)

Expand Down

0 comments on commit 1a2a86e

Please sign in to comment.