Skip to content

Commit

Permalink
Test llava with quantizied weights
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Sep 26, 2024
1 parent 6ebc13c commit c960a33
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ jobs:
huggingface download --silent Qwen/Qwen2-0.5B
yarn tsx src/generate.ts --max-tokens=128 Qwen2-0.5B
huggingface download --silent llava-hf/llava-1.5-7b-hf
yarn tsx src/generate.ts --max-tokens=128 llava-1.5-7b-hf 'USER: How are you?\nASSISTANT:'
huggingface download --silent mlx-community/llava-1.5-7b-4bit
yarn tsx src/generate.ts --max-tokens=128 llava-1.5-7b-4bit 'USER: How are you?\nASSISTANT:'
publish:
if: startsWith(github.ref, 'refs/tags/')
Expand Down
8 changes: 6 additions & 2 deletions src/models/llava.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ export class Model extends BaseModel {
// [out_channels, in_channels, kH, KW]
// MLX Conv2d expects the weight tensor to be of shape:
// [out_channels, kH, KW, in_channels]
if (key.endsWith('patch_embedding.weight'))
weights[key] = weights[key].transpose(0, 2, 3, 1);
if (key.endsWith('patch_embedding.weight')) {
// Some mlx-community models already transposed it for us.
const {shape} = weights[key];
if (shape[1] != shape[2])
weights[key] = weights[key].transpose(0, 2, 3, 1);
}
}
}

Expand Down

0 comments on commit c960a33

Please sign in to comment.