Skip to content

Question about the gennerated code of WeightOnlyInt8Linear #114

Open
@feiyuvl

Description

@feiyuvl

I write a simple test to get the triton code of WeightOnlyInt8Linear,the test code is as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
        self.register_buffer("scales", torch.randn(out_features, dtype=torch.bfloat16))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

linear = WeightOnlyInt8Linear(4096, 8192)
linear.to(device='cuda', dtype=torch.bfloat16)
linear = torch.compile(linear.eval(), mode='reduce-overhead', fullgraph=True)

input = torch.randn(1, 4096, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
    linear(input)

I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code:
1709173980439

Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions