Skip to content

grad_norm=NaN During NFT Training on Flux1.d-dev #134

@rlustc

Description

@rlustc

DiffusionNFT 训练中频繁出现 grad_norm=NaN 这导致我的训练完全不可行 具体的训练yaml文件如下

# Environment Configuration
launcher: "accelerate"  # Options: accelerate
config_file: config/accelerate_configs/fsdp_full_shard.yaml  # Path to distributed config file (optional)
num_processes: 8  # Number of processes to launch (overrides config file)
main_process_port: 29500
mixed_precision: "bf16"  # Options: no, fp16, bf16

# Data Configuration
data:
  dataset_dir: "dataset/pickscore"  # Path to dataset folder
  preprocessing_batch_size: 8  # Batch size for preprocessing
  dataloader_num_workers: 16  # Number of workers for DataLoader
  force_reprocess: false  # Force reprocessing of the dataset
  cache_dir: "~/.cache/flow_factory/datasets" # Cache directory for preprocessed datasets
  max_dataset_size: 1024  # Limit the maximum number of samples in the dataset
  sampler_type: "auto"  # Options: auto, distributed_k_repeat, group_contiguous

# Model Configuration
model:
  finetune_type: 'full' # Options: full, lora
  target_modules: "default" # Options: all, default, or list of module names like ["to_k", "to_q", "to_v", "to_out.0"]
  model_name_or_path: "/data/aigc/liangyzh_intern/Lirui/Flux1.0-dev"  # HuggingFace model ID or local path
  model_type: "flux1"
  resume_path: null # Path to load previous checkpoint/lora adapter
  resume_type: null # Options: lora, full, state. Null to auto-detect based on `finetune_type`

log:
  run_name: null  # Run name (auto: {model_type}_{finetune_type}_{trainer_type}_{timestamp})
  project: "Flow-Factory"  # Project name for logging
  logging_backend: "tensorboard"  # Options: wandb, swanlab, none
  save_dir: "saves/"  # Directory to save model checkpoints and logs
  save_freq: 40  # Save frequency in epochs (0 to disable)
  save_model_only: true  # Save only the model weights (not optimizer, scheduler, etc.)

# Training Configuration
train:
  # Trainer settings
  trainer_type: 'nft'
  advantage_aggregation: 'sum' # Options: 'sum', 'gdpo'
  nft_beta: 1
  # `Old` Policy settings
  off_policy: true # Whether to use ema parameters for sampling off-policy data.
  ema_decay_schedule: "piecewise_linear"  # Decay schedule for EMA. Options: ['constant', 'power', 'linear', 'piecewise_linear', 'cosine', 'warmup_cosine']
  flat_steps: 0
  ramp_rate: 0.001
  ema_decay: 0.5  # EMA decay rate (0 to disable)
  ema_update_interval: 1  # EMA update interval (in epochs)
  ema_device: "cpu"  # Device to store EMA model (options: cpu, cuda)
  # Training Timestep distribution
  num_train_timesteps: 2 # Set null to all steps
  time_sampling_strategy: discrete # Options: uniform, logit_normal, discrete, discrete_with_init, discrete_wo_init
  time_shift: 3.0
  timestep_range: 0.7 # Select fraction of timesteps to train on
  # KL div
  kl_type: 'v-based'
  kl_beta: 0 # KL divergence beta, 0 to disable
  ref_param_device: 'cpu' # Options: cpu, cuda
  # Clipping
  adv_clip_range: 5.0  # Advantage clipping range

  # Sampling
  resolution: 384  # Can be int or [height, width]
  num_inference_steps: 8  # Number of timesteps
  guidance_scale: 3.5  # Guidance scale for sampling

  # Batch and sampling
  per_device_batch_size: 1  # Batch size per device
  group_size: 16  # Group size for GRPO sampling
  global_std: false  # Use global std for advantage normalization
  unique_sample_num_per_epoch: 48  # Unique samples per group
  gradient_step_per_epoch: 1  # Gradient steps per epoch. The first step is on-policy, the rest are off-policy.
  gradient_accumulation_steps: auto  # Options: auto, or positive integer. When set, `gradient_step_per_epoch` is ignored.
    
  # Optimization
  learning_rate: 1.0e-5  # Initial learning rate
  adam_weight_decay: 1.0e-4  # AdamW weight decay
  adam_betas: [0.9, 0.999]  # AdamW betas
  adam_epsilon: 1.0e-8  # AdamW epsilon
  max_grad_norm: 1.0  # Max gradient norm for clipping

  # Gradient checkpointing
  enable_gradient_checkpointing: true  # Enable gradient checkpointing to save memory with extra compute

  # Seed
  seed: 42  # Random seed

# Scheduler Configuration
scheduler:
  dynamics_type: "ODE"  # Options: Flow-SDE, Dance-SDE, CPS, ODE

# Evaluation settings
eval:
  resolution: 1024  # Evaluation resolution
  per_device_batch_size: 1  # Eval batch size
  guidance_scale: 3.5  # Guidance scale for sampling
  num_inference_steps: 28  # Number of eval timesteps
  eval_freq: 20  # Eval frequency in epochs (0 to disable)
  seed: 42  # Eval seed (defaults to training seed)

# Reward Model Configuration
rewards:
  - name: "hps"
    reward_model: "HPSv2"
    hps_ckpt_path: "/data/aigc/liangyzh_intern/zqni/DanceGRPO-main/HPSv2/ckpt_all/HPS_v2.1_compressed.pt"
    clip_pretrained_path: "/data/aigc/liangyzh_intern/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    hps_version: "v2.1"
    batch_size: 16
    dtype: bfloat16
    device: "cuda"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions