Skip to content

Conversation

AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

Performance Optimization

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Copy link
Contributor

Summary of Changes

Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the "mean" operator, particularly for dimension-wise mean calculations, by leveraging specialized Triton kernels and advanced heuristic-driven tuning. The changes enhance both the performance and numerical stability of the operation, especially for mixed-precision computations and varying input tensor shapes, ensuring more accurate and efficient results across different scenarios.

Highlights

  • Numerical Stability: Implemented "float32" accumulation for "float16" and "bfloat16" inputs in mean kernels to prevent precision loss during summation, ensuring more accurate results for lower-precision data types.
  • Specialized Triton Kernels: Introduced two new Triton kernels, "mean_dim_kernel_non_inner" and "mean_dim_kernel_inner", specifically designed for efficient calculation of the mean along a given dimension, optimizing for different reduction patterns.
  • Dynamic Kernel Dispatch: The "mean_dim" function now dynamically dispatches to the most appropriate specialized kernel based on the input tensor's dimensions and the characteristics of the reduction axis, leading to improved performance.
  • Heuristic-driven Tuning: New heuristic functions ("mean_heur_tile_k", "mean_heur_tile_n_non_inner", "mean_heur_one_tile_per_cta") have been added and integrated into the Triton auto-tuning system for "mean_non_inner" operations, allowing for adaptive performance optimization.
  • Boolean Input Handling: Added explicit conversion of "torch.bool" inputs to "torch.int64" for mean calculations, ensuring correct and consistent behavior when dealing with boolean tensors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for the mean operator by adding specialized Triton kernels for different reduction scenarios. The changes include handling various data types with upcasting for precision, and adding new heuristics for kernel tuning. My review focuses on improving code clarity and removing redundancy. I've suggested combining two functions to simplify the control flow and pointed out a redundant check in the new heuristic logic.

Comment on lines +226 to +310
def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
logger.debug("GEMS MEAN_DIM")
if dtype is None:
dtype = x.dtype
if dim is None:
out = mean(x, dtype=dtype)
dtype = inp.dtype
if dtype is torch.bool:
inp = inp.to(torch.int64)
dtype = torch.int64

if dim == []:
# mean over all elements
if not keepdim:
return mean(inp, dtype=dtype)
else:
dim_num = inp.ndim
return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)

shape = list(inp.shape)

# -------- normalize dim to a list of ints --------
if isinstance(dim, int):
dim = [dim]
else:
try:
dim = list(dim)
except TypeError:
raise TypeError(
f"dim must be an int, iterable of ints, or [], got {type(dim)}"
)

dim = [d % inp.ndim for d in dim]
# -------------------------------------------------

if len(dim) == 1:
dim0 = dim[0]
N = inp.shape[dim0] # reduction length
# product of dims before dim0; use initializer 1 for empty slice
M = reduce(lambda x, y: x * y, shape[:dim0], 1)
inp = inp.contiguous()
K = inp.numel() // M // N
shape[dim0] = 1
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)

with torch_device_fn.device(inp.device):
if K > 1:
grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
mean_dim_kernel_non_inner[grid](
out,
inp,
M,
N,
K,
)
else:
grid = (M, 1, 1)
mean_dim_kernel_inner[grid](
out,
inp,
M,
N,
)
if not keepdim:
out = out.reshape([1] * x.ndim)
out = out.squeeze(dim=dim0)
return out
else:
inp = dim_compress(inp, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = inp.numel() // N
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)

shape = list(x.shape)
dim = [d % x.ndim for d in dim]
x = dim_compress(x, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = x.numel() // N
out = torch.empty(shape, dtype=dtype, device=x.device)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)

with torch_device_fn.device(x.device):
mean_dim_kernel[grid](x, out, M, N)
if not keepdim:
out = out.squeeze(dim)
return out
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch_device_fn.device(inp.device):
mean_dim_kernel[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
return out


def mean_dim(inp, dim=None, keepdim=False, *, dtype=None):
logger.debug("GEMS MEAN_DIM (wrapper)")
return mean_dim_comm(inp, dim, keepdim, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For simplicity and clarity, consider combining mean_dim_comm and the mean_dim wrapper into a single mean_dim function. The current implementation with a simple wrapper adds an unnecessary layer of indirection. Exposing the out parameter in the public mean_dim function is also consistent with the PyTorch API.

def mean_dim(inp, dim=None, keepdim=False, *, dtype=None, out=None):
    logger.debug("GEMS MEAN_DIM")
    if dtype is None:
        dtype = inp.dtype
        if dtype is torch.bool:
            inp = inp.to(torch.int64)
            dtype = torch.int64

    if dim == []:
        # mean over all elements
        if not keepdim:
            return mean(inp, dtype=dtype)
        else:
            dim_num = inp.ndim
            return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)

    shape = list(inp.shape)

    # -------- normalize dim to a list of ints --------
    if isinstance(dim, int):
        dim = [dim]
    else:
        try:
            dim = list(dim)
        except TypeError:
            raise TypeError(
                f"dim must be an int, iterable of ints, or [], got {type(dim)}"
            )

    dim = [d % inp.ndim for d in dim]
    # -------------------------------------------------

    if len(dim) == 1:
        dim0 = dim[0]
        N = inp.shape[dim0]  # reduction length
        # product of dims before dim0; use initializer 1 for empty slice
        M = reduce(lambda x, y: x * y, shape[:dim0], 1)
        inp = inp.contiguous()
        K = inp.numel() // M // N
        shape[dim0] = 1
        if out is None:
            out = torch.empty(shape, dtype=dtype, device=inp.device)

        with torch_device_fn.device(inp.device):
            if K > 1:
                grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
                mean_dim_kernel_non_inner[grid](
                    out,
                    inp,
                    M,
                    N,
                    K,
                )
            else:
                grid = (M, 1, 1)
                mean_dim_kernel_inner[grid](
                    out,
                    inp,
                    M,
                    N,
                )
        if not keepdim:
            out = out.squeeze(dim=dim0)
        return out
    else:
        inp = dim_compress(inp, dim)
        N = 1
        for i in dim:
            N *= shape[i]
            shape[i] = 1
        M = inp.numel() // N
        if out is None:
            out = torch.empty(shape, dtype=dtype, device=inp.device)

        grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
        with torch_device_fn.device(inp.device):
            mean_dim_kernel[grid](inp, out, M, N)
        if not keepdim:
            out = out.squeeze(dim=dim)
        return out

limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
N = args.get("N", 1)
desired = min(max(N, _MIN_TILE_N), limit_by_k)
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The limit_by_k variable in this min call is redundant. The value of desired is already constrained by limit_by_k in the previous line. Removing the redundant variable will make the code clearer.

Suggested change
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
desired = min(desired, _MAX_ONE_TILE_N)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants