Skip to content

Commit 226e74b

Browse files
authored
Removes SageMakerTrainer code but keeps class as wrapper (#11587)
* removed all old code * make quality
1 parent 084a187 commit 226e74b

File tree

1 file changed

+0
-292
lines changed

1 file changed

+0
-292
lines changed

src/transformers/sagemaker/trainer_sm.py

Lines changed: 0 additions & 292 deletions
Original file line numberDiff line numberDiff line change
@@ -11,312 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
import warnings
16-
from typing import Any, Dict, List, Optional, Tuple, Union
1715

18-
import numpy as np
19-
import torch
20-
from torch import nn
21-
from torch.utils.data.dataset import Dataset
22-
from torch.utils.data.distributed import DistributedSampler
23-
24-
from ..file_utils import WEIGHTS_NAME, is_torch_tpu_available
25-
from ..modeling_utils import PreTrainedModel, unwrap_model
2616
from ..trainer import Trainer
27-
from ..trainer_pt_utils import (
28-
DistributedLengthGroupedSampler,
29-
DistributedSamplerWithLoop,
30-
SequentialDistributedSampler,
31-
nested_detach,
32-
nested_numpify,
33-
reissue_pt_warnings,
34-
)
35-
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
3617
from ..utils import logging
37-
from .training_args_sm import is_sagemaker_model_parallel_available
3818

3919

4020
logger = logging.get_logger(__name__)
4121

4222

43-
if is_sagemaker_model_parallel_available():
44-
import smdistributed.modelparallel.torch as smp
45-
46-
@smp.step()
47-
def forward_backward(model, inputs, gradient_accumulation_steps=1):
48-
outputs = model(**inputs)
49-
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
50-
loss /= gradient_accumulation_steps
51-
model.backward(loss)
52-
return loss
53-
54-
@smp.step()
55-
def forward_only(model, inputs):
56-
return model(**inputs)
57-
58-
def smp_gather(tensor):
59-
if isinstance(tensor, (list, tuple)):
60-
return type(tensor)(smp_gather(t) for t in tensor)
61-
elif isinstance(tensor, dict):
62-
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
63-
elif not isinstance(tensor, torch.Tensor):
64-
raise TypeError(
65-
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
66-
)
67-
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
68-
return torch.cat([t.cpu() for t in all_tensors], dim=0)
69-
70-
def nested_smp_concat(tensor):
71-
if isinstance(tensor, (list, tuple)):
72-
return type(tensor)(nested_smp_concat(t) for t in tensor)
73-
elif isinstance(tensor, dict):
74-
return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()})
75-
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
76-
# which is also the name of the decorator so Python is confused.
77-
return tensor.concat().detach().cpu()
78-
79-
8023
class SageMakerTrainer(Trainer):
8124
def __init__(self, args=None, **kwargs):
8225
warnings.warn(
8326
"`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
8427
"instead.",
8528
FutureWarning,
8629
)
87-
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
8830
super().__init__(args=args, **kwargs)
89-
90-
def is_world_process_zero(self) -> bool:
91-
"""
92-
Whether or not this process is the global main process (when training in a distributed fashion on several
93-
machines, this is only going to be :obj:`True` for one process).
94-
"""
95-
if self.is_model_parallel_enabled:
96-
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
97-
else:
98-
return super().is_world_process_zero()
99-
100-
def _get_train_sampler(self):
101-
if self.is_model_parallel_enabled:
102-
if self.args.group_by_length:
103-
return DistributedLengthGroupedSampler(
104-
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
105-
)
106-
elif not self.args.dataloader_drop_last:
107-
return DistributedSamplerWithLoop(
108-
self.train_dataset,
109-
self.args.per_device_train_batch_size,
110-
num_replicas=smp.dp_size(),
111-
rank=smp.dp_rank(),
112-
)
113-
else:
114-
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
115-
else:
116-
return super()._get_train_sampler()
117-
118-
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
119-
if self.is_model_parallel_enabled:
120-
return SequentialDistributedSampler(
121-
eval_dataset,
122-
num_replicas=smp.dp_size(),
123-
rank=smp.dp_rank(),
124-
batch_size=self.args.per_device_eval_batch_size,
125-
)
126-
else:
127-
return super()._get_eval_sampler(eval_dataset)
128-
129-
def _wrap_model(self, model, training=True):
130-
if self.is_model_parallel_enabled:
131-
# Wrapping the base model twice in a DistributedModel will raise an error.
132-
if isinstance(self.model_wrapped, smp.model.DistributedModel):
133-
return self.model_wrapped
134-
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
135-
else:
136-
return super()._wrap_model(model)
137-
138-
def create_optimizer_and_scheduler(self, num_training_steps: int):
139-
super().create_optimizer_and_scheduler(num_training_steps)
140-
if self.is_model_parallel_enabled:
141-
self.optimizer = smp.DistributedOptimizer(self.optimizer)
142-
143-
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
144-
if self.is_model_parallel_enabled:
145-
model.train()
146-
inputs = self._prepare_inputs(inputs)
147-
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
148-
return loss_mb.reduce_mean().detach().to(self.args.device)
149-
else:
150-
return super().training_step(model, inputs)
151-
152-
def _gather_and_numpify(self, tensors, name):
153-
if tensors is None:
154-
return
155-
if self.is_model_parallel_enabled:
156-
tensors = smp_gather(tensors)
157-
return nested_numpify(tensors)
158-
else:
159-
return super()._gather_and_numpify(tensors, name)
160-
161-
def save_model(self, output_dir: Optional[str] = None):
162-
"""
163-
Will save the model, so you can reload it using :obj:`from_pretrained()`.
164-
165-
Will only save from the world_master process (unless in TPUs).
166-
"""
167-
if self.is_model_parallel_enabled:
168-
self._save_smp(output_dir)
169-
elif is_torch_tpu_available():
170-
self._save_tpu(output_dir)
171-
elif self.is_world_process_zero():
172-
self._save(output_dir)
173-
174-
# If on sagemaker and we are saving the main model (not a checkpoint so output_dir=None), save a copy to
175-
# SM_MODEL_DIR for easy deployment.
176-
if output_dir is None and os.getenv("SM_MODEL_DIR") is not None:
177-
self.save_model(output_dir=os.getenv("SM_MODEL_DIR"))
178-
179-
def _save_smp(self, output_dir: Optional[str] = None):
180-
if smp.dp_rank() != 0:
181-
return
182-
output_dir = output_dir if output_dir is not None else self.args.output_dir
183-
os.makedirs(output_dir, exist_ok=True)
184-
logger.info(f"Saving model checkpoint to {output_dir}")
185-
# Calling the state_dict needs to be done on the wrapped model
186-
state_dict = self.model_wrapped.state_dict()
187-
188-
# Rest of the save is done for the main process only
189-
if self.is_world_process_zero():
190-
model = self.model
191-
if not isinstance(model, PreTrainedModel):
192-
model = unwrap_model(model)
193-
if isinstance(model, PreTrainedModel):
194-
model.save_pretrained(output_dir, state_dict=state_dict)
195-
else:
196-
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
197-
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
198-
199-
if self.tokenizer is not None:
200-
self.tokenizer.save_pretrained(output_dir)
201-
202-
# Good practice: save your training arguments together with the trained model
203-
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
204-
205-
def _save_checkpoint(self, model, trial, metrics=None):
206-
if self.is_model_parallel_enabled:
207-
if smp.dp_rank() != 0:
208-
return
209-
210-
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
211-
212-
run_dir = self.args.output_dir
213-
self.store_flos()
214-
215-
output_dir = os.path.join(run_dir, checkpoint_folder)
216-
self.save_model(output_dir)
217-
# Consolidate the state dict on all processed of dp_rank 0
218-
opt_state_dict = self.optimizer.state_dict()
219-
# Save it and the scheduler on the main process
220-
if self.is_world_process_zero():
221-
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
222-
with warnings.catch_warnings(record=True) as caught_warnings:
223-
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
224-
reissue_pt_warnings(caught_warnings)
225-
226-
# Determine the new best metric / best model checkpoint
227-
if metrics is not None and self.args.metric_for_best_model is not None:
228-
metric_to_check = self.args.metric_for_best_model
229-
if not metric_to_check.startswith("eval_"):
230-
metric_to_check = f"eval_{metric_to_check}"
231-
metric_value = metrics[metric_to_check]
232-
233-
operator = np.greater if self.args.greater_is_better else np.less
234-
if (
235-
self.state.best_metric is None
236-
or self.state.best_model_checkpoint is None
237-
or operator(metric_value, self.state.best_metric)
238-
):
239-
self.state.best_metric = metric_value
240-
self.state.best_model_checkpoint = output_dir
241-
242-
# Save the Trainer state
243-
if self.is_world_process_zero():
244-
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
245-
246-
# Maybe delete some older checkpoints.
247-
if self.is_world_process_zero():
248-
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
249-
else:
250-
super()._save_checkpoint(self, model, trial, metrics=metrics)
251-
252-
def _load_optimizer_and_scheduler(self, checkpoint):
253-
"""If optimizer and scheduler states exist, load them."""
254-
if self.is_model_parallel_enabled:
255-
if checkpoint is None:
256-
return
257-
258-
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
259-
os.path.join(checkpoint, "scheduler.pt")
260-
):
261-
self.optimizer.load_state_dict(
262-
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
263-
)
264-
with warnings.catch_warnings(record=True) as caught_warnings:
265-
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
266-
reissue_pt_warnings(caught_warnings)
267-
else:
268-
super()._load_optimizer_and_scheduler(checkpoint)
269-
270-
def prediction_step(
271-
self,
272-
model: nn.Module,
273-
inputs: Dict[str, Union[torch.Tensor, Any]],
274-
prediction_loss_only: bool,
275-
ignore_keys: Optional[List[str]] = None,
276-
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
277-
if self.is_model_parallel_enabled:
278-
has_labels = all(inputs.get(k) is not None for k in self.label_names)
279-
inputs = self._prepare_inputs(inputs)
280-
281-
if ignore_keys is None:
282-
if hasattr(self.model, "config"):
283-
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
284-
else:
285-
ignore_keys = []
286-
287-
with torch.no_grad():
288-
raw_outputs = forward_only(model, inputs)
289-
if has_labels:
290-
if isinstance(raw_outputs, dict):
291-
loss_mb = raw_outputs["loss"]
292-
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
293-
else:
294-
loss_mb = raw_outputs[0]
295-
logits_mb = raw_outputs[1:]
296-
297-
loss = loss_mb.reduce_mean().detach().cpu()
298-
logits = nested_smp_concat(logits_mb)
299-
else:
300-
loss = None
301-
if isinstance(raw_outputs, dict):
302-
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
303-
else:
304-
logits_mb = raw_outputs
305-
logits = nested_smp_concat(logits_mb)
306-
307-
if prediction_loss_only:
308-
return (loss, None, None)
309-
310-
if len(logits) == 1:
311-
logits = logits[0]
312-
313-
if has_labels:
314-
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
315-
if len(labels) == 1:
316-
labels = labels[0]
317-
else:
318-
labels = None
319-
320-
return (loss, logits, labels)
321-
else:
322-
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

0 commit comments

Comments
 (0)