-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
496 lines (417 loc) · 21.3 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import importlib.metadata
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data import DataLoader
from transformers import Trainer
from transformers.integrations.deepspeed import deepspeed_init
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer_pt_utils import (
EvalLoopContainer,
IterableDatasetShard,
find_batch_size,
nested_detach,
)
from transformers.trainer_utils import (
EvalLoopOutput,
EvalPrediction,
denumpify_detensorize,
has_length,
is_main_process,
)
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, is_torch_xla_available, logging
if is_peft_available():
from peft import PeftModel
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from transformers.trainer_pt_utils import smp_forward_only, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
def _is_peft_model(model):
if is_peft_available():
classes_to_check = (PeftModel,) if is_peft_available() else ()
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
from peft import PeftMixedModel
classes_to_check = (*classes_to_check, PeftMixedModel)
return isinstance(model, classes_to_check)
return False
logger = logging.get_logger(__name__)
def greedy_search(
model,
audio_features: torch.Tensor,
hyps_ls=None,
) -> List[int]:
audio_features = audio_features[0]
blank_id = model.config.blk_token_ids
hyps_ls = hyps_ls if hyps_ls else [blank_id]
decoder_out = model.get_text_features(torch.tensor([hyps_ls], device=model.device))[0][0]
max_frame_num = audio_features.shape[0]
cur_frame_num = 0
max_utt_num = 10
cur_utt_num = 0
while cur_frame_num < max_frame_num:
if cur_utt_num >= max_utt_num:
cur_frame_num += 1
continue
frame = audio_features[cur_frame_num]
joint_out = model.joint_network(frame + decoder_out)
pred = joint_out.log_softmax(-1).argmax(-1).item()
if pred != model.config.blk_token_ids:
hyps_ls.append(pred)
decoder_out = model.get_text_features(torch.tensor([hyps_ls], device=model.device))[0][-1]
cur_utt_num += 1
else:
cur_frame_num += 1
cur_utt_num = 0
return hyps_ls
class TransformerTransducerTrainer(Trainer):
def __init__(self, example_sample: Dict[str, torch.Tensor] = None, **kwagrs):
super().__init__(**kwagrs)
self.example_sample = example_sample
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
# NOTE: Compute Fast RNN-T loss gradient
if (
self.state.global_step % self.args.rnn_t_grad_img_save_iterval == 0
and is_main_process(self.args.local_rank)
and model.training
and self.example_sample
and self.args.rnn_t_grad_img_save_path
):
with torch.no_grad():
model.eval()
example_sample = {k: v.to(model.device) for k, v in self.example_sample.items()}
chack_outputs = model(**example_sample)
total_grad = chack_outputs.total_grad
loss = chack_outputs.loss
total_grad = total_grad.detach().cpu().permute(0, 2, 1)
loss = loss.detach().cpu().item()
title = f"[{self.state.global_step}]{loss:.2f}-grad"
img_save_path = f"{self.args.rnn_t_grad_img_save_path}/{self.state.global_step}-grad_img.png"
plt.matshow(total_grad.unbind(0)[0].t(), origin="lower")
plt.xlabel("t")
plt.ylabel("u")
plt.title(title)
plt.savefig(img_save_path)
model.train()
return (loss, outputs) if return_outputs else loss
@torch.no_grad()
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
input_features, attention_mask = inputs["input_features"][0], inputs["attention_mask"][0]
logits = list()
for input_features, attention_mask in zip(inputs["input_features"], inputs["attention_mask"]):
chunk_num = (
math.ceil(input_features.shape[0] / model.config.one_sec_mel_shape)
* model.config.one_sec_mel_shape
)
mems = None
hyps_ls = None
for idx in range(0, chunk_num, model.config.one_sec_mel_shape):
chunk_features = input_features[idx : idx + model.config.one_sec_mel_shape]
chunk_mask = attention_mask[idx : idx + model.config.one_sec_mel_shape]
audiofeatures, mems = model.get_audio_features(
input_features=chunk_features.unsqueeze(0),
attention_mask=chunk_mask.unsqueeze(0),
mems=mems,
)
hyps_ls = greedy_search(model, audiofeatures, hyps_ls)
logits.append(torch.tensor(hyps_ls, device=model.device))
logits = torch.nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-100)
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = self.args.eval_batch_size
logger.info(f"\n***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
model.eval()
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
if args.past_index >= 0:
self._past = None
# Initialize containers
all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_total_grad = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
metrics = None
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if is_torch_xla_available():
xm.mark_step()
# Update containers
if losses is not None:
losses = self.gather_function((losses.repeat(batch_size)))
all_losses.add(losses)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function((inputs_decode))
if not self.args.batch_eval_metrics or description == "Prediction":
all_inputs.add(inputs_decode)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function((logits))
if not self.args.batch_eval_metrics or description == "Prediction":
all_preds.add(logits)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels))
if not self.args.batch_eval_metrics or description == "Prediction":
all_labels.add(labels)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and logits is not None and labels is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels),
compute_result=is_last_step,
)
del losses, logits, labels, inputs
torch.cuda.empty_cache()
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
all_losses.to_cpu_and_numpy()
all_preds.to_cpu_and_numpy()
all_labels.to_cpu_and_numpy()
all_inputs.to_cpu_and_numpy()
del losses, logits, labels, inputs
torch.cuda.empty_cache()
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
all_losses = all_losses.get_arrays()
all_preds = all_preds.get_arrays()
all_labels = all_labels.get_arrays()
all_inputs = all_inputs.get_arrays()
all_total_grad = all_total_grad.get_arrays()
# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples
# Metrics!
if (
self.compute_metrics is not None
and all_preds is not None
and all_labels is not None
and not self.args.batch_eval_metrics
):
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs),
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
elif metrics is None:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)