1
+ import inspect
1
2
import os
2
3
import re
3
4
import textwrap
13
14
from transformers import PreTrainedModel
14
15
15
16
17
+ def trace_calls (func ):
18
+ func .has_called = False
19
+
20
+ def wrapper (* args , ** kwargs ):
21
+ if not func .has_called :
22
+ func .has_called = True
23
+ stack = inspect .stack ()
24
+ print (f"Function '{ func .__name__ } ' was called. Call stack:" )
25
+ for frame_info in stack [1 :]:
26
+ print (f" - Function '{ frame_info .function } ' in { frame_info .filename } , line { frame_info .lineno } " )
27
+ return func (* args , ** kwargs )
28
+
29
+ return wrapper
30
+
31
+
16
32
def get_distributed_type ():
17
33
distributed_type = DistributedType .DEEPSPEED if "ACCELERATE_USE_DEEPSPEED" in os .environ else DistributedType .NO
18
34
return distributed_type
@@ -170,4 +186,59 @@ def guess_grad_norms_from_hf_trainer(
170
186
table_to_print = tabulate .tabulate (
171
187
df [df ["layer" ] == layer_idx ], headers = 'keys' , tablefmt = 'psql' , showindex = False
172
188
)
173
- print_tabulate_with_header (table_to_print , f"Layer { layer_idx } , { note } " )
189
+ print_tabulate_with_header (table_to_print , f"Layer { layer_idx } , { note } " )
190
+
191
+
192
+ def tensor_all_zero (tensor : torch .Tensor ) -> bool :
193
+ return torch .equal (tensor , torch .zeros_like (tensor ))
194
+
195
+
196
+ def guess_grad_all_zero_from_pg (
197
+ parameter_names : List [Dict [str , str ]],
198
+ all_grads : List [torch .Tensor ],
199
+ show_zero_grads : bool = False ,
200
+ separate_by_layer : bool = False ,
201
+ ):
202
+ all_grad_status = {
203
+ "name" : [],
204
+ "layer" : [],
205
+ "grad_all_zero" : [],
206
+ }
207
+ has_guess = False
208
+ pg_note = None
209
+
210
+ for pg_idx , pg_names in enumerate (parameter_names ):
211
+ if len (pg_names ["parameter_names" ]) == len (all_grads ):
212
+ all_grad_status ["name" ] = pg_names ["parameter_names" ]
213
+ all_grad_status ["grad_all_zero" ] = [tensor_all_zero (grad_tensor ) for grad_tensor in all_grads ]
214
+ if not has_guess :
215
+ has_guess = True
216
+ pg_note = 'Parameter group with weight decay' if pg_idx == 0 else 'Parameter group without weight decay'
217
+ else :
218
+ print ("Failed to guess grad norms from parameter groups according to group length." )
219
+ return
220
+
221
+ if not has_guess :
222
+ return
223
+
224
+ layer_pattern = re .compile (r'transformer\.h\.(\d+)\.' )
225
+ for name in all_grad_status ["name" ]:
226
+ layer_match = layer_pattern .search (name )
227
+ if layer_match :
228
+ all_grad_status ["layer" ].append (int (layer_match .group (1 )))
229
+ else :
230
+ all_grad_status ["layer" ].append ('other' )
231
+
232
+ df = pd .DataFrame (all_grad_status )
233
+ if not show_zero_grads :
234
+ df = df [df ["grad_all_zero" ] == False ]
235
+
236
+ if not separate_by_layer :
237
+ table_to_print = tabulate .tabulate (df , headers = 'keys' , tablefmt = 'psql' , showindex = False )
238
+ print_tabulate_with_header (table_to_print , pg_note )
239
+ else :
240
+ for layer_idx in df ["layer" ].unique ():
241
+ table_to_print = tabulate .tabulate (
242
+ df [df ["layer" ] == layer_idx ], headers = 'keys' , tablefmt = 'psql' , showindex = False
243
+ )
244
+ print_tabulate_with_header (table_to_print , f"Layer { layer_idx } , { pg_note } " )
0 commit comments