Skip to content

Commit e56421f

Browse files
authored
Update miner.py
1 parent ea16987 commit e56421f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

minetorch/miner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,12 +499,11 @@ def run_train_iteration(self, index, data, train_iters):
499499
total_iters=train_iters,
500500
iteration=self.current_train_iteration,
501501
)
502-
if self.amp:
502+
if self.amp and self.amp_scaler:
503503
with torch.cuda.amp.autocast():
504504
_, loss = self._forward(data)
505505
seperate_loss = loss / self.accumulated_iter
506-
if self.amp_scaler:
507-
seperate_loss = self.scaler.scale(seperate_loss)
506+
seperate_loss = self.scaler.scale(seperate_loss)
508507
else:
509508
_, loss = self._forward(data)
510509
seperate_loss = loss / self.accumulated_iter

0 commit comments

Comments
 (0)