Skip to content

Latest commit

 

History

History
74 lines (66 loc) · 5.93 KB

File metadata and controls

74 lines (66 loc) · 5.93 KB

Training Arguments Explanation

This document provides a detailed explanation of the training arguments used in our scripts.

Core Arguments

  • --model_paths: Path to the base model directory (e.g., ./FLUX.1-Kontext-dev).
  • --output_path: Directory to save model checkpoints and logs.
  • --learning_rate: The learning rate for the optimizer (e.g., 1e-5).
  • --num_epochs: Total number of training epochs.
  • --batch_size: Per-GPU batch size. Adjust based on your GPU memory.
  • --resume: Resume training from the latest checkpoint in --output_path.

Dataset Arguments

  • --dataset_base_path: Comma-separated absolute paths to the training datasets. The order matters.
  • --dataset_metadata_path: Path to the metadata file (CSV or Parquet) containing relative file paths.
  • --data_file_keys: Column names in the metadata file to be used, e.g., kontext_images,image.
  • --dataset_repeat: Number of times to repeat a dataset within an epoch. Useful for balancing datasets of different sizes.
  • --height, --width: Target resolution for training images.

Model & Training Strategy

  • --trainable_models: Specifies which parts of the model to train. For fine-tuning, set to "dit".
  • --extra_inputs: Specifies additional input keys besides the main image and prompt. In our case, it's "kontext_images".
  • --default_caption: The default text prompt used for training (e.g., "Transform to normal map...").
  • --multi_res_noise: (Flag) Use multi-resolution noise for potentially faster convergence, inspired by Marigold.
  • --with_mask: (Flag) Compute the loss only on valid masked areas (e.g., where ground truth depth is available).
  • --using_sqrt: (Flag, Depth only) Use our theoretically optimal square-root normalization for depth.
  • --extra_loss: Name of the pixel-space consistency loss to apply (e.g., "cycle_consistency_normal_estimation").
  • --deterministic_flow: (Flag) Use a fixed random seed for the initial noise to create a pseudo-deterministic path.

LoRA Specific Arguments

  • --lora_base_model: The base model to which LoRA is applied, typically "dit".
  • --lora_target_modules: Comma-separated list of modules to apply LoRA to.
  • --lora_rank: The rank of the LoRA decomposition matrices (e.g., 64).
  • --align_to_opensource_format: (Flag) Save LoRA weights in a community-standard format.

Performance & Logging

  • --use_gradient_checkpointing: (Flag) Enable to save GPU memory at the cost of a small slowdown.
  • --adamw8bit: (Flag) Use the 8-bit AdamW optimizer to reduce memory usage.
  • --save_steps: Save a full model checkpoint every N steps.
  • --eval_steps: Perform a quick evaluation on a small subset every N steps to monitor progress.
  • --eval_file_list: Path to the text file containing the list of images for evaluation during training.