Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the gennerated code of WeightOnlyInt8Linear #114

Open
feiyuvl opened this issue Feb 29, 2024 · 6 comments
Open

Question about the gennerated code of WeightOnlyInt8Linear #114

feiyuvl opened this issue Feb 29, 2024 · 6 comments

Comments

@feiyuvl
Copy link

feiyuvl commented Feb 29, 2024

I write a simple test to get the triton code of WeightOnlyInt8Linear,the test code is as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
        self.register_buffer("scales", torch.randn(out_features, dtype=torch.bfloat16))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

linear = WeightOnlyInt8Linear(4096, 8192)
linear.to(device='cuda', dtype=torch.bfloat16)
linear = torch.compile(linear.eval(), mode='reduce-overhead', fullgraph=True)

input = torch.randn(1, 4096, device='cuda').to(dtype=torch.bfloat16)
with torch.no_grad():
    linear(input)

I expect the generated code will fuse the weight convert (int8 ->bfloat6 ) to the gemv function. However I get the following code:
1709173980439

Weight convert kernel is not fused. The load of full bf16 weight after conversion will hurt the gemv performace badly. Is the generated code reasonable? or have I made some mistake?

@conway-abacus
Copy link

I was having trouble reproing the int8 speedup. didn't look into the generated code to verify, but turns out I needed the following

import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True

can you try that?

@Chillee
Copy link
Contributor

Chillee commented Feb 29, 2024

Yes, you need to add coordinate_descent_tuning to be True.

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True

D = 8192

def bench(f, name=None):
    import time
    from triton.testing import do_bench

    us_per_iter = do_bench(lambda: f())*1000
    print(f"{name}: {(1e6/us_per_iter) * D * D / 1e9} GB/s")

    return 0

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.randn(out_features, in_features).to(dtype=torch.int8))
        self.register_buffer("scales", torch.randn(out_features, dtype=dtype))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

linear = WeightOnlyInt8Linear(D, D)
compiled_linear = torch.compile(linear.eval(), fullgraph=True)

input = torch.randn(1, D, device='cuda').to(dtype=torch.bfloat16)

with torch.no_grad():
    bench(lambda: linear(input), "eager")
    bench(lambda: compiled_linear(input), "compiled")
eager: 163.15397820521537 GB/s
compiled: 1277.1920906791986 GB/s

and this is the generated file

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_chilli/26/c26sd3tlnwesrvfxag3lrdmeofbgukgfrbuj2lus7rhhd7madjg6.py
# Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
# linear => convert_element_type_3, mul, sum_1
# mul => mul_1
triton_red_fused_mm_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
from triton.compiler.compiler import AttrsDescriptor

@reduction(
    size_hints=[8192, 8192],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*bf16', 2: '*i8', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4, 5))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mm_mul_0', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 8192
    rnumel = 8192
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp7 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1), None, eviction_policy='evict_last').to(tl.float32)
        tmp2 = tl.load(in_ptr1 + (r1 + (8192*x0)), None, eviction_policy='evict_first')
        tmp1 = tmp0.to(tl.float32)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp3.to(tl.float32)
        tmp5 = tmp1 * tmp4
        tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
        tmp8 = _tmp7 + tmp6
        _tmp7 = tmp8
    tmp7 = tl.sum(_tmp7, 1)[:, None]
    tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
    tmp9 = tmp7.to(tl.float32)
    tmp10 = tmp9.to(tl.float32)
    tmp12 = tmp10 * tmp11
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (x0), tmp12, None)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (8192, 8192), (8192, 1))
    assert_size_stride(arg1_1, (8192, ), (1, ))
    assert_size_stride(arg2_1, (1, 8192), (8192, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, 8192), (8192, 1), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # Source Nodes: [linear, mul], Original ATen: [aten.mm, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused_mm_mul_0.run(buf1, arg2_1, arg0_1, arg1_1, 8192, 8192, grid=grid(8192), stream=stream0)
        del arg0_1
        del arg1_1
        del arg2_1
    return (buf1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((8192, 8192), (8192, 1), device='cuda:0', dtype=torch.int8)
    arg1_1 = rand_strided((8192, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg2_1 = rand_strided((1, 8192), (8192, 1), device='cuda:0', dtype=torch.bfloat16)
    fn = lambda: call([arg0_1, arg1_1, arg2_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

@Chillee
Copy link
Contributor

Chillee commented Feb 29, 2024

I think we can actually relax coordinate_descent_tuning, although we still need BS=1 restriction.

@feiyuvl
Copy link
Author

feiyuvl commented Mar 1, 2024

@Chillee @conway-abacus Thank you, coordinate_descent_tuning=True generates the expected code.

@Chillee
Copy link
Contributor

Chillee commented Mar 1, 2024

pytorch/pytorch#120954

This PR always turns on the decomposition.

@conway-abacus
Copy link

hey @Chillee any tips for generating fused kernel for BS > 1? is it related at all to pytorch/pytorch#127056

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants