Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add cov op #276

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

[Operator] Add cov op #276

wants to merge 9 commits into from

Conversation

RubiaCx
Copy link

@RubiaCx RubiaCx commented Nov 6, 2024

PR Category

Operator

Type of Change

New Feature

Description

Add cov op

Issue

#256

Performance

Tested on NV-A100-80G
image

@tongxin
Copy link
Contributor

tongxin commented Nov 6, 2024

What's the performance metric?

@RubiaCx
Copy link
Author

RubiaCx commented Nov 6, 2024

What's the performance metric?

I've added a benchmark to test the accuracy and performance of the covariance operation compared to Torch’s implementation:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'], 
        x_vals=[2**i for i in range(4, 16)],
        x_log=True, 
        line_arg='provider',  
        line_vals=['torch', 'triton'],  
        line_names=['Torch', 'Triton'],  
        styles=[('blue', '-'), ('green', '-')],  
        ylabel='GB/s',  
        plot_name='covariance-benchmark',   # Name for the plot.
        args={'M': 1024, 'correction': 1},  # Default values for parameters other than N.
    )
)
def benchmark_cov(M, N, correction, provider):
    X = torch.randn(M, N, device='cuda', dtype=torch.float32)
    fweights = torch.randint(1, 5, (N,), dtype=torch.int32, device=X.device)
    aweights = torch.rand(N, device='cuda') + 0.1  # Avoid zeros in weights

    quantiles = [0.5, 0.2, 0.8]

    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.cov(X, correction=correction, fweights=fweights, aweights=aweights),
            quantiles=quantiles
        )        
        result = torch.cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()
    elif provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: cov(X, correction=correction, fweights=fweights, aweights=aweights),
            quantiles=quantiles
        )
        result = cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()

    gbps = lambda ms: 3 * X.numel() * X.element_size() * 1e-9 / (ms * 1e-3)

    if provider == 'triton':
        torch_result = torch.cov(X, correction=correction, fweights=fweights, aweights=aweights).cpu()
        precision_diff = torch.max(torch.abs(torch_result - result))
        print(f'The maximum difference between Torch and Triton is {precision_diff.item()}')
    else:
        precision_diff = torch.tensor(0.0)  # No difference for Torch itself

    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark_cov.run(print_data=True, show_plots=True, save_path=".")

I’ve attached the results as a picture for reference. Let me know if additional details are needed!

def cov(X, correction=1, fweights=None, aweights=None):
logging.debug("GEMS ")

if not X.is_cuda:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gems is behind the Pytorch dispatcher and doesn't need to handle non-cuda inputs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tongxin Thanks for the feedback! I’ve updated the code to remove it in the latest commit.

mean = torch.zeros(M, device=X.device, dtype=X.dtype)
cov_matrix = torch.zeros((M, M), device=X.device, dtype=X.dtype)

BLOCK_SIZE = min(256, N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work if N < 256 and N is not power of 2.

BLOCK_SIZE = min(256, N)
num_blocks = (N + BLOCK_SIZE - 1) // BLOCK_SIZE

grid = lambda meta: (M, num_blocks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid.y limit is 65535 for cuda so there'll be kernel param error if row size is larger than 256 * 65535, roughly 16m. Probably need a gsl style kernel or split kernels.

mean_kernel[grid](X, mean, M, N, weights, BLOCK_SIZE=BLOCK_SIZE)
mean = mean / total_weight

grid_cov = lambda meta: (M, M, num_blocks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, M is subject to maximum of 65535, which could be an issue.

tl.atomic_add(cov_matrix + row * M + col, cov)

def cov(X, correction=1, fweights=None, aweights=None):
logging.debug("GEMS ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"GEMS COV"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thanks for pointing that out.

@tongxin tongxin self-assigned this Nov 10, 2024
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @RubiaCx , are you planning on further revise?

@RubiaCx
Copy link
Author

RubiaCx commented Nov 20, 2024

Hello @RubiaCx , are you planning on further revise?
@tongxin Yes, I am in the process of fixing bugs, but I've been quite busy recently.

@RubiaCx
Copy link
Author

RubiaCx commented Nov 20, 2024

This update commit addresses the issue where the COV calculation would fail if N < 256 and N is not a power of 2. The code now ensures that BLOCK_SIZE is a power of 2, which resolves this problem.

However, I encountered an "illegal memory access was encountered" error when M exceeds MAX_GRID_NUM.

@tongxin
Copy link
Contributor

tongxin commented Nov 25, 2024

This update commit addresses the issue where the COV calculation would fail if N < 256 and N is not a power of 2. The code now ensures that BLOCK_SIZE is a power of 2, which resolves this problem.

However, I encountered an "illegal memory access was encountered" error when M exceeds MAX_GRID_NUM.

You probably should try gsl style kernel to both reduce kernel calls and contain cta number.

@tongxin
Copy link
Contributor

tongxin commented Dec 13, 2024

Could you please try resolve conflicts and we're able to merge this PR?

Comment on lines +93 to +107
for i in range(num_row_chunks):
row_offset = i * MAX_GRID_NUM
current_M = min(MAX_GRID_NUM, M - row_offset)
grid = (current_M,)
mean_kernel[grid](X, mean, M, N, weights, row_offset=row_offset, BLOCK_SIZE=BLOCK_SIZE)
mean = mean / sum_weights

for i in range(num_row_chunks):
row_offset = i * MAX_GRID_NUM
current_rows = min(MAX_GRID_NUM, M - row_offset)
for j in range(num_row_chunks):
col_offset = j * MAX_GRID_NUM
current_cols = min(MAX_GRID_NUM, M - col_offset)
grid = (current_rows, current_cols)
covariance_kernel[grid](X, cov_matrix, mean, M, N, weights, row_offset=row_offset, col_offset=col_offset, BLOCK_SIZE=BLOCK_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not gsl style kernel as I previously mentioned. Multiple kernel invocations should be avoided as much as possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I found #91 and will refactor the cov function accordingly to adopt the GSL-style kernel as suggested. Thanks for pointing this out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants