Skip to content

Conversation

Blury233
Copy link
Contributor

PR Category

Operator

Type of Change

New Feature

Description

In comparison with various reduction methods, the single-thread block shows better performance.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Accuracy Test
acc

Performance Test
perf1
perf2

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Blury233, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new trace operator to the flag_gems library, designed to efficiently compute the sum of the diagonal elements of a 2D tensor. The implementation leverages a custom Triton kernel, aiming for improved performance compared to standard reduction methods, as indicated by the PR description. The changes include the core Triton kernel, its Python wrapper, integration into the flag_gems operator registry, and comprehensive accuracy and performance tests to ensure correctness and evaluate efficiency across diverse input scenarios.

Highlights

  • New trace Operator: Introduced a new trace operator, implemented with a custom Triton kernel for potentially optimized performance.
  • Triton Kernel Implementation: A trace_kernel using Triton language (tl) was added to efficiently calculate the sum of diagonal elements for 2D tensors, handling various data types and block sizes.
  • Comprehensive Testing: New accuracy tests (test_accuracy_trace) cover various tensor shapes and data types (float, int, bool), including edge cases like empty diagonals, ensuring correctness against torch.trace.
  • Performance Benchmarking: A performance benchmark (test_perf_trace) has been added to evaluate the efficiency of the new trace operator using GenericBenchmark2DOnly.
  • Framework Integration: The trace operator has been integrated into the flag_gems framework, making it available for use by importing it and adding it to the enable function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new trace operator implemented with Triton for improved performance. The implementation is clean and follows the project's structure. The accompanying accuracy and performance tests are comprehensive, covering various shapes and data types. I've identified a couple of minor opportunities for improvement: one for code simplification by removing a redundant check, and another for enhancing test coverage. Overall, this is a great addition.

Comment on lines +82 to +83
if BLOCK_SIZE == 0:
BLOCK_SIZE = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for BLOCK_SIZE == 0 is redundant and can be removed. The num_diag value is checked for 0 on line 66, and the function returns early if it is. Therefore, num_diag is guaranteed to be at least 1 at this point. Since triton.next_power_of_2(x) returns a positive integer for any positive x, BLOCK_SIZE will never be 0 here.

if dtype in FLOAT_DTYPES:
gems_assert_close(res_out, ref_out, dtype)
else:
gems_assert_equal(res_out, ref_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test coverage is good. To make it even more robust, consider adding test cases for non-contiguous tensors, such as a transposed matrix. This would ensure that the implementation correctly handles different memory layouts via strides. You could achieve this by creating a transposed tensor within the existing test_accuracy_trace function and running the same assertions on it.

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

grid = (1,)
BLOCK_SIZE = 1024
if num_diag < BLOCK_SIZE:
BLOCK_SIZE = triton.next_power_of_2(num_diag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider partition the task of summing elements on the diagonal to several blocks when the number of elements to sum is large.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can treat it like a normal reduction on a vector whose stride is stride_dim0 + stride_dim1.

@iclementine
Copy link
Collaborator

Since torch.trace's cpu implementation does not support bool, the test fail on op-test-quick-cpu. Please remove torch.bool for testing trace, or only test it when the reference is not on cpu.

______________________ test_accuracy_trace[dtype3-shape6] ______________________

shape = (128, 256), dtype = torch.bool

    @pytest.mark.trace
    @pytest.mark.parametrize("shape", TRACE_SHAPES)
    @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool])
    def test_accuracy_trace(shape, dtype):
        if dtype == torch.bool:
            inp = torch.randint(0, 2, size=shape, device=flag_gems.device).to(dtype)
        elif dtype in INT_DTYPES:
            inp = torch.randint(-100, 100, size=shape, device=flag_gems.device).to(dtype)
        else:
            inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
    
        ref_inp = to_reference(inp)
>       ref_out = torch.trace(ref_inp)
E       RuntimeError: "trace" not implemented for 'Bool'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants