-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add cov op #276
base: master
Are you sure you want to change the base?
Conversation
What's the performance metric? |
I've added a benchmark to test the accuracy and performance of the covariance operation compared to Torch’s implementation: @triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[2**i for i in range(4, 16)],
x_log=True,
line_arg='provider',
line_vals=['torch', 'triton'],
line_names=['Torch', 'Triton'],
styles=[('blue', '-'), ('green', '-')],
ylabel='GB/s',
plot_name='covariance-benchmark', # Name for the plot.
args={'M': 1024, 'correction': 1}, # Default values for parameters other than N.
)
)
def benchmark_cov(M, N, correction, provider):
X = torch.randn(M, N, device='cuda', dtype=torch.float32)
fweights = torch.randint(1, 5, (N,), dtype=torch.int32, device=X.device)
aweights = torch.rand(N, device='cuda') + 0.1 # Avoid zeros in weights
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.cov(X, correction=correction, fweights=fweights, aweights=aweights),
quantiles=quantiles
)
result = torch.cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()
elif provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cov(X, correction=correction, fweights=fweights, aweights=aweights),
quantiles=quantiles
)
result = cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()
gbps = lambda ms: 3 * X.numel() * X.element_size() * 1e-9 / (ms * 1e-3)
if provider == 'triton':
torch_result = torch.cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()
precision_diff = torch.max(torch.abs(torch_result - result))
print(f'The maximum difference between Torch and Triton is {precision_diff.item()}')
else:
precision_diff = torch.tensor(0.0) # No difference for Torch itself
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark_cov.run(print_data=True, show_plots=True, save_path=".") I’ve attached the results as a picture for reference. Let me know if additional details are needed! |
src/flag_gems/ops/cov.py
Outdated
def cov(X, correction=1, fweights=None, aweights=None): | ||
logging.debug("GEMS ") | ||
|
||
if not X.is_cuda: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gems is behind the Pytorch dispatcher and doesn't need to handle non-cuda inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tongxin Thanks for the feedback! I’ve updated the code to remove it in the latest commit.
src/flag_gems/ops/cov.py
Outdated
mean = torch.zeros(M, device=X.device, dtype=X.dtype) | ||
cov_matrix = torch.zeros((M, M), device=X.device, dtype=X.dtype) | ||
|
||
BLOCK_SIZE = min(256, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work if N < 256 and N is not power of 2.
src/flag_gems/ops/cov.py
Outdated
BLOCK_SIZE = min(256, N) | ||
num_blocks = (N + BLOCK_SIZE - 1) // BLOCK_SIZE | ||
|
||
grid = lambda meta: (M, num_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grid.y limit is 65535 for cuda so there'll be kernel param error if row size is larger than 256 * 65535, roughly 16m. Probably need a gsl style kernel or split kernels.
src/flag_gems/ops/cov.py
Outdated
mean_kernel[grid](X, mean, M, N, weights, BLOCK_SIZE=BLOCK_SIZE) | ||
mean = mean / total_weight | ||
|
||
grid_cov = lambda meta: (M, M, num_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, M is subject to maximum of 65535, which could be an issue.
src/flag_gems/ops/cov.py
Outdated
tl.atomic_add(cov_matrix + row * M + col, cov) | ||
|
||
def cov(X, correction=1, fweights=None, aweights=None): | ||
logging.debug("GEMS ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"GEMS COV"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, thanks for pointing that out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @RubiaCx , are you planning on further revise?
Merge updates from upstream master into develop branch to keep it up-to-date.
This update commit addresses the issue where the COV calculation would fail if However, I encountered an "illegal memory access was encountered" error when |
You probably should try gsl style kernel to both reduce kernel calls and contain cta number. |
Could you please try resolve conflicts and we're able to merge this PR? |
for i in range(num_row_chunks): | ||
row_offset = i * MAX_GRID_NUM | ||
current_M = min(MAX_GRID_NUM, M - row_offset) | ||
grid = (current_M,) | ||
mean_kernel[grid](X, mean, M, N, weights, row_offset=row_offset, BLOCK_SIZE=BLOCK_SIZE) | ||
mean = mean / sum_weights | ||
|
||
for i in range(num_row_chunks): | ||
row_offset = i * MAX_GRID_NUM | ||
current_rows = min(MAX_GRID_NUM, M - row_offset) | ||
for j in range(num_row_chunks): | ||
col_offset = j * MAX_GRID_NUM | ||
current_cols = min(MAX_GRID_NUM, M - col_offset) | ||
grid = (current_rows, current_cols) | ||
covariance_kernel[grid](X, cov_matrix, mean, M, N, weights, row_offset=row_offset, col_offset=col_offset, BLOCK_SIZE=BLOCK_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not gsl style kernel as I previously mentioned. Multiple kernel invocations should be avoided as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I found #91 and will refactor the cov
function accordingly to adopt the GSL-style kernel as suggested. Thanks for pointing this out!
PR Category
Operator
Type of Change
New Feature
Description
Add cov op
Issue
#256
Performance
Tested on NV-A100-80G