Skip to content

Commit

Permalink
✨ add --train_mse
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jul 9, 2024
1 parent 7eff3ec commit 7e54bb7
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions modules/finetune/train_speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def train_speaker_embeddings(
gpt,
batch_size=16,
epochs=10,
train_text=True,
train_text=False,
train_mse=False,
speaker_embeds=None,
):
tokenizer = chat.pretrain_models["tokenizer"]
Expand Down Expand Up @@ -180,15 +181,16 @@ def train_speaker_embeddings(
audio_hidden_states[:, :-1].transpose(1, 2)
).transpose(1, 2)
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
loss += 0.01 * mse_loss

optimizer.zero_grad()

if train_mse:
loss += 0.01 * mse_loss

if train_text:
# just for test
text_loss.backward()
else:
loss.backward()
loss += 0.01 * text_loss

loss.backward()
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
optimizer.step()
logger.meters["loss"].update(loss.item(), n=batch_size)
Expand Down Expand Up @@ -220,6 +222,7 @@ def train_speaker_embeddings(
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--train_text", action="store_true", help="train text loss")
parser.add_argument("--train_mse", action="store_true", help="train mse loss")
# 初始化 speaker
parser.add_argument("--init_speaker", type=str)
parser.add_argument(
Expand All @@ -239,6 +242,7 @@ def train_speaker_embeddings(
tar_path: str | None = args.tar_path
tar_in_memory: bool = args.tar_in_memory
train_text: bool = args.train_text
train_mse: bool = args.train_mse
# gpt_lora: bool = args.gpt_lora
# gpt_kbit: int = args.gpt_kbit
save_folder: str = args.save_folder
Expand Down Expand Up @@ -301,6 +305,5 @@ def train_speaker_embeddings(
--save_folder ./data \
--init_speaker ./data/speakers/Bob.pt \
--epochs 100 \
--batch_size 6 \
--train_text
--batch_size 6
"""

0 comments on commit 7e54bb7

Please sign in to comment.