This document provides a detailed explanation of the training arguments used in our scripts.
--model_paths: Path to the base model directory (e.g.,./FLUX.1-Kontext-dev).--output_path: Directory to save model checkpoints and logs.--learning_rate: The learning rate for the optimizer (e.g.,1e-5).--num_epochs: Total number of training epochs.--batch_size: Per-GPU batch size. Adjust based on your GPU memory.--resume: Resume training from the latest checkpoint in--output_path.
--dataset_base_path: Comma-separated absolute paths to the training datasets. The order matters.--dataset_metadata_path: Path to the metadata file (CSV or Parquet) containing relative file paths.--data_file_keys: Column names in the metadata file to be used, e.g.,kontext_images,image.--dataset_repeat: Number of times to repeat a dataset within an epoch. Useful for balancing datasets of different sizes.--height,--width: Target resolution for training images.
--trainable_models: Specifies which parts of the model to train. For fine-tuning, set to"dit".--extra_inputs: Specifies additional input keys besides the main image and prompt. In our case, it's"kontext_images".--default_caption: The default text prompt used for training (e.g.,"Transform to normal map...").--multi_res_noise: (Flag) Use multi-resolution noise for potentially faster convergence, inspired by Marigold.--with_mask: (Flag) Compute the loss only on valid masked areas (e.g., where ground truth depth is available).--using_sqrt: (Flag, Depth only) Use our theoretically optimal square-root normalization for depth.--extra_loss: Name of the pixel-space consistency loss to apply (e.g.,"cycle_consistency_normal_estimation").--deterministic_flow: (Flag) Use a fixed random seed for the initial noise to create a pseudo-deterministic path.
--lora_base_model: The base model to which LoRA is applied, typically"dit".--lora_target_modules: Comma-separated list of modules to apply LoRA to.--lora_rank: The rank of the LoRA decomposition matrices (e.g.,64).--align_to_opensource_format: (Flag) Save LoRA weights in a community-standard format.
--use_gradient_checkpointing: (Flag) Enable to save GPU memory at the cost of a small slowdown.--adamw8bit: (Flag) Use the 8-bit AdamW optimizer to reduce memory usage.--save_steps: Save a full model checkpoint every N steps.--eval_steps: Perform a quick evaluation on a small subset every N steps to monitor progress.--eval_file_list: Path to the text file containing the list of images for evaluation during training.