Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin slower than fp16 on larger batches #21

Open
mobicham opened this issue Apr 9, 2024 · 2 comments
Open

Marlin slower than fp16 on larger batches #21

mobicham opened this issue Apr 9, 2024 · 2 comments

Comments

@mobicham
Copy link

mobicham commented Apr 9, 2024

I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16:
GPU: A6000 ada

matrix_shape:  [11008, 4096]

input_shape: [1, 1024, 11008]
time (fp16): 0.0007191438674926758
time (marlin): 0.0006200861930847168 (1.16x)

input_shape: [16, 1024, 11008]
time (fp16): 0.010448209762573242
time (marlin): 0.01280400848388672 (0.82x)

Code below:

def forward_marlin(marlin_layer, x):
    y = torch.empty(x.shape[:-1] + (marlin_layer.s.shape[1],), dtype=x.dtype, device=x.device)
    marlin.mul(x.view((-1, x.shape[-1])), marlin_layer.B, y.view((-1, y.shape[-1])), marlin_layer.s, marlin_layer.workspace_fp)
    return y

print(time_it(lambda: torch.matmul(x, ref) ))
print(time_it(lambda: forward_marlin(marlin_layer, x)))

What could be the issue ? Thanks in advance!

@efrantar
Copy link
Member

efrantar commented Apr 9, 2024

Hi, Marlin is primarily optimized for generative inference (with a few tokens at-a-time), which is actually memory-bound and can hence be sped up via weight-quantization; e.g. input shapes of (16, 1, 11008). Note that for batchsize > 128 (meaning the overall number of tokens, in your case 16 * 1024), inference stops being memory bound and weight-only quantization can generally not be faster (though Marlin sometimes is a bit for not too large batchsizes due to slightly better partitioning than the default torch kernels).

@mobicham
Copy link
Author

Thanks for your answer @efrantar . Understood. I am trying to integrate it with our quantization method, below the benchmarks for the forward pass on an 3090, Llama2-7B, batch-size=1, context-size=2048:

fp16:                 0.4604 + model compile: 0.4218
int4 (torch compile): 0.4554
Marlin (int4):        0.4221 + model compile: 0.3841

It is about 10% faster than fp16 with this setup-up, the llm eval score drops a bit though (51.57 -> 51.27)

Is there a way to dequantize the weights without calling the matmul with the identity matrix?

Thanks again for your works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants