@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
538
538
ct .c_int32 (blocksize ),
539
539
stream ,
540
540
)
541
+
542
+
543
+ """C FUNCTIONS FOR OPTIMIZERS"""
544
+ str2optimizer32bit = {
545
+ "adam" : (
546
+ lib .cadam32bit_grad_fp32 ,
547
+ lib .cadam32bit_grad_fp16 ,
548
+ lib .cadam32bit_grad_bf16 ,
549
+ ),
550
+ "momentum" : (
551
+ lib .cmomentum32bit_grad_32 ,
552
+ lib .cmomentum32bit_grad_16 ,
553
+ ),
554
+ "rmsprop" : (
555
+ lib .crmsprop32bit_grad_32 ,
556
+ lib .crmsprop32bit_grad_16 ,
557
+ ),
558
+ "lion" : (
559
+ lib .clion32bit_grad_fp32 ,
560
+ lib .clion32bit_grad_fp16 ,
561
+ lib .clion32bit_grad_bf16 ,
562
+ ),
563
+ "adagrad" : (
564
+ lib .cadagrad32bit_grad_32 ,
565
+ lib .cadagrad32bit_grad_16 ,
566
+ ),
567
+ "lamb" : (
568
+ lib .cadam32bit_grad_fp32 ,
569
+ lib .cadam32bit_grad_fp16 ,
570
+ lib .cadam32bit_grad_bf16 ,
571
+ ),
572
+ "ademamix" : (
573
+ lib .cademamix32bit_grad_fp32 ,
574
+ lib .cademamix32bit_grad_fp16 ,
575
+ lib .cademamix32bit_grad_bf16 ,
576
+ ),
577
+ }
578
+
579
+ str2optimizer8bit_blockwise = {
580
+ "adam" : (
581
+ lib .cadam_8bit_blockwise_grad_fp32 ,
582
+ lib .cadam_8bit_blockwise_grad_fp16 ,
583
+ lib .cadam_8bit_blockwise_grad_bf16 ,
584
+ ),
585
+ "momentum" : (
586
+ lib .cmomentum_8bit_blockwise_grad_fp32 ,
587
+ lib .cmomentum_8bit_blockwise_grad_fp16 ,
588
+ lib .cmomentum_8bit_blockwise_grad_bf16 ,
589
+ ),
590
+ "rmsprop" : (
591
+ lib .crmsprop_8bit_blockwise_grad_fp32 ,
592
+ lib .crmsprop_8bit_blockwise_grad_fp16 ,
593
+ lib .crmsprop_8bit_blockwise_grad_bf16 ,
594
+ ),
595
+ "lion" : (
596
+ lib .clion_8bit_blockwise_grad_fp32 ,
597
+ lib .clion_8bit_blockwise_grad_fp16 ,
598
+ lib .clion_8bit_blockwise_grad_bf16 ,
599
+ ),
600
+ "adagrad" : (
601
+ lib .cadagrad_8bit_blockwise_grad_fp32 ,
602
+ lib .cadagrad_8bit_blockwise_grad_fp16 ,
603
+ lib .cadagrad_8bit_blockwise_grad_bf16 ,
604
+ ),
605
+ "ademamix" : (
606
+ lib .cademamix_8bit_blockwise_grad_fp32 ,
607
+ lib .cademamix_8bit_blockwise_grad_fp16 ,
608
+ lib .cademamix_8bit_blockwise_grad_bf16 ,
609
+ ),
610
+ }
611
+
612
+
613
+ def _optimizer_update_32bit_impl (
614
+ optimizer_name : str ,
615
+ g : torch .Tensor ,
616
+ p : torch .Tensor ,
617
+ state1 : torch .Tensor ,
618
+ state2 : Optional [torch .Tensor ],
619
+ unorm_vec : Optional [torch .Tensor ],
620
+ max_unorm : float ,
621
+ param_norm : float ,
622
+ beta1 : float ,
623
+ beta2 : float ,
624
+ beta3 : float ,
625
+ alpha : float ,
626
+ eps : float ,
627
+ weight_decay : float ,
628
+ step : int ,
629
+ lr : float ,
630
+ gnorm_scale : float ,
631
+ skip_zeros = False ,
632
+ ) -> None :
633
+ optim_fns = str2optimizer32bit .get (optimizer_name , None )
634
+ if optim_fns is None :
635
+ raise ValueError (
636
+ f"Unsupported optimizer name: { optimizer_name } . Supported optimizers: { list (str2optimizer8bit_blockwise .keys ())} "
637
+ )
638
+ if g .dtype == torch .float32 :
639
+ optim_func = optim_fns [0 ]
640
+ elif g .dtype == torch .float16 :
641
+ optim_func = optim_fns [1 ]
642
+ elif g .dtype == torch .bfloat16 and len (optim_fns ) == 3 :
643
+ optim_func = optim_fns [2 ]
644
+ else :
645
+ raise ValueError (
646
+ f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
647
+ )
648
+
649
+ with _cuda_device_of (g ):
650
+ optim_func (
651
+ get_ptr (g ),
652
+ get_ptr (p ),
653
+ get_ptr (state1 ),
654
+ get_ptr (state2 ),
655
+ get_ptr (unorm_vec ),
656
+ ct .c_float (max_unorm ),
657
+ ct .c_float (param_norm ),
658
+ ct .c_float (beta1 ),
659
+ ct .c_float (beta2 ),
660
+ ct .c_float (beta3 ),
661
+ ct .c_float (alpha ),
662
+ ct .c_float (eps ),
663
+ ct .c_float (weight_decay ),
664
+ ct .c_int32 (step ),
665
+ ct .c_float (lr ),
666
+ ct .c_float (gnorm_scale ),
667
+ ct .c_bool (skip_zeros ),
668
+ ct .c_int32 (g .numel ()),
669
+ )
670
+
671
+
672
+ def _optimizer_update_8bit_blockwise_impl (
673
+ optimizer_name : str ,
674
+ g : torch .Tensor ,
675
+ p : torch .Tensor ,
676
+ state1 : torch .Tensor ,
677
+ state2 : Optional [torch .Tensor ],
678
+ beta1 : float ,
679
+ beta2 : float ,
680
+ beta3 : float ,
681
+ alpha : float ,
682
+ eps : float ,
683
+ step : int ,
684
+ lr : float ,
685
+ qmap1 : torch .Tensor ,
686
+ qmap2 : Optional [torch .Tensor ],
687
+ absmax1 : torch .Tensor ,
688
+ absmax2 : Optional [torch .Tensor ],
689
+ weight_decay : float ,
690
+ gnorm_scale : float ,
691
+ skip_zeros = False ,
692
+ ) -> None :
693
+ # torch._check(
694
+ # g.numel() == p.numel(),
695
+ # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
696
+ # )
697
+ # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
698
+
699
+ # torch._check(
700
+ # g.dtype in compute_dtypes,
701
+ # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
702
+ # )
703
+ # torch._check(
704
+ # g.dtype == p.dtype,
705
+ # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
706
+ # )
707
+ # torch._check(
708
+ # state1.dtype == torch.uint8,
709
+ # lambda: f"state1 must be uint8, got {state1.dtype}",
710
+ # )
711
+ # torch._check(
712
+ # qmap1.dtype == absmax1.dtype == torch.float32,
713
+ # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
714
+ # )
715
+ # if state2 is not None:
716
+ # torch._check(
717
+ # state2.dtype == torch.uint8,
718
+ # lambda: f"state2 must be uint8, got {state2.dtype}",
719
+ # )
720
+ # torch._check(
721
+ # qmap2.dtype == absmax2.dtype == torch.float32,
722
+ # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
723
+ # )
724
+ optimizer_fns = str2optimizer8bit_blockwise .get (optimizer_name )
725
+ if optimizer_fns is None :
726
+ raise ValueError (
727
+ f"Unsupported optimizer name: { optimizer_name } . Supported optimizers: { list (str2optimizer8bit_blockwise .keys ())} "
728
+ )
729
+
730
+ if g .dtype == torch .float32 :
731
+ optimizer_fn = optimizer_fns [0 ]
732
+ elif g .dtype == torch .float16 :
733
+ optimizer_fn = optimizer_fns [1 ]
734
+ elif g .dtype == torch .bfloat16 :
735
+ optimizer_fn = optimizer_fns [2 ]
736
+ else :
737
+ raise ValueError (
738
+ f"Unsupported gradient dtype: { g .dtype } . Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
739
+ )
740
+
741
+ with _cuda_device_of (g ):
742
+ optimizer_fn (
743
+ get_ptr (p ),
744
+ get_ptr (g ),
745
+ get_ptr (state1 ),
746
+ get_ptr (state2 ),
747
+ ct .c_float (beta1 ),
748
+ ct .c_float (beta2 ),
749
+ ct .c_float (beta3 ),
750
+ ct .c_float (alpha ),
751
+ ct .c_float (eps ),
752
+ ct .c_int32 (step ),
753
+ ct .c_float (lr ),
754
+ get_ptr (qmap1 ),
755
+ get_ptr (qmap2 ),
756
+ get_ptr (absmax1 ),
757
+ get_ptr (absmax2 ),
758
+ ct .c_float (weight_decay ),
759
+ ct .c_float (gnorm_scale ),
760
+ ct .c_bool (skip_zeros ),
761
+ ct .c_int32 (g .numel ()),
762
+ )
763
+
764
+
765
+ register_kernel ("bitsandbytes::optimizer_update_8bit_blockwise" , "cuda" )(_optimizer_update_8bit_blockwise_impl )
766
+ register_kernel ("bitsandbytes::optimizer_update_32bit" , "cuda" )(_optimizer_update_32bit_impl )
0 commit comments