Skip to content

Commit

Permalink
Update KVCache implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Dec 7, 2024
1 parent 431c0a9 commit 6d86c2f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,12 @@ export function createAttentionMask(h: mx.array, cache?: BaseKVCache[]) {
let windowSize: number | undefined;
let offset = 0;
if (cache) {
offset = cache[0].offset;
if (cache[0] instanceof RotatingKVCache)
if (cache[0] instanceof RotatingKVCache) {
offset = Math.min(cache[0].offset, cache[0].maxSize);
windowSize = cache[0].maxSize;
} else {
offset = cache[0].offset;
}
}
return createCausalMask(T, offset, windowSize).astype(h.dtype);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/kv-cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ export class RotatingKVCache extends BaseKVCache {
this.keys = this.temporalOrder(this.keys);
this.values = this.temporalOrder(this.values);

// The largest size is self.max_size + S - 1 to ensure every token gets
// The largest size is self.max_size + S to ensure every token gets
// at least this.maxSize context.
const trimSize = this.#idx - this.maxSize + 1;
const trimSize = this.#idx - this.maxSize;
this.keys = this.trim(trimSize, this.keys, keys);
this.values = this.trim(trimSize, this.values, values);
mx.dispose(old);
Expand Down

0 comments on commit 6d86c2f

Please sign in to comment.