Open
Description
I write a simple test to get the triton code of WeightOnlyInt8Linear
,the test code is as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
self.register_buffer("scales", torch.randn(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
linear = WeightOnlyInt8Linear(4096, 8192)
linear.to(device='cuda', dtype=torch.bfloat16)
linear = torch.compile(linear.eval(), mode='reduce-overhead', fullgraph=True)
input = torch.randn(1, 4096, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
linear(input)
I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code:
Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?
Metadata
Metadata
Assignees
Labels
No labels