-
Notifications
You must be signed in to change notification settings - Fork 720
Add kernel registration for 8bit and 32bit optimizers #1706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add kernel registration for 8bit and 32bit optimizers #1706
Conversation
@matthewdouglas The PR is ready for review |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
# allow up to 15 errors for Lion | ||
assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with this, although the tests passed as they were for me on T4, RTX 4090 and L40S.
@matthewdouglas I updated schema to describe all mutable arguments and default arguments. Ready for the second round of review. Optimizer tests still pass, I checked on torch==2.5 |
LGTM, thanks for taking care of this! |
941681d
into
bitsandbytes-foundation:main
Add new kernel
bitsandbytes::optimizer_update_8bit_blockwise
for 8bit optimizers andbitsandbytes::optimizer_update_32bit
for 32bit optimizers. I'm not 100% sure about using a dictionary in a cudaops.py
. I could replace it with a lot of branching (5optimizers * 3 dtypes) if that's preferable.This kernel registration is necessary for #1692
I also added option to skip cpu from the list of available devices to test properly and simplify testing of GPU-only kernels.
I run tests on cuda, they passed locally on RTX 4070, except for 1 optimizer.
The only issue was 32bit test for Lion optimizer. I retested on the main branch, and there was the same error. Also, this commit doesn't change the implementation. So I increased tolerance in the test, maybe the issue only appears on some client devices, like RTX 4070. With this PR all tests passed.
Traceback: