-
Notifications
You must be signed in to change notification settings - Fork 521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about the gennerated code of WeightOnlyInt8Linear
#114
Comments
I was having trouble reproing the int8 speedup. didn't look into the generated code to verify, but turns out I needed the following
can you try that? |
Yes, you need to add import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
D = 8192
def bench(f, name=None):
import time
from triton.testing import do_bench
us_per_iter = do_bench(lambda: f())*1000
print(f"{name}: {(1e6/us_per_iter) * D * D / 1e9} GB/s")
return 0
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
self.register_buffer("scales", torch.randn(out_features, dtype=dtype))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
linear = WeightOnlyInt8Linear(D, D)
compiled_linear = torch.compile(linear.eval(), fullgraph=True)
input = torch.randn(1, D, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
bench(lambda: linear(input), "eager")
bench(lambda: compiled_linear(input), "compiled")
and this is the generated file from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_chilli/26/c26sd3tlnwesrvfxag3lrdmeofbgukgfrbuj2lus7rhhd7madjg6.py
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
# linear => convert_element_type_3, mul, sum_1
# mul => mul_1
triton_red_fused_mm_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from triton.compiler.compiler import AttrsDescriptor
@reduction(
size_hints=[8192, 8192],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*i8', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4, 5))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mm_mul_0', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8192
rnumel = 8192
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r1 + (8192*x0)), None, eviction_policy='evict_first')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp5 = tmp1 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tmp8
tmp7 = tl.sum(_tmp7, 1)[:, None]
tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp9 = tmp7.to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp10 * tmp11
tl.debug_barrier()
tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (8192, 8192), (8192, 1))
assert_size_stride(arg1_1, (8192, ), (1, ))
assert_size_stride(arg2_1, (1, 8192), (8192, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 8192), (8192, 1), torch.float32)
buf1 = buf0; del buf0 # reuse
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused_mm_mul_0.run(buf1, arg2_1, arg0_1, arg1_1, 8192, 8192, grid=grid(8192), stream=stream0)
del arg0_1
del arg1_1
del arg2_1
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((8192, 8192), (8192, 1), device='cuda:0', dtype=torch.int8)
arg1_1 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((1, 8192), (8192, 1), device='cuda:0', dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module) |
I think we can actually relax |
@Chillee @conway-abacus Thank you, |
This PR always turns on the decomposition. |
hey @Chillee any tips for generating fused kernel for BS > 1? is it related at all to pytorch/pytorch#127056 |
I write a simple test to get the triton code of
WeightOnlyInt8Linear
,the test code is as follows:I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code:
Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?
The text was updated successfully, but these errors were encountered: