Skip to content

Commit 929e18b

Browse files
committed
[usability] debug tools dev
1 parent 9352b50 commit 929e18b

File tree

2 files changed

+55
-188
lines changed

2 files changed

+55
-188
lines changed

src/lmflow/pipeline/finetuner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from lmflow.datasets.dataset import Dataset
3939
from lmflow.pipeline.base_tuner import BaseTuner
4040
from lmflow.pipeline.utils.peft_trainer import PeftTrainer, PeftSavingCallback
41-
from lmflow.utils.debug import check_layerwise_grad, get_parameter_names_in_param_groups
41+
from lmflow.utils.debug import get_parameter_names_in_param_groups
4242

4343

4444
logger = logging.getLogger(__name__)
@@ -580,10 +580,7 @@ def on_step_end(self, args, state, control, **kwargs):
580580
pass
581581

582582
def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
583-
from lmflow.utils.debug import DistributedType, get_distributed_type
584-
layers = eval('self.' + self.layers_attribute)
585-
if get_distributed_type() != DistributedType.DEEPSPEED:
586-
check_layerwise_grad(layers, note=f"optim step {state.global_step}", show_details='has_grads')
583+
pass
587584

588585

589586
# Instantiate the callback

src/lmflow/utils/debug.py

Lines changed: 53 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_distributed_type():
2121
def print_tabulate_with_header(tabulate_df, header: Optional[str] = None):
2222
if header:
2323
df_len = len(tabulate_df.split('\n')[0])
24+
print('\n')
2425
print('+' + '-'*(df_len-2) + '+')
2526
wrap_header = textwrap.wrap(header, df_len-4)
2627
for header in wrap_header:
@@ -29,185 +30,6 @@ def print_tabulate_with_header(tabulate_df, header: Optional[str] = None):
2930
print(tabulate_df)
3031

3132

32-
def inspect_layer(layers, layer_idx: int, note: Optional[str] = None):
33-
layer_info = {
34-
"name": [],
35-
"size": [],
36-
"requires_grad": [],
37-
"grad_norm": [],
38-
}
39-
40-
for n, p in layers[layer_idx].named_parameters():
41-
layer_info["name"].append(n)
42-
layer_info["size"].append(p.size())
43-
layer_info["requires_grad"].append(p.requires_grad)
44-
45-
layer_info["grad_norm"] = [
46-
norm_tensor.item()
47-
for norm_tensor in clip_grad_norm_(parameters=layers[layer_idx].parameters(),
48-
max_norm=1.0,
49-
return_norm_by_layer=True)[1]
50-
]
51-
52-
df = pd.DataFrame(layer_info)
53-
table_to_print = tabulate.tabulate(df, headers='keys', tablefmt='psql')
54-
55-
print_tabulate_with_header(table_to_print, note)
56-
57-
58-
def inspect_layers(
59-
layers,
60-
layer_idxs: Union[int, List[int]],
61-
notes: Optional[Union[str, List[str]]] = None
62-
):
63-
if isinstance(layer_idxs, int):
64-
layer_idxs = [layer_idxs]
65-
if notes:
66-
if isinstance(notes, str):
67-
notes = [notes]
68-
assert len(layer_idxs) == len(notes) if notes else True
69-
70-
for layer_idx in layer_idxs:
71-
inspect_layer(layers, layer_idx, notes[layer_idx] if notes else None)
72-
73-
74-
def check_layerwise_grad(
75-
layers,
76-
layer_idx: Union[str, int, List[int]] = 'all',
77-
show_details: Optional[str] = 'has_grads',
78-
note: Optional[Union[str, List[str]]] = None,
79-
):
80-
if layer_idx == 'all':
81-
layer_idx = list(range(len(layers)))
82-
elif isinstance(layer_idx, int):
83-
layer_idx = [layer_idx]
84-
85-
distributed_type = get_distributed_type()
86-
87-
all_states = {
88-
"layer_idx": layer_idx,
89-
"requires_grad": [],
90-
"grad_norm": []
91-
}
92-
93-
for idx in layer_idx:
94-
layer_states = {
95-
"names": [],
96-
"requires_grad": [],
97-
"requires_grad_meta": False
98-
}
99-
100-
for n, p in layers[idx].named_parameters():
101-
layer_states["names"].append(n)
102-
layer_states["requires_grad"].append(p.requires_grad)
103-
104-
if all(layer_states["requires_grad"]):
105-
layer_states["requires_grad_meta"] = True
106-
107-
if show_details == 'all':
108-
inspect_layer(layers, idx, f"Layer {idx} detail")
109-
elif show_details == 'has_grads':
110-
if any(layer_states["requires_grad"]):
111-
inspect_layer(layers, idx, f"Layer {idx} detail")
112-
113-
all_states["requires_grad"].append(layer_states['requires_grad_meta'])
114-
all_states["grad_norm"].append(clip_grad_norm_(layers[idx].parameters(), 1.0, distributed_type=distributed_type).item())
115-
116-
df = pd.DataFrame(all_states)
117-
table_to_print = tabulate.tabulate(df, headers='keys', tablefmt='psql', showindex=False)
118-
119-
print_tabulate_with_header(table_to_print, f"{note}, {distributed_type=}")
120-
121-
122-
def clip_grad_norm_(
123-
parameters, max_norm: float, norm_type: float = 2.0,
124-
error_if_nonfinite: bool = False, foreach: Optional[bool] = None,
125-
distributed_type: DistributedType = DistributedType.NO,
126-
return_norm_by_layer: bool = False
127-
) -> Union[Tuple[torch.Tensor, List[torch.Tensor]], torch.Tensor]:
128-
r"""Clip the gradient norm of an iterable of parameters.
129-
130-
The norm is computed over all gradients together, as if they were
131-
concatenated into a single vector. Gradients are modified in-place.
132-
133-
Args:
134-
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
135-
single Tensor that will have gradients normalized
136-
max_norm (float): max norm of the gradients
137-
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
138-
infinity norm.
139-
error_if_nonfinite (bool): if True, an error is thrown if the total
140-
norm of the gradients from :attr:`parameters` is ``nan``,
141-
``inf``, or ``-inf``. Default: False (will switch to True in the future)
142-
foreach (bool): use the faster foreach-based implementation.
143-
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
144-
fall back to the slow implementation for other device types.
145-
Default: ``None``
146-
147-
Returns:
148-
Total norm of the parameter gradients (viewed as a single vector).
149-
"""
150-
if isinstance(parameters, torch.Tensor):
151-
parameters = [parameters]
152-
if distributed_type == DistributedType.DEEPSPEED:
153-
# from deepspeed.utils import safe_get_full_grad
154-
# grads = [safe_get_full_grad(p) for p in parameters]
155-
return torch.tensor(0.)
156-
else:
157-
grads = [p.grad for p in parameters if p.grad is not None]
158-
# print(f'torch grads {grads=}')
159-
max_norm = float(max_norm)
160-
norm_type = float(norm_type)
161-
if len(grads) == 0:
162-
return torch.tensor(0.)
163-
first_device = grads[0].device
164-
grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
165-
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
166-
167-
norms: List[Tensor] = []
168-
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
169-
if (
170-
(foreach is None and _has_foreach_support(device_grads, device))
171-
or (foreach and _device_has_foreach_support(device))
172-
):
173-
norms.extend(torch._foreach_norm(device_grads, norm_type))
174-
elif foreach:
175-
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
176-
else:
177-
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
178-
179-
# print(f'torch norms {norms=}')
180-
total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
181-
# print(f'torch total_norm {total_norm=}')
182-
183-
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
184-
raise RuntimeError(
185-
f'The total norm of order {norm_type} for gradients from '
186-
'`parameters` is non-finite, so it cannot be clipped. To disable '
187-
'this error and scale the gradients by the non-finite norm anyway, '
188-
'set `error_if_nonfinite=False`')
189-
clip_coef = max_norm / (total_norm + 1e-6)
190-
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
191-
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
192-
# when the gradients do not reside in CPU memory.
193-
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
194-
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
195-
if (
196-
(foreach is None and _has_foreach_support(device_grads, device))
197-
or (foreach and _device_has_foreach_support(device))
198-
):
199-
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
200-
elif foreach:
201-
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
202-
else:
203-
clip_coef_clamped_device = clip_coef_clamped.to(device)
204-
for g in device_grads:
205-
g.mul_(clip_coef_clamped_device)
206-
207-
# print(f'torch total_norm at end {total_norm=}')
208-
return (total_norm, norms) if return_norm_by_layer else total_norm
209-
210-
21133
def get_decay_parameter_names(model: Union[PreTrainedModel, nn.Module]) -> List[str]:
21234
"""
21335
From transformers.trainer
@@ -261,6 +83,12 @@ def get_parameter_names_in_param_groups(
26183
return parameter_names
26284

26385

86+
def get_parameter_names_require_grads(
87+
model: Union[PreTrainedModel, nn.Module],
88+
) -> List[str]:
89+
return [n for n, p in model.named_parameters() if p.requires_grad]
90+
91+
26492
def guess_grad_norms_from_pg(
26593
parameter_names: List[Dict[str, str]],
26694
all_norms: List[torch.Tensor],
@@ -272,12 +100,21 @@ def guess_grad_norms_from_pg(
272100
"layer": [],
273101
"grad_norm": [],
274102
}
275-
for pg_names in parameter_names:
103+
has_guess = False
104+
pg_note = None
105+
106+
for pg_idx, pg_names in enumerate(parameter_names):
276107
if len(pg_names["parameter_names"]) == len(all_norms):
277108
all_grad_norms["name"] = pg_names["parameter_names"]
278109
all_grad_norms["grad_norm"] = [norm_tensor.item() for norm_tensor in all_norms]
110+
if not has_guess:
111+
has_guess = True
112+
pg_note = 'Parameter group with weight decay' if pg_idx == 0 else 'Parameter group without weight decay'
113+
else:
114+
print("Failed to guess grad norms from parameter groups according to group length.")
115+
return
279116

280-
if not all_grad_norms["name"]:
117+
if not has_guess:
281118
return
282119

283120
layer_pattern = re.compile(r'transformer\.h\.(\d+)\.')
@@ -294,10 +131,43 @@ def guess_grad_norms_from_pg(
294131

295132
if not separate_by_layer:
296133
table_to_print = tabulate.tabulate(df, headers='keys', tablefmt='psql', showindex=False)
297-
print_tabulate_with_header(table_to_print)
134+
print_tabulate_with_header(table_to_print, pg_note)
135+
else:
136+
for layer_idx in df["layer"].unique():
137+
table_to_print = tabulate.tabulate(
138+
df[df["layer"] == layer_idx], headers='keys', tablefmt='psql', showindex=False
139+
)
140+
print_tabulate_with_header(table_to_print, f"Layer {layer_idx}, {pg_note}")
141+
142+
143+
def guess_grad_norms_from_hf_trainer(
144+
parameter_names: List[str],
145+
all_norms: List[torch.Tensor],
146+
separate_by_layer: bool = False,
147+
note: Optional[str] = None
148+
):
149+
all_grad_norms = {
150+
"name": parameter_names,
151+
"layer": [],
152+
"grad_norm": [norm_tensor.item() for norm_tensor in all_norms],
153+
}
154+
155+
layer_pattern = re.compile(r'transformer\.h\.(\d+)\.')
156+
for name in all_grad_norms["name"]:
157+
layer_match = layer_pattern.search(name)
158+
if layer_match:
159+
all_grad_norms["layer"].append(int(layer_match.group(1)))
160+
else:
161+
all_grad_norms["layer"].append('other')
162+
163+
df = pd.DataFrame(all_grad_norms)
164+
165+
if not separate_by_layer:
166+
table_to_print = tabulate.tabulate(df, headers='keys', tablefmt='psql', showindex=False)
167+
print_tabulate_with_header(table_to_print, note)
298168
else:
299169
for layer_idx in df["layer"].unique():
300170
table_to_print = tabulate.tabulate(
301171
df[df["layer"] == layer_idx], headers='keys', tablefmt='psql', showindex=False
302172
)
303-
print_tabulate_with_header(table_to_print, f"Layer {layer_idx}")
173+
print_tabulate_with_header(table_to_print, f"Layer {layer_idx}, {note}")

0 commit comments

Comments
 (0)