Skip to content

Commit 82f9af9

Browse files
committed
Update lm.py
float32 cast during norm
1 parent c9b0107 commit 82f9af9

File tree

1 file changed

+1
-1
lines changed
  • stable_audio_tools/training

1 file changed

+1
-1
lines changed

stable_audio_tools/training/lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper,
236236

237237
filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
238238
fakes = fakes / fakes.abs().max()
239-
fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
239+
fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu()
240240
torchaudio.save(filename, fakes, self.sample_rate)
241241

242242
log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,

0 commit comments

Comments
 (0)