-
Notifications
You must be signed in to change notification settings - Fork 144
[AdvancedCompiler]Optimize mean #979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the "mean" operator, particularly for dimension-wise mean calculations, by leveraging specialized Triton kernels and advanced heuristic-driven tuning. The changes enhance both the performance and numerical stability of the operation, especially for mixed-precision computations and varying input tensor shapes, ensuring more accurate and efficient results across different scenarios. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant performance optimizations for the mean
operator by adding specialized Triton kernels for different reduction scenarios. The changes include handling various data types with upcasting for precision, and adding new heuristics for kernel tuning. My review focuses on improving code clarity and removing redundancy. I've suggested combining two functions to simplify the control flow and pointed out a redundant check in the new heuristic logic.
def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): | ||
logger.debug("GEMS MEAN_DIM") | ||
if dtype is None: | ||
dtype = x.dtype | ||
if dim is None: | ||
out = mean(x, dtype=dtype) | ||
dtype = inp.dtype | ||
if dtype is torch.bool: | ||
inp = inp.to(torch.int64) | ||
dtype = torch.int64 | ||
|
||
if dim == []: | ||
# mean over all elements | ||
if not keepdim: | ||
return mean(inp, dtype=dtype) | ||
else: | ||
dim_num = inp.ndim | ||
return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) | ||
|
||
shape = list(inp.shape) | ||
|
||
# -------- normalize dim to a list of ints -------- | ||
if isinstance(dim, int): | ||
dim = [dim] | ||
else: | ||
try: | ||
dim = list(dim) | ||
except TypeError: | ||
raise TypeError( | ||
f"dim must be an int, iterable of ints, or [], got {type(dim)}" | ||
) | ||
|
||
dim = [d % inp.ndim for d in dim] | ||
# ------------------------------------------------- | ||
|
||
if len(dim) == 1: | ||
dim0 = dim[0] | ||
N = inp.shape[dim0] # reduction length | ||
# product of dims before dim0; use initializer 1 for empty slice | ||
M = reduce(lambda x, y: x * y, shape[:dim0], 1) | ||
inp = inp.contiguous() | ||
K = inp.numel() // M // N | ||
shape[dim0] = 1 | ||
if out is None: | ||
out = torch.empty(shape, dtype=dtype, device=inp.device) | ||
|
||
with torch_device_fn.device(inp.device): | ||
if K > 1: | ||
grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) | ||
mean_dim_kernel_non_inner[grid]( | ||
out, | ||
inp, | ||
M, | ||
N, | ||
K, | ||
) | ||
else: | ||
grid = (M, 1, 1) | ||
mean_dim_kernel_inner[grid]( | ||
out, | ||
inp, | ||
M, | ||
N, | ||
) | ||
if not keepdim: | ||
out = out.reshape([1] * x.ndim) | ||
out = out.squeeze(dim=dim0) | ||
return out | ||
else: | ||
inp = dim_compress(inp, dim) | ||
N = 1 | ||
for i in dim: | ||
N *= shape[i] | ||
shape[i] = 1 | ||
M = inp.numel() // N | ||
if out is None: | ||
out = torch.empty(shape, dtype=dtype, device=inp.device) | ||
|
||
shape = list(x.shape) | ||
dim = [d % x.ndim for d in dim] | ||
x = dim_compress(x, dim) | ||
N = 1 | ||
for i in dim: | ||
N *= shape[i] | ||
shape[i] = 1 | ||
M = x.numel() // N | ||
out = torch.empty(shape, dtype=dtype, device=x.device) | ||
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) | ||
|
||
with torch_device_fn.device(x.device): | ||
mean_dim_kernel[grid](x, out, M, N) | ||
if not keepdim: | ||
out = out.squeeze(dim) | ||
return out | ||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
with torch_device_fn.device(inp.device): | ||
mean_dim_kernel[grid](inp, out, M, N) | ||
if not keepdim: | ||
out = out.squeeze(dim=dim) | ||
return out | ||
|
||
|
||
def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): | ||
logger.debug("GEMS MEAN_DIM (wrapper)") | ||
return mean_dim_comm(inp, dim, keepdim, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For simplicity and clarity, consider combining mean_dim_comm
and the mean_dim
wrapper into a single mean_dim
function. The current implementation with a simple wrapper adds an unnecessary layer of indirection. Exposing the out
parameter in the public mean_dim
function is also consistent with the PyTorch API.
def mean_dim(inp, dim=None, keepdim=False, *, dtype=None, out=None):
logger.debug("GEMS MEAN_DIM")
if dtype is None:
dtype = inp.dtype
if dtype is torch.bool:
inp = inp.to(torch.int64)
dtype = torch.int64
if dim == []:
# mean over all elements
if not keepdim:
return mean(inp, dtype=dtype)
else:
dim_num = inp.ndim
return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)
shape = list(inp.shape)
# -------- normalize dim to a list of ints --------
if isinstance(dim, int):
dim = [dim]
else:
try:
dim = list(dim)
except TypeError:
raise TypeError(
f"dim must be an int, iterable of ints, or [], got {type(dim)}"
)
dim = [d % inp.ndim for d in dim]
# -------------------------------------------------
if len(dim) == 1:
dim0 = dim[0]
N = inp.shape[dim0] # reduction length
# product of dims before dim0; use initializer 1 for empty slice
M = reduce(lambda x, y: x * y, shape[:dim0], 1)
inp = inp.contiguous()
K = inp.numel() // M // N
shape[dim0] = 1
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)
with torch_device_fn.device(inp.device):
if K > 1:
grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
mean_dim_kernel_non_inner[grid](
out,
inp,
M,
N,
K,
)
else:
grid = (M, 1, 1)
mean_dim_kernel_inner[grid](
out,
inp,
M,
N,
)
if not keepdim:
out = out.squeeze(dim=dim0)
return out
else:
inp = dim_compress(inp, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = inp.numel() // N
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch_device_fn.device(inp.device):
mean_dim_kernel[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
return out
limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) | ||
N = args.get("N", 1) | ||
desired = min(max(N, _MIN_TILE_N), limit_by_k) | ||
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limit_by_k
variable in this min
call is redundant. The value of desired
is already constrained by limit_by_k
in the previous line. Removing the redundant variable will make the code clearer.
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) | |
desired = min(desired, _MAX_ONE_TILE_N) |
PR Category
Operator
Type of Change
Performance Optimization
Description
Issue
Progress
Performance