@@ -89,13 +89,31 @@ def _preprocessing(self):
8989        g_idx_trivial  =  torch .tensor (
9090            g_idx_trivial , dtype = torch .int32 , device = self .g_idx .device 
9191        )
92-         assert  torch .equal (
93-             self .g_idx , g_idx_trivial 
94-         ), "Non-trivial tensor g_idx is not supported" 
92+         sort_zeros  =  not  (torch .equal (self .g_idx , g_idx_trivial ))
9593        self .qzeros  =  self .qzeros .cpu ()
9694        zeros  =  self .unpack_zeros_from_cuda_old_format ()
97-         new_qzeros  =  pack_tensor (zeros )
98-         self .qzeros  =  new_qzeros .to (orig_device )
95+         if  sort_zeros :
96+             zeros_group_1  =  torch .zeros (
97+                 (self .infeatures , self .outfeatures ),
98+                 dtype = zeros .dtype ,
99+                 device = zeros .device ,
100+             )
101+             scales  =  self .scales .cpu ()
102+             scale_group_1  =  torch .zeros (
103+                 (self .infeatures , self .outfeatures ),
104+                 dtype = scales .dtype ,
105+                 device = scales .device ,
106+             )
107+             for  i  in  range (self .infeatures ):
108+                 zeros_group_1 [i ] =  zeros [self .g_idx [i ]]
109+                 scale_group_1 [i ] =  self .scales [self .g_idx [i ]]
110+             self .qzeros  =  pack_tensor (zeros_group_1 ).to (orig_device )
111+             self .scales  =  scale_group_1 .to (orig_device )
112+             self .groupsize  =  1 
113+             self .g_idx  =  None 
114+         else :
115+             new_qzeros  =  pack_tensor (zeros )
116+             self .qzeros  =  new_qzeros .to (orig_device )
99117
100118    @classmethod  
101119    def  new (cls , bits , groupsize , infeatures , outfeatures , bias ):
0 commit comments