Skip to content

Add kernel registration for 8bit and 32bit optimizers #1706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 14, 2025

Conversation

Egor-Krivov
Copy link
Contributor

@Egor-Krivov Egor-Krivov commented Jul 11, 2025

Add new kernel bitsandbytes::optimizer_update_8bit_blockwise for 8bit optimizers and bitsandbytes::optimizer_update_32bit for 32bit optimizers. I'm not 100% sure about using a dictionary in a cuda ops.py. I could replace it with a lot of branching (5optimizers * 3 dtypes) if that's preferable.

This kernel registration is necessary for #1692

I also added option to skip cpu from the list of available devices to test properly and simplify testing of GPU-only kernels.

I run tests on cuda, they passed locally on RTX 4070, except for 1 optimizer.

The only issue was 32bit test for Lion optimizer. I retested on the main branch, and there was the same error. Also, this commit doesn't change the implementation. So I increased tolerance in the test, maybe the issue only appears on some client devices, like RTX 4070. With this PR all tests passed.

Traceback:

=================================== FAILURES ======================================================================================
_________________________________________________________ test_optimizer32bit[device=cuda-dim2=4097-dim1=1024-fp32-opt=lion] _________________________________________________________

dim1 = 1024, dim2 = 4097, gtype = torch.float32, optim_name = 'lion', device = 'cuda'

    @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
    @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
    @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
    @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
    @pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
    def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
        if optim_name.startswith("paged_") and sys.platform == "win32":
            pytest.skip("Paged optimizers can have issues on Windows.")
    
        if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
            pytest.skip()
        if dim1 == 1 and dim2 == 1:
            return
        p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
        p2 = p1.clone()
        p1 = p1.float()
    
        torch_optimizer = str2optimizers[optim_name][0]([p1])
        bnb_optimizer = str2optimizers[optim_name][1]([p2])
    
        if gtype == torch.float32:
            atol, rtol = 1e-6, 1e-5
        elif gtype == torch.bfloat16:
            atol, rtol = 1e-3, 1e-2
        else:
            atol, rtol = 1e-4, 1e-3
    
        for i in range(k):
            g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
            p1.grad = g.clone().float()
            p2.grad = g.clone()
    
            bnb_optimizer.step()
            torch_optimizer.step()
    
            for name1, name2 in str2statenames[optim_name]:
                torch.testing.assert_close(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2].to(device),
                    atol=atol,
                    rtol=rtol,
                )
    
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
>           assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)

tests/test_optim.py:213: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

a = tensor([[-0.0921, -0.0439, -0.2642,  ..., -0.1074,  0.1601,  0.2297],
        [ 0.0158,  0.2077, -0.1181,  ..., -0.010...1670, -0.1418,  0.0079],
        [ 0.0991,  0.0394, -0.1615,  ...,  0.0141, -0.1775, -0.0713]],
       device='cuda:0')
b = tensor([[-0.0921, -0.0439, -0.2642,  ..., -0.1074,  0.1601,  0.2297],
        [ 0.0158,  0.2077, -0.1181,  ..., -0.010...1670, -0.1418,  0.0079],
        [ 0.0991,  0.0394, -0.1615,  ...,  0.0141, -0.1775, -0.0713]],
       device='cuda:0')
rtol = 1e-05, atol = 1e-06, max_error_count = 10

    def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
        idx = torch.isclose(a, b, rtol=rtol, atol=atol)
        error_count = (idx == 0).sum().item()
        if error_count > max_error_count:
            print(f"Too many values not close: assert {error_count} < {max_error_count}")
>           torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 11 / 4195328 (0.0%)
E           Greatest absolute difference: 0.00020000338554382324 at index (299, 340) (up to 1e-06 allowed)
E           Greatest relative difference: 0.06716564297676086 at index (601, 1638) (up to 1e-05 allowed)

tests/test_optim.py:27: AssertionError
-------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------
Too many values not close: assert 11 < 10
______________________________________________________ test_optimizer32bit[device=cuda-dim2=4097-dim1=1024-fp32-opt=paged_lion] ______________________________________________________

dim1 = 1024, dim2 = 4097, gtype = torch.float32, optim_name = 'paged_lion', device = 'cuda'

    @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
    @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
    @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
    @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
    @pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
    def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
        if optim_name.startswith("paged_") and sys.platform == "win32":
            pytest.skip("Paged optimizers can have issues on Windows.")
    
        if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
            pytest.skip()
        if dim1 == 1 and dim2 == 1:
            return
        p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
        p2 = p1.clone()
        p1 = p1.float()
    
        torch_optimizer = str2optimizers[optim_name][0]([p1])
        bnb_optimizer = str2optimizers[optim_name][1]([p2])
    
        if gtype == torch.float32:
            atol, rtol = 1e-6, 1e-5
        elif gtype == torch.bfloat16:
            atol, rtol = 1e-3, 1e-2
        else:
            atol, rtol = 1e-4, 1e-3
    
        for i in range(k):
            g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
            p1.grad = g.clone().float()
            p2.grad = g.clone()
    
            bnb_optimizer.step()
            torch_optimizer.step()
    
            for name1, name2 in str2statenames[optim_name]:
                torch.testing.assert_close(
                    torch_optimizer.state[p1][name1],
                    bnb_optimizer.state[p2][name2].to(device),
                    atol=atol,
                    rtol=rtol,
                )
    
            # since Lion can have pretty noisy updates where things lie at the boundary
            # allow up to 10 errors for Lion
>           assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)

tests/test_optim.py:213: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

a = tensor([[-0.0921, -0.0439, -0.2642,  ..., -0.1074,  0.1601,  0.2297],
        [ 0.0158,  0.2077, -0.1181,  ..., -0.010...1670, -0.1418,  0.0079],
        [ 0.0991,  0.0394, -0.1615,  ...,  0.0141, -0.1775, -0.0713]],
       device='cuda:0')
b = tensor([[-0.0921, -0.0439, -0.2642,  ..., -0.1074,  0.1601,  0.2297],
        [ 0.0158,  0.2077, -0.1181,  ..., -0.010...1670, -0.1418,  0.0079],
        [ 0.0991,  0.0394, -0.1615,  ...,  0.0141, -0.1775, -0.0713]],
       device='cuda:0')
rtol = 1e-05, atol = 1e-06, max_error_count = 10

    def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
        idx = torch.isclose(a, b, rtol=rtol, atol=atol)
        error_count = (idx == 0).sum().item()
        if error_count > max_error_count:
            print(f"Too many values not close: assert {error_count} < {max_error_count}")
>           torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 11 / 4195328 (0.0%)
E           Greatest absolute difference: 0.00020000338554382324 at index (299, 340) (up to 1e-06 allowed)
E           Greatest relative difference: 0.06716564297676086 at index (601, 1638) (up to 1e-05 allowed)

tests/test_optim.py:27: AssertionError
-------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------
Too many values not close: assert 11 < 10
================================================================================== warnings summary ==================================================================================
tests/test_optim.py: 124 warnings
  /workspace/bitsandbytes/tests/test_optim.py:221: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))

tests/test_optim.py: 54 warnings
  /workspace/bitsandbytes/tests/test_optim.py:401: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================= PASSES =======================================================================================
================================================== 2 failed, 182 passed, 8 skipped, 33 deselected, 178 warnings in 85.31s (0:01:25) ========

@Egor-Krivov Egor-Krivov changed the title Add kernel registration for 8bit optimizers Add kernel registration for 8bit and 32bit optimizers Jul 14, 2025
@Egor-Krivov
Copy link
Contributor Author

@matthewdouglas The PR is ready for review

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas matthewdouglas added the Optimizers Issues or feature requests relating to optimizers label Jul 14, 2025
@matthewdouglas matthewdouglas linked an issue Jul 14, 2025 that may be closed by this pull request
@matthewdouglas matthewdouglas added this to the v0.47.0 milestone Jul 14, 2025
Comment on lines +212 to +213
# allow up to 15 errors for Lion
assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with this, although the tests passed as they were for me on T4, RTX 4090 and L40S.

@Egor-Krivov
Copy link
Contributor Author

@matthewdouglas I updated schema to describe all mutable arguments and default arguments. Ready for the second round of review. Optimizer tests still pass, I checked on torch==2.5

@matthewdouglas
Copy link
Member

LGTM, thanks for taking care of this!

@matthewdouglas matthewdouglas merged commit 941681d into bitsandbytes-foundation:main Jul 14, 2025
39 checks passed
@Egor-Krivov Egor-Krivov deleted the egor/8bit_int branch July 14, 2025 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Optimizers Issues or feature requests relating to optimizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimizer custom ops
2 participants