- 
                Notifications
    
You must be signed in to change notification settings  - Fork 564
 
Open
Description
When using int8 quantization, there is a significant performance drop in multi-batch inference compared to single-batch inference. The single-batch performance is good, but the performance doesn't scale well with increased batch size.
class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
Current Behavior
- The explicit 
.to(dtype=input.dtype)creates a separate type conversion kernel - In single batch case, inductor can successfully fuse this conversion with gemm
 - In multi-batch case, the fusion fails and we get:
- One kernel for int8->fp16 conversion
 - Another kernel for gemm computation
 - This leads to extra memory traffic and lower performance
 
 
Metadata
Metadata
Assignees
Labels
No labels