Skip to content

Commit fab61f6

Browse files
authored
Merge pull request #394 from modelscope/wan-train-update
fix swanlab after test
2 parents 91f77d2 + 6b67a11 commit fab61f6

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

diffsynth/trainers/text_to_image.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,17 @@ def add_general_parsers(parser):
250250
default=None,
251251
help="Pretrained LoRA path. Required if the training is resumed.",
252252
)
253+
parser.add_argument(
254+
"--use_swanlab",
255+
default=False,
256+
action="store_true",
257+
help="Whether to use SwanLab logger.",
258+
)
259+
parser.add_argument(
260+
"--swanlab_mode",
261+
default=None,
262+
help="SwanLab mode (cloud or local).",
263+
)
253264
return parser
254265

255266

@@ -270,6 +281,20 @@ def launch_training_task(model, args):
270281
num_workers=args.dataloader_num_workers
271282
)
272283
# train
284+
if args.use_swanlab:
285+
from swanlab.integration.pytorch_lightning import SwanLabLogger
286+
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
287+
swanlab_config.update(vars(args))
288+
swanlab_logger = SwanLabLogger(
289+
project="diffsynth_studio",
290+
name="diffsynth_studio",
291+
config=swanlab_config,
292+
mode=args.swanlab_mode,
293+
logdir=args.output_path,
294+
)
295+
logger = [swanlab_logger]
296+
else:
297+
logger = None
273298
trainer = pl.Trainer(
274299
max_epochs=args.max_epochs,
275300
accelerator="gpu",
@@ -278,7 +303,8 @@ def launch_training_task(model, args):
278303
strategy=args.training_strategy,
279304
default_root_dir=args.output_path,
280305
accumulate_grad_batches=args.accumulate_grad_batches,
281-
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
306+
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
307+
logger=logger,
282308
)
283309
trainer.fit(model=model, train_dataloaders=train_loader)
284310

examples/wanvideo/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
132132
--steps_per_epoch 500 \
133133
--max_epochs 10 \
134134
--learning_rate 1e-4 \
135-
--lora_rank 4 \
136-
--lora_alpha 4 \
135+
--lora_rank 16 \
136+
--lora_alpha 16 \
137137
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
138138
--accumulate_grad_batches 1 \
139139
--use_gradient_checkpointing

examples/wanvideo/train_wan_t2v.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,17 @@ def parse_args():
423423
default=None,
424424
help="Pretrained LoRA path. Required if the training is resumed.",
425425
)
426+
parser.add_argument(
427+
"--use_swanlab",
428+
default=False,
429+
action="store_true",
430+
help="Whether to use SwanLab logger.",
431+
)
432+
parser.add_argument(
433+
"--swanlab_mode",
434+
default=None,
435+
help="SwanLab mode (cloud or local).",
436+
)
426437
args = parser.parse_args()
427438
return args
428439

@@ -481,6 +492,20 @@ def train(args):
481492
use_gradient_checkpointing=args.use_gradient_checkpointing,
482493
pretrained_lora_path=args.pretrained_lora_path,
483494
)
495+
if args.use_swanlab:
496+
from swanlab.integration.pytorch_lightning import SwanLabLogger
497+
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
498+
swanlab_config.update(vars(args))
499+
swanlab_logger = SwanLabLogger(
500+
project="wan",
501+
name="wan",
502+
config=swanlab_config,
503+
mode=args.swanlab_mode,
504+
logdir=args.output_path,
505+
)
506+
logger = [swanlab_logger]
507+
else:
508+
logger = None
484509
trainer = pl.Trainer(
485510
max_epochs=args.max_epochs,
486511
accelerator="gpu",
@@ -489,6 +514,7 @@ def train(args):
489514
default_root_dir=args.output_path,
490515
accumulate_grad_batches=args.accumulate_grad_batches,
491516
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
517+
logger=logger,
492518
)
493519
trainer.fit(model, dataloader)
494520

0 commit comments

Comments
 (0)