-
Notifications
You must be signed in to change notification settings - Fork 144
Add bitwise_shift_left and bitwise_shift_right #990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Blury233, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces bitwise_shift_left
and bitwise_shift_right
operators, complete with Triton kernels, registration, and both accuracy and performance tests. The implementation leverages pointwise_dynamic
for handling broadcasting, which is a good approach. The overall structure is clean and follows the existing patterns in the codebase. My review comments focus on improving the comprehensiveness of the tests by using a wider, more realistic range of shift values, and on maintaining consistency in logging across the new operator files.
shift_amount = torch.randint( | ||
0, 8, shape, dtype=cur_dtype, device=self.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The range of shift amounts [0, 8)
is a bit narrow for integer types larger than 8 bits. To ensure the performance benchmark covers a more realistic range of inputs, consider using a range for the shift amount that is dependent on the cur_dtype
, up to the number of bits in the type.
shift_amount = torch.randint( | |
0, 8, shape, dtype=cur_dtype, device=self.device | |
) | |
shift_amount = torch.randint( | |
0, torch.iinfo(cur_dtype).bits, shape, dtype=cur_dtype, device=self.device | |
) |
def bitwise_right_shift(self, other, *, out=None): | ||
return bitwise_right_shift_kernel(self, other, out=out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with bitwise_left_shift
and to aid in debugging, it would be good to add a debug log message here.
def bitwise_right_shift(self, other, *, out=None): | |
return bitwise_right_shift_kernel(self, other, out=out) | |
def bitwise_right_shift(self, other, *, out=None): | |
logger.debug("GEMS BITWISE_RIGHT_SHIFT") | |
return bitwise_right_shift_kernel(self, other, out=out) |
def test_accuracy_bitwise_left_shift(shapes, dtype): | ||
shape_a, shape_b = shapes | ||
res_a = torch.randint(0, 100, shape_a, dtype=dtype, device=flag_gems.device) | ||
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The range of shift amounts [0, 8)
is a bit narrow for integer types larger than 8 bits. To improve test coverage, consider using a range that is dependent on the dtype
, up to the number of bits in the type. This will ensure that shifts by larger amounts are also tested.
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) | |
res_b = torch.randint(0, torch.iinfo(dtype).bits, shape_b, dtype=dtype, device=flag_gems.device) |
def test_accuracy_bitwise_right_shift(shapes, dtype): | ||
shape_a, shape_b = shapes | ||
res_a = torch.randint(0, 100, shape_a, dtype=dtype, device=flag_gems.device) | ||
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The range of shift amounts [0, 8)
is a bit narrow for integer types larger than 8 bits. To improve test coverage, consider using a range that is dependent on the dtype
, up to the number of bits in the type. This will ensure that shifts by larger amounts are also tested.
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) | |
res_b = torch.randint(0, torch.iinfo(dtype).bits, shape_b, dtype=dtype, device=flag_gems.device) |
def test_accuracy_bitwise_left_shift_(shapes, dtype): | ||
shape_a, shape_b = shapes | ||
res_a = torch.randint(0, 100, shape_a, dtype=dtype, device=flag_gems.device) | ||
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The range of shift amounts [0, 8)
is a bit narrow for integer types larger than 8 bits. To improve test coverage, consider using a range that is dependent on the dtype
, up to the number of bits in the type. This will ensure that shifts by larger amounts are also tested for the in-place operation.
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) | |
res_b = torch.randint(0, torch.iinfo(dtype).bits, shape_b, dtype=dtype, device=flag_gems.device) |
def test_accuracy_bitwise_right_shift_(shapes, dtype): | ||
shape_a, shape_b = shapes | ||
res_a = torch.randint(0, 100, shape_a, dtype=dtype, device=flag_gems.device) | ||
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The range of shift amounts [0, 8)
is a bit narrow for integer types larger than 8 bits. To improve test coverage, consider using a range that is dependent on the dtype
, up to the number of bits in the type. This will ensure that shifts by larger amounts are also tested for the in-place operation.
res_b = torch.randint(0, 8, shape_b, dtype=dtype, device=flag_gems.device) | |
res_b = torch.randint(0, torch.iinfo(dtype).bits, shape_b, dtype=dtype, device=flag_gems.device) |
PR Category
Operator
Type of Change
New Feature
Description
This PR adds:
Issue
Progress
Performance
Accuracy Test



Performance Test