Skip to content

Commit 941681d

Browse files
Merge pull request #1706 from Egor-Krivov/egor/8bit_int
Add kernel registration for 8bit and 32bit optimizers
2 parents adc7fda + 0f6fe6b commit 941681d

File tree

7 files changed

+417
-172
lines changed

7 files changed

+417
-172
lines changed

bitsandbytes/_ops.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,107 @@ def _(
348348
) -> torch.Tensor:
349349
torch._check_is_size(blocksize)
350350
return torch.empty(shape, dtype=dtype, device=A.device)
351+
352+
353+
torch.library.define(
354+
"bitsandbytes::optimizer_update_32bit",
355+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
356+
)
357+
358+
359+
@register_fake("bitsandbytes::optimizer_update_32bit")
360+
def _(
361+
optimizer_name: str,
362+
g: torch.Tensor,
363+
p: torch.Tensor,
364+
state1: torch.Tensor,
365+
state2: Optional[torch.Tensor],
366+
unorm_vec: Optional[torch.Tensor],
367+
max_unorm: float,
368+
param_norm: float,
369+
beta1: float,
370+
beta2: float,
371+
beta3: float,
372+
alpha: float,
373+
eps: float,
374+
weight_decay: float,
375+
step: int,
376+
lr: float,
377+
gnorm_scale: float,
378+
skip_zeros=False,
379+
) -> None:
380+
torch._check(
381+
g.numel() == p.numel(),
382+
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
383+
)
384+
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
385+
386+
torch._check(
387+
g.dtype in compute_dtypes,
388+
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
389+
)
390+
torch._check(
391+
g.dtype == p.dtype,
392+
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
393+
)
394+
395+
396+
torch.library.define(
397+
"bitsandbytes::optimizer_update_8bit_blockwise",
398+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
399+
)
400+
401+
402+
@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
403+
def _(
404+
optimizer_name: str,
405+
g: torch.Tensor,
406+
p: torch.Tensor,
407+
state1: torch.Tensor,
408+
state2: Optional[torch.Tensor],
409+
beta1: float,
410+
beta2: float,
411+
beta3: float,
412+
alpha: float,
413+
eps: float,
414+
step: int,
415+
lr: float,
416+
qmap1: torch.Tensor,
417+
qmap2: Optional[torch.Tensor],
418+
absmax1: torch.Tensor,
419+
absmax2: Optional[torch.Tensor],
420+
weight_decay: float,
421+
gnorm_scale: float,
422+
skip_zeros=False,
423+
) -> None:
424+
torch._check(
425+
g.numel() == p.numel(),
426+
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
427+
)
428+
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
429+
430+
torch._check(
431+
g.dtype in compute_dtypes,
432+
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
433+
)
434+
torch._check(
435+
g.dtype == p.dtype,
436+
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
437+
)
438+
torch._check(
439+
state1.dtype == torch.uint8,
440+
lambda: f"state1 must be uint8, got {state1.dtype}",
441+
)
442+
torch._check(
443+
qmap1.dtype == absmax1.dtype == torch.float32,
444+
lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
445+
)
446+
if state2 is not None:
447+
torch._check(
448+
state2.dtype == torch.uint8,
449+
lambda: f"state2 must be uint8, got {state2.dtype}",
450+
)
451+
torch._check(
452+
qmap2.dtype == absmax2.dtype == torch.float32,
453+
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
454+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
538538
ct.c_int32(blocksize),
539539
stream,
540540
)
541+
542+
543+
"""C FUNCTIONS FOR OPTIMIZERS"""
544+
str2optimizer32bit = {
545+
"adam": (
546+
lib.cadam32bit_grad_fp32,
547+
lib.cadam32bit_grad_fp16,
548+
lib.cadam32bit_grad_bf16,
549+
),
550+
"momentum": (
551+
lib.cmomentum32bit_grad_32,
552+
lib.cmomentum32bit_grad_16,
553+
),
554+
"rmsprop": (
555+
lib.crmsprop32bit_grad_32,
556+
lib.crmsprop32bit_grad_16,
557+
),
558+
"lion": (
559+
lib.clion32bit_grad_fp32,
560+
lib.clion32bit_grad_fp16,
561+
lib.clion32bit_grad_bf16,
562+
),
563+
"adagrad": (
564+
lib.cadagrad32bit_grad_32,
565+
lib.cadagrad32bit_grad_16,
566+
),
567+
"lamb": (
568+
lib.cadam32bit_grad_fp32,
569+
lib.cadam32bit_grad_fp16,
570+
lib.cadam32bit_grad_bf16,
571+
),
572+
"ademamix": (
573+
lib.cademamix32bit_grad_fp32,
574+
lib.cademamix32bit_grad_fp16,
575+
lib.cademamix32bit_grad_bf16,
576+
),
577+
}
578+
579+
str2optimizer8bit_blockwise = {
580+
"adam": (
581+
lib.cadam_8bit_blockwise_grad_fp32,
582+
lib.cadam_8bit_blockwise_grad_fp16,
583+
lib.cadam_8bit_blockwise_grad_bf16,
584+
),
585+
"momentum": (
586+
lib.cmomentum_8bit_blockwise_grad_fp32,
587+
lib.cmomentum_8bit_blockwise_grad_fp16,
588+
lib.cmomentum_8bit_blockwise_grad_bf16,
589+
),
590+
"rmsprop": (
591+
lib.crmsprop_8bit_blockwise_grad_fp32,
592+
lib.crmsprop_8bit_blockwise_grad_fp16,
593+
lib.crmsprop_8bit_blockwise_grad_bf16,
594+
),
595+
"lion": (
596+
lib.clion_8bit_blockwise_grad_fp32,
597+
lib.clion_8bit_blockwise_grad_fp16,
598+
lib.clion_8bit_blockwise_grad_bf16,
599+
),
600+
"adagrad": (
601+
lib.cadagrad_8bit_blockwise_grad_fp32,
602+
lib.cadagrad_8bit_blockwise_grad_fp16,
603+
lib.cadagrad_8bit_blockwise_grad_bf16,
604+
),
605+
"ademamix": (
606+
lib.cademamix_8bit_blockwise_grad_fp32,
607+
lib.cademamix_8bit_blockwise_grad_fp16,
608+
lib.cademamix_8bit_blockwise_grad_bf16,
609+
),
610+
}
611+
612+
613+
def _optimizer_update_32bit_impl(
614+
optimizer_name: str,
615+
g: torch.Tensor,
616+
p: torch.Tensor,
617+
state1: torch.Tensor,
618+
state2: Optional[torch.Tensor],
619+
unorm_vec: Optional[torch.Tensor],
620+
max_unorm: float,
621+
param_norm: float,
622+
beta1: float,
623+
beta2: float,
624+
beta3: float,
625+
alpha: float,
626+
eps: float,
627+
weight_decay: float,
628+
step: int,
629+
lr: float,
630+
gnorm_scale: float,
631+
skip_zeros=False,
632+
) -> None:
633+
optim_fns = str2optimizer32bit.get(optimizer_name, None)
634+
if optim_fns is None:
635+
raise ValueError(
636+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
637+
)
638+
if g.dtype == torch.float32:
639+
optim_func = optim_fns[0]
640+
elif g.dtype == torch.float16:
641+
optim_func = optim_fns[1]
642+
elif g.dtype == torch.bfloat16 and len(optim_fns) == 3:
643+
optim_func = optim_fns[2]
644+
else:
645+
raise ValueError(
646+
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
647+
)
648+
649+
with _cuda_device_of(g):
650+
optim_func(
651+
get_ptr(g),
652+
get_ptr(p),
653+
get_ptr(state1),
654+
get_ptr(state2),
655+
get_ptr(unorm_vec),
656+
ct.c_float(max_unorm),
657+
ct.c_float(param_norm),
658+
ct.c_float(beta1),
659+
ct.c_float(beta2),
660+
ct.c_float(beta3),
661+
ct.c_float(alpha),
662+
ct.c_float(eps),
663+
ct.c_float(weight_decay),
664+
ct.c_int32(step),
665+
ct.c_float(lr),
666+
ct.c_float(gnorm_scale),
667+
ct.c_bool(skip_zeros),
668+
ct.c_int32(g.numel()),
669+
)
670+
671+
672+
def _optimizer_update_8bit_blockwise_impl(
673+
optimizer_name: str,
674+
g: torch.Tensor,
675+
p: torch.Tensor,
676+
state1: torch.Tensor,
677+
state2: Optional[torch.Tensor],
678+
beta1: float,
679+
beta2: float,
680+
beta3: float,
681+
alpha: float,
682+
eps: float,
683+
step: int,
684+
lr: float,
685+
qmap1: torch.Tensor,
686+
qmap2: Optional[torch.Tensor],
687+
absmax1: torch.Tensor,
688+
absmax2: Optional[torch.Tensor],
689+
weight_decay: float,
690+
gnorm_scale: float,
691+
skip_zeros=False,
692+
) -> None:
693+
# torch._check(
694+
# g.numel() == p.numel(),
695+
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
696+
# )
697+
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
698+
699+
# torch._check(
700+
# g.dtype in compute_dtypes,
701+
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
702+
# )
703+
# torch._check(
704+
# g.dtype == p.dtype,
705+
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
706+
# )
707+
# torch._check(
708+
# state1.dtype == torch.uint8,
709+
# lambda: f"state1 must be uint8, got {state1.dtype}",
710+
# )
711+
# torch._check(
712+
# qmap1.dtype == absmax1.dtype == torch.float32,
713+
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
714+
# )
715+
# if state2 is not None:
716+
# torch._check(
717+
# state2.dtype == torch.uint8,
718+
# lambda: f"state2 must be uint8, got {state2.dtype}",
719+
# )
720+
# torch._check(
721+
# qmap2.dtype == absmax2.dtype == torch.float32,
722+
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
723+
# )
724+
optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name)
725+
if optimizer_fns is None:
726+
raise ValueError(
727+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
728+
)
729+
730+
if g.dtype == torch.float32:
731+
optimizer_fn = optimizer_fns[0]
732+
elif g.dtype == torch.float16:
733+
optimizer_fn = optimizer_fns[1]
734+
elif g.dtype == torch.bfloat16:
735+
optimizer_fn = optimizer_fns[2]
736+
else:
737+
raise ValueError(
738+
f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
739+
)
740+
741+
with _cuda_device_of(g):
742+
optimizer_fn(
743+
get_ptr(p),
744+
get_ptr(g),
745+
get_ptr(state1),
746+
get_ptr(state2),
747+
ct.c_float(beta1),
748+
ct.c_float(beta2),
749+
ct.c_float(beta3),
750+
ct.c_float(alpha),
751+
ct.c_float(eps),
752+
ct.c_int32(step),
753+
ct.c_float(lr),
754+
get_ptr(qmap1),
755+
get_ptr(qmap2),
756+
get_ptr(absmax1),
757+
get_ptr(absmax2),
758+
ct.c_float(weight_decay),
759+
ct.c_float(gnorm_scale),
760+
ct.c_bool(skip_zeros),
761+
ct.c_int32(g.numel()),
762+
)
763+
764+
765+
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
766+
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)

0 commit comments

Comments
 (0)