From c37810952c1e07cfb044524da733daba5a7593b8 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 13 May 2026 08:43:00 +0000 Subject: [PATCH 01/17] qwen-image dmd-lora --- .../lightx2v_train/model_zoo/qwen_image.py | 14 +- .../lightx2v_train/trainers/__init__.py | 3 +- .../lightx2v_train/trainers/dmd_lora.py | 370 ++++++++++++++++++ 3 files changed, 384 insertions(+), 3 deletions(-) create mode 100644 lightx2v_train/lightx2v_train/trainers/dmd_lora.py diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index e70b3962a..10566647f 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -24,9 +24,13 @@ def load_components(self): torch_dtype=self.running_dtype, ).to(self.device) self.vae = AutoencoderKLQwenImage.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype) - self.transformer = QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + self.transformer = self.load_transformer() self.vae.requires_grad_(False) + def load_transformer(self): + model_path = self.config["model"]["pretrained_model_name_or_path"] + return QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + def build_pipeline(self): pipe = QwenImagePipeline( scheduler=self.flow_matching, @@ -53,6 +57,9 @@ def encode_to_latent(self, sample): def encode_condition(self, sample): prompt = sample["prompt"] + return self.encode_prompt_condition(prompt) + + def encode_prompt_condition(self, prompt): prompt_embed, prompt_embed_mask = self.text_pipeline.encode_prompt( prompt=prompt, device=self.device, @@ -84,7 +91,10 @@ def prepare_denoiser_input(self, noisy_latent, sample, condition): ) def denoise(self, denoiser_input, timestep_or_sigma, condition): - return self.transformer( + return self.denoise_with_transformer(self.transformer, denoiser_input, timestep_or_sigma, condition) + + def denoise_with_transformer(self, transformer, denoiser_input, timestep_or_sigma, condition): + return transformer( hidden_states=denoiser_input.hidden_states, timestep=timestep_or_sigma, # timestep_or_sigma is in [0, 1] not [0, 1000] guidance=None, diff --git a/lightx2v_train/lightx2v_train/trainers/__init__.py b/lightx2v_train/lightx2v_train/trainers/__init__.py index ee9795f3d..a8c7f8fbc 100644 --- a/lightx2v_train/lightx2v_train/trainers/__init__.py +++ b/lightx2v_train/lightx2v_train/trainers/__init__.py @@ -1,5 +1,6 @@ from lightx2v_train.utils.registry import build_trainer +from .dmd_lora import DmdLoraTrainer from .lora import LoraTrainer -__all__ = ["build_trainer", "LoraTrainer"] +__all__ = ["build_trainer", "DmdLoraTrainer", "LoraTrainer"] diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py new file mode 100644 index 000000000..910b9eb7f --- /dev/null +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -0,0 +1,370 @@ +import os + +import torch +import torch.nn.functional as F +from diffusers.optimization import get_scheduler +from diffusers.utils import convert_state_dict_to_diffusers +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from tqdm.auto import tqdm + +from lightx2v_train.runtime.checkpoint import prune_checkpoints +from lightx2v_train.utils.registry import TRAINER_REGISTER +from lightx2v_train.utils.utils import get_running_dtype + +from .base import BaseTrainer + + +def _linear_shift(mu, t): + return mu / (mu + (1 / t - 1)) + + +def _add_noise(x0, noise, sigma): + sigma = _expand_to(sigma, x0).to(dtype=torch.float32) + return ((1.0 - sigma) * x0.float() + sigma * noise.float()).to(dtype=x0.dtype) + + +def _euler_step(x, velocity, sigma, target_sigma): + sigma = _expand_to(sigma, x).to(dtype=torch.float32) + target_sigma = _expand_to(target_sigma, x).to(dtype=torch.float32) + return x.float() + (target_sigma - sigma) * velocity.float() + + +def _expand_to(value, target): + value = value.to(device=target.device) + while value.ndim < target.ndim: + value = value.view(*value.shape, 1) + return value + + +def _do_cfg(cond_pred, uncond_pred, cfg_scale, cfg_norm): + pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) + if cfg_norm in (None, "none"): + return pred + if cfg_norm == "layer_norm": + cond_norm = torch.norm(cond_pred, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + return pred * (cond_norm / torch.clamp(pred_norm, min=1e-12)) + if cfg_norm == "scalar": + cond_norm = torch.norm(cond_pred) + pred_norm = torch.norm(pred) + return pred * min(1.0, (cond_norm / torch.clamp(pred_norm, min=1e-12)).item()) + raise ValueError(f"Unsupported cfg_norm: {cfg_norm}") + + +def _dmd_loss(latents, x_pred_fake_flow, x_pred_teacher): + with torch.no_grad(): + grad = x_pred_fake_flow - x_pred_teacher + dims = tuple(range(1, latents.ndim)) + normalizer = torch.abs(latents - x_pred_teacher).mean(dim=dims, keepdim=True) + grad = torch.nan_to_num(grad / normalizer) + return 0.5 * F.mse_loss(latents.float(), (latents.float() - grad.float()).detach(), reduction="mean") + + +class _DMDEulerScheduler: + def __init__(self, shift=3.0, device="cuda"): + self.shift = float(shift) + self.device = torch.device(device) + self.num_train_timesteps = 1000 + + def set_timesteps(self, num_inference_steps): + timesteps = torch.linspace( + 1000, + 0, + int(num_inference_steps) + 1, + dtype=torch.float32, + device=self.device, + ) + self.sigmas = _linear_shift(self.shift, timesteps / self.num_train_timesteps) + + def step(self, model_output, step_idx, sample): + sigma = self.sigmas[step_idx].expand(sample.shape[0]).to(sample.device) + sigma_next = self.sigmas[step_idx + 1].expand(sample.shape[0]).to(sample.device) + x0 = sample.float() - _expand_to(sigma, sample).float() * model_output.float() + next_sample = _euler_step(sample, model_output, sigma, sigma_next) + return next_sample.to(sample.dtype), x0.to(sample.dtype) + + +@TRAINER_REGISTER("dmd_lora") +class DmdLoraTrainer(BaseTrainer): + def get_configs(self): + model_config = self.config["model"] + if model_config.get("name") != "qwen_image": + raise ValueError("dmd_lora currently supports model.name: qwen_image only.") + self.running_dtype = get_running_dtype(model_config["running_dtype"]) + + training_config = self.config["training"] + lora_config = training_config.get("lora", {}) + self.lora_rank = lora_config.get("rank", 16) + self.lora_alpha = lora_config.get("alpha", self.lora_rank) + self.lora_target_modules = lora_config.get("target_modules") + + fake_config = training_config.get("fake", {}) + fake_lora_config = fake_config.get("lora", lora_config) + self.fake_lora_rank = fake_lora_config.get("rank", self.lora_rank) + self.fake_lora_alpha = fake_lora_config.get("alpha", self.fake_lora_rank) + self.fake_lora_target_modules = fake_lora_config.get("target_modules", self.lora_target_modules) + + self.gradient_checkpointing = training_config.get("gradient_checkpointing", True) + + optimizer_config = training_config.get("optimizer", {}) + self.optimizer_learning_rate = optimizer_config.get("learning_rate", 1e-4) + self.optimizer_adam_beta1 = optimizer_config.get("adam_beta1", 0.9) + self.optimizer_adam_beta2 = optimizer_config.get("adam_beta2", 0.999) + self.optimizer_weight_decay = optimizer_config.get("weight_decay", 0.01) + self.optimizer_adam_epsilon = optimizer_config.get("adam_epsilon", 1e-8) + + fake_optimizer_config = fake_config.get("optimizer", {}) + self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) + self.fake_optimizer_adam_beta1 = fake_optimizer_config.get("adam_beta1", self.optimizer_adam_beta1) + self.fake_optimizer_adam_beta2 = fake_optimizer_config.get("adam_beta2", self.optimizer_adam_beta2) + self.fake_optimizer_weight_decay = fake_optimizer_config.get("weight_decay", self.optimizer_weight_decay) + self.fake_optimizer_adam_epsilon = fake_optimizer_config.get("adam_epsilon", self.optimizer_adam_epsilon) + + self.lr_scheduler_name = training_config.get("lr_scheduler", "constant") + self.lr_warmup_iters = training_config.get("lr_warmup_iters", 0) + self.max_train_iters = training_config["max_train_iters"] + + self.output_dir = training_config["output_dir"] + self.gradient_accumulation_iters = training_config.get("gradient_accumulation_iters", 1) + self.max_grad_norm = training_config.get("max_grad_norm", 1.0) + self.save_every_iters = training_config.get("save_every_iters", 0) + self.save_total_limit = training_config.get("save_total_limit") + self.save_fake_lora = fake_config.get("save_lora", False) + + dmd_config = training_config.get("dmd", {}) + self.num_inference_steps = int(dmd_config.get("num_inference_steps", 4)) + self.fake_update_ratio = int(dmd_config.get("fake_update_ratio", 1)) + self.guidance_scale = float(dmd_config.get("guidance_scale", 3.0)) + self.negative_prompt = dmd_config.get("negative_prompt", " ") + self.cfg_norm = dmd_config.get("cfg_norm", "layer_norm") + self.min_sigma = float(dmd_config.get("sigma_min", 0.02)) + self.max_sigma = float(dmd_config.get("sigma_max", 1.0)) + self.discrete_samples = int(dmd_config.get("discrete_samples", 1000)) + self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0)) + self.inference_shift = float(dmd_config.get("inference_shift", 3.0)) + + def setup(self): + self.get_configs() + print("[dmd_lora] single-GPU resident mode: student/fake/teacher transformers stay on CUDA") + + self.model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) + self.model.set_lora_trainable() + if self.gradient_checkpointing: + self.model.enable_gradient_checkpointing() + + self.fake_transformer = self.model.load_transformer() + self._add_lora_to_transformer( + self.fake_transformer, + self.fake_lora_rank, + self.fake_lora_alpha, + self.fake_lora_target_modules, + ) + self._set_lora_trainable(self.fake_transformer) + if self.gradient_checkpointing and hasattr(self.fake_transformer, "enable_gradient_checkpointing"): + self.fake_transformer.enable_gradient_checkpointing() + + self.teacher_transformer = self.model.load_transformer() + self.teacher_transformer.requires_grad_(False) + self.teacher_transformer.eval() + + self.optimizer = torch.optim.AdamW( + self.model.trainable_parameters(), + lr=self.optimizer_learning_rate, + betas=(self.optimizer_adam_beta1, self.optimizer_adam_beta2), + weight_decay=self.optimizer_weight_decay, + eps=self.optimizer_adam_epsilon, + ) + self.fake_optimizer = torch.optim.AdamW( + (p for p in self.fake_transformer.parameters() if p.requires_grad), + lr=self.fake_optimizer_learning_rate, + betas=(self.fake_optimizer_adam_beta1, self.fake_optimizer_adam_beta2), + weight_decay=self.fake_optimizer_weight_decay, + eps=self.fake_optimizer_adam_epsilon, + ) + self.lr_scheduler = get_scheduler( + self.lr_scheduler_name, + optimizer=self.optimizer, + num_warmup_steps=self.lr_warmup_iters, + num_training_steps=self.max_train_iters, + ) + self.fake_lr_scheduler = get_scheduler( + self.lr_scheduler_name, + optimizer=self.fake_optimizer, + num_warmup_steps=0, + num_training_steps=max(1, self.max_train_iters * self.fake_update_ratio), + ) + self.scheduler = _DMDEulerScheduler(shift=self.inference_shift, device=self.model.device) + + print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") + print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_transformer)}") + + @staticmethod + def _add_lora_to_transformer(transformer, rank, alpha, target_modules): + transformer.add_adapter( + LoraConfig( + r=rank, + lora_alpha=alpha, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + ) + + @staticmethod + def _set_lora_trainable(transformer): + transformer.requires_grad_(False) + transformer.train() + for name, param in transformer.named_parameters(): + param.requires_grad = "lora" in name + + @staticmethod + def _count_trainable(module): + return sum(1 for param in module.parameters() if param.requires_grad) + + def _latent_shape(self, sample): + image = sample["target_image"] + batch_size = image.shape[0] + latent_channels = getattr(self.model.vae.config, "z_dim", None) + if latent_channels is None: + latent_channels = self.model.transformer.config.in_channels // 4 + return ( + batch_size, + 1, + int(latent_channels), + image.shape[-2] // self.model.vae_scale_factor, + image.shape[-1] // self.model.vae_scale_factor, + ) + + def _encode_conditions(self, sample): + prompt = sample["prompt"] + if isinstance(prompt, str): + negative_prompt = self.negative_prompt + else: + negative_prompt = [self.negative_prompt] * len(prompt) + with torch.no_grad(): + condition = self.model.encode_prompt_condition(prompt) + negative_condition = self.model.encode_prompt_condition(negative_prompt) + return condition, negative_condition + + def _predict_velocity(self, transformer, latents, sigma, condition): + denoiser_input = self.model.prepare_denoiser_input(latents, {}, condition) + prediction = self.model.denoise_with_transformer(transformer, denoiser_input, sigma, condition) + prediction = self.model.postprocess_denoiser_output(prediction, denoiser_input) + return self.model.prepare_flow_matching_target(prediction) + + def sample_initial_latents(self, latent_shape): + return torch.randn(latent_shape, device=self.model.device, dtype=self.running_dtype) + + def sample_end_step(self): + return int(torch.randint(0, self.num_inference_steps, (1,), device=self.model.device).item()) + + def sample_renoise_sigma(self, batch_size): + raw = torch.rand((batch_size,), device=self.model.device, dtype=torch.float32) + if self.discrete_samples > 0: + raw = torch.ceil(raw * self.discrete_samples) / self.discrete_samples + raw = torch.clamp(raw, 1e-7, 1 - 1e-7) + return torch.clamp(_linear_shift(self.renoise_shift, raw), self.min_sigma, self.max_sigma).to(self.running_dtype) + + def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enabled, xt=None): + self.scheduler.set_timesteps(self.num_inference_steps) + if xt is None: + xt = self.sample_initial_latents(latent_shape) + x0 = None + self.model.transformer.train() + for idx in range(end_step_idx + 1): + sigma = self.scheduler.sigmas[idx].expand(latent_shape[0]).to(self.model.device, self.running_dtype) + context = torch.enable_grad if (grad_enabled and idx == end_step_idx) else torch.no_grad + with context(): + velocity = self._predict_velocity(self.model.transformer, xt, sigma, condition) + xt, x0 = self.scheduler.step(velocity, idx, xt) + return x0 + + def forward_loss(self, sample, stage): + condition, negative_condition = self._encode_conditions(sample) + latent_shape = self._latent_shape(sample) + end_step_idx = self.sample_end_step() + xt_start = self.sample_initial_latents(latent_shape) + x0_ref = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=False, xt=xt_start) + + sigma = self.sample_renoise_sigma(latent_shape[0]) + noise = torch.randn(latent_shape, device=self.model.device, dtype=torch.float32) + renoised_xt = _add_noise(x0_ref, noise, sigma) + velocity_gt = noise - x0_ref.float() + + if stage == "fake": + self.fake_transformer.train() + velocity_fake = self._predict_velocity(self.fake_transformer, renoised_xt, sigma, condition) + return F.mse_loss(velocity_fake.float(), velocity_gt.float(), reduction="mean") + + with torch.no_grad(): + self.fake_transformer.eval() + velocity_fake = self._predict_velocity(self.fake_transformer, renoised_xt, sigma, condition) + velocity_teacher_cond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, condition) + velocity_teacher_uncond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, negative_condition) + velocity_teacher = _do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) + + zeros = torch.zeros_like(sigma) + x_pred_fake = _euler_step(renoised_xt, velocity_fake, sigma, zeros) + x_pred_teacher = _euler_step(renoised_xt, velocity_teacher, sigma, zeros) + x0 = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=True, xt=xt_start) + return _dmd_loss(x0, x_pred_fake, x_pred_teacher) + + def train(self): + self.setup() + os.makedirs(self.output_dir, exist_ok=True) + + current_iter = 0 + running_dmd = 0.0 + running_fake = 0.0 + progress = tqdm(total=self.max_train_iters, desc="DMD-LoRA iterations") + + while current_iter < self.max_train_iters: + for sample in self.dataloader: + loss_dmd = self.forward_loss(sample, stage="generator") + loss_dmd.backward() + torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), self.max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + running_dmd += loss_dmd.detach().float().item() + + fake_losses = [] + for _ in range(self.fake_update_ratio): + loss_fake = self.forward_loss(sample, stage="fake") + loss_fake.backward() + torch.nn.utils.clip_grad_norm_(self.fake_transformer.parameters(), self.max_grad_norm) + self.fake_optimizer.step() + self.fake_lr_scheduler.step() + self.fake_optimizer.zero_grad(set_to_none=True) + fake_losses.append(loss_fake.detach()) + if fake_losses: + running_fake += torch.stack(fake_losses).mean().float().item() + + current_iter += 1 + progress.update(1) + progress.set_postfix( + dmd=running_dmd, + fake=running_fake, + lr=self.lr_scheduler.get_last_lr()[0], + ) + running_dmd = 0.0 + running_fake = 0.0 + + if self.save_every_iters and current_iter % self.save_every_iters == 0: + self.save_checkpoint(current_iter, self.save_total_limit) + + if current_iter >= self.max_train_iters: + break + + progress.close() + + def save_checkpoint(self, iteration, save_total_limit): + prune_checkpoints(self.output_dir, save_total_limit) + save_dir = os.path.join(self.output_dir, f"checkpoint-{iteration}") + os.makedirs(save_dir, exist_ok=True) + self.model.save_lora_weights(save_dir) + if self.save_fake_lora: + fake_dir = os.path.join(save_dir, "fake") + os.makedirs(fake_dir, exist_ok=True) + fake_state = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.fake_transformer)) + self.model.pipeline_cls.save_lora_weights(fake_dir, fake_state, safe_serialization=True) From 05a3e6dde338cb04b0071fdeb43102b7561146f3 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 13 May 2026 09:20:26 +0000 Subject: [PATCH 02/17] refactor --- .../lightx2v_train/schedulers/__init__.py | 4 + .../lightx2v_train/schedulers/dmd.py | 71 +++++++++ .../lightx2v_train/trainers/dmd_lora.py | 139 ++++++------------ 3 files changed, 116 insertions(+), 98 deletions(-) create mode 100644 lightx2v_train/lightx2v_train/schedulers/dmd.py diff --git a/lightx2v_train/lightx2v_train/schedulers/__init__.py b/lightx2v_train/lightx2v_train/schedulers/__init__.py index e69de29bb..5e3574ddc 100644 --- a/lightx2v_train/lightx2v_train/schedulers/__init__.py +++ b/lightx2v_train/lightx2v_train/schedulers/__init__.py @@ -0,0 +1,4 @@ +from .dmd import DMDFlowMatchingScheduler +from .flow_matching import RectifiedFlowMatchingScheduler + +__all__ = ["DMDFlowMatchingScheduler", "RectifiedFlowMatchingScheduler"] diff --git a/lightx2v_train/lightx2v_train/schedulers/dmd.py b/lightx2v_train/lightx2v_train/schedulers/dmd.py new file mode 100644 index 000000000..c28440856 --- /dev/null +++ b/lightx2v_train/lightx2v_train/schedulers/dmd.py @@ -0,0 +1,71 @@ +import torch + +from .flow_matching import RectifiedFlowMatchingScheduler + + +class DMDFlowMatchingScheduler(RectifiedFlowMatchingScheduler): + def __init__(self, config, dmd_config=None): + super().__init__(config) + dmd_config = dmd_config or {} + self.inference_shift = float(dmd_config.get("inference_shift", 3.0)) + self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0)) + self.min_sigma = float(dmd_config.get("sigma_min", 0.02)) + self.max_sigma = float(dmd_config.get("sigma_max", 1.0)) + self.discrete_samples = int(dmd_config.get("discrete_samples", 1000)) + + @staticmethod + def expand_to(value, target): + value = value.to(device=target.device) + while value.ndim < target.ndim: + value = value.view(*value.shape, 1) + return value + + @staticmethod + def linear_shift(mu, t): + return mu / (mu + (1 / t - 1)) + + def set_timesteps(self, num_inference_steps, device=None): + self.num_inference_steps = int(num_inference_steps) + device = device or self.device + timesteps = torch.linspace( + self.num_train_timesteps, + 0, + self.num_inference_steps + 1, + dtype=torch.float32, + device=device, + ) + self.sigmas = self.linear_shift(self.inference_shift, timesteps / self.num_train_timesteps) + self.timesteps = self.sigmas * self.num_train_timesteps + + def sigma_at(self, step_idx, batch_size, device=None, dtype=None): + sigma = self.sigmas[int(step_idx)].expand(int(batch_size)) + if device is not None or dtype is not None: + sigma = sigma.to(device=device, dtype=dtype) + return sigma + + def sample_renoise_sigma(self, batch_size, device=None, dtype=None): + device = device or self.device + raw = torch.rand((int(batch_size),), device=device, dtype=torch.float32) + if self.discrete_samples > 0: + raw = torch.ceil(raw * self.discrete_samples) / self.discrete_samples + raw = torch.clamp(raw, 1e-7, 1 - 1e-7) + sigma = torch.clamp(self.linear_shift(self.renoise_shift, raw), self.min_sigma, self.max_sigma) + if dtype is not None: + sigma = sigma.to(dtype=dtype) + return sigma + + def add_noise(self, latent, noise, sigmas): + sigmas = self.expand_to(sigmas, latent).to(dtype=torch.float32) + return ((1.0 - sigmas) * latent.float() + sigmas * noise.float()).to(dtype=latent.dtype) + + def euler_step(self, sample, velocity, sigma, target_sigma): + sigma = self.expand_to(sigma, sample).to(dtype=torch.float32) + target_sigma = self.expand_to(target_sigma, sample).to(dtype=torch.float32) + return sample.float() + (target_sigma - sigma) * velocity.float() + + def step_by_index(self, model_output, step_idx, sample): + sigma = self.sigma_at(step_idx, sample.shape[0], device=sample.device) + sigma_next = self.sigma_at(int(step_idx) + 1, sample.shape[0], device=sample.device) + x0 = sample.float() - self.expand_to(sigma, sample).float() * model_output.float() + next_sample = self.euler_step(sample, model_output, sigma, sigma_next) + return next_sample.to(sample.dtype), x0.to(sample.dtype) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 910b9eb7f..989a1e9fd 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -9,82 +9,13 @@ from tqdm.auto import tqdm from lightx2v_train.runtime.checkpoint import prune_checkpoints +from lightx2v_train.schedulers import DMDFlowMatchingScheduler from lightx2v_train.utils.registry import TRAINER_REGISTER from lightx2v_train.utils.utils import get_running_dtype from .base import BaseTrainer -def _linear_shift(mu, t): - return mu / (mu + (1 / t - 1)) - - -def _add_noise(x0, noise, sigma): - sigma = _expand_to(sigma, x0).to(dtype=torch.float32) - return ((1.0 - sigma) * x0.float() + sigma * noise.float()).to(dtype=x0.dtype) - - -def _euler_step(x, velocity, sigma, target_sigma): - sigma = _expand_to(sigma, x).to(dtype=torch.float32) - target_sigma = _expand_to(target_sigma, x).to(dtype=torch.float32) - return x.float() + (target_sigma - sigma) * velocity.float() - - -def _expand_to(value, target): - value = value.to(device=target.device) - while value.ndim < target.ndim: - value = value.view(*value.shape, 1) - return value - - -def _do_cfg(cond_pred, uncond_pred, cfg_scale, cfg_norm): - pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) - if cfg_norm in (None, "none"): - return pred - if cfg_norm == "layer_norm": - cond_norm = torch.norm(cond_pred, dim=-1, keepdim=True) - pred_norm = torch.norm(pred, dim=-1, keepdim=True) - return pred * (cond_norm / torch.clamp(pred_norm, min=1e-12)) - if cfg_norm == "scalar": - cond_norm = torch.norm(cond_pred) - pred_norm = torch.norm(pred) - return pred * min(1.0, (cond_norm / torch.clamp(pred_norm, min=1e-12)).item()) - raise ValueError(f"Unsupported cfg_norm: {cfg_norm}") - - -def _dmd_loss(latents, x_pred_fake_flow, x_pred_teacher): - with torch.no_grad(): - grad = x_pred_fake_flow - x_pred_teacher - dims = tuple(range(1, latents.ndim)) - normalizer = torch.abs(latents - x_pred_teacher).mean(dim=dims, keepdim=True) - grad = torch.nan_to_num(grad / normalizer) - return 0.5 * F.mse_loss(latents.float(), (latents.float() - grad.float()).detach(), reduction="mean") - - -class _DMDEulerScheduler: - def __init__(self, shift=3.0, device="cuda"): - self.shift = float(shift) - self.device = torch.device(device) - self.num_train_timesteps = 1000 - - def set_timesteps(self, num_inference_steps): - timesteps = torch.linspace( - 1000, - 0, - int(num_inference_steps) + 1, - dtype=torch.float32, - device=self.device, - ) - self.sigmas = _linear_shift(self.shift, timesteps / self.num_train_timesteps) - - def step(self, model_output, step_idx, sample): - sigma = self.sigmas[step_idx].expand(sample.shape[0]).to(sample.device) - sigma_next = self.sigmas[step_idx + 1].expand(sample.shape[0]).to(sample.device) - x0 = sample.float() - _expand_to(sigma, sample).float() * model_output.float() - next_sample = _euler_step(sample, model_output, sigma, sigma_next) - return next_sample.to(sample.dtype), x0.to(sample.dtype) - - @TRAINER_REGISTER("dmd_lora") class DmdLoraTrainer(BaseTrainer): def get_configs(self): @@ -132,17 +63,12 @@ def get_configs(self): self.save_total_limit = training_config.get("save_total_limit") self.save_fake_lora = fake_config.get("save_lora", False) - dmd_config = training_config.get("dmd", {}) - self.num_inference_steps = int(dmd_config.get("num_inference_steps", 4)) - self.fake_update_ratio = int(dmd_config.get("fake_update_ratio", 1)) - self.guidance_scale = float(dmd_config.get("guidance_scale", 3.0)) - self.negative_prompt = dmd_config.get("negative_prompt", " ") - self.cfg_norm = dmd_config.get("cfg_norm", "layer_norm") - self.min_sigma = float(dmd_config.get("sigma_min", 0.02)) - self.max_sigma = float(dmd_config.get("sigma_max", 1.0)) - self.discrete_samples = int(dmd_config.get("discrete_samples", 1000)) - self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0)) - self.inference_shift = float(dmd_config.get("inference_shift", 3.0)) + self.dmd_config = training_config.get("dmd", {}) + self.num_inference_steps = int(self.dmd_config.get("num_inference_steps", 4)) + self.fake_update_ratio = int(self.dmd_config.get("fake_update_ratio", 1)) + self.guidance_scale = float(self.dmd_config.get("guidance_scale", 3.0)) + self.negative_prompt = self.dmd_config.get("negative_prompt", " ") + self.cfg_norm = self.dmd_config.get("cfg_norm", "layer_norm") def setup(self): self.get_configs() @@ -194,7 +120,7 @@ def setup(self): num_warmup_steps=0, num_training_steps=max(1, self.max_train_iters * self.fake_update_ratio), ) - self.scheduler = _DMDEulerScheduler(shift=self.inference_shift, device=self.model.device) + self.scheduler = DMDFlowMatchingScheduler(self.config, self.dmd_config) print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_transformer)}") @@ -221,6 +147,30 @@ def _set_lora_trainable(transformer): def _count_trainable(module): return sum(1 for param in module.parameters() if param.requires_grad) + @staticmethod + def _do_cfg(cond_pred, uncond_pred, cfg_scale, cfg_norm): + pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) + if cfg_norm in (None, "none"): + return pred + if cfg_norm == "layer_norm": + cond_norm = torch.norm(cond_pred, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + return pred * (cond_norm / torch.clamp(pred_norm, min=1e-12)) + if cfg_norm == "scalar": + cond_norm = torch.norm(cond_pred) + pred_norm = torch.norm(pred) + return pred * min(1.0, (cond_norm / torch.clamp(pred_norm, min=1e-12)).item()) + raise ValueError(f"Unsupported cfg_norm: {cfg_norm}") + + @staticmethod + def _dmd_loss(latents, x_pred_fake_flow, x_pred_teacher): + with torch.no_grad(): + grad = x_pred_fake_flow - x_pred_teacher + dims = tuple(range(1, latents.ndim)) + normalizer = torch.abs(latents - x_pred_teacher).mean(dim=dims, keepdim=True) + grad = torch.nan_to_num(grad / normalizer) + return 0.5 * F.mse_loss(latents.float(), (latents.float() - grad.float()).detach(), reduction="mean") + def _latent_shape(self, sample): image = sample["target_image"] batch_size = image.shape[0] @@ -258,13 +208,6 @@ def sample_initial_latents(self, latent_shape): def sample_end_step(self): return int(torch.randint(0, self.num_inference_steps, (1,), device=self.model.device).item()) - def sample_renoise_sigma(self, batch_size): - raw = torch.rand((batch_size,), device=self.model.device, dtype=torch.float32) - if self.discrete_samples > 0: - raw = torch.ceil(raw * self.discrete_samples) / self.discrete_samples - raw = torch.clamp(raw, 1e-7, 1 - 1e-7) - return torch.clamp(_linear_shift(self.renoise_shift, raw), self.min_sigma, self.max_sigma).to(self.running_dtype) - def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enabled, xt=None): self.scheduler.set_timesteps(self.num_inference_steps) if xt is None: @@ -272,11 +215,11 @@ def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enable x0 = None self.model.transformer.train() for idx in range(end_step_idx + 1): - sigma = self.scheduler.sigmas[idx].expand(latent_shape[0]).to(self.model.device, self.running_dtype) + sigma = self.scheduler.sigma_at(idx, latent_shape[0], device=self.model.device, dtype=self.running_dtype) context = torch.enable_grad if (grad_enabled and idx == end_step_idx) else torch.no_grad with context(): velocity = self._predict_velocity(self.model.transformer, xt, sigma, condition) - xt, x0 = self.scheduler.step(velocity, idx, xt) + xt, x0 = self.scheduler.step_by_index(velocity, idx, xt) return x0 def forward_loss(self, sample, stage): @@ -286,10 +229,10 @@ def forward_loss(self, sample, stage): xt_start = self.sample_initial_latents(latent_shape) x0_ref = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=False, xt=xt_start) - sigma = self.sample_renoise_sigma(latent_shape[0]) + sigma = self.scheduler.sample_renoise_sigma(latent_shape[0], device=self.model.device, dtype=self.running_dtype) noise = torch.randn(latent_shape, device=self.model.device, dtype=torch.float32) - renoised_xt = _add_noise(x0_ref, noise, sigma) - velocity_gt = noise - x0_ref.float() + renoised_xt = self.scheduler.add_noise(x0_ref, noise, sigma) + velocity_gt = self.scheduler.build_train_gt(x0_ref.float(), noise) if stage == "fake": self.fake_transformer.train() @@ -301,13 +244,13 @@ def forward_loss(self, sample, stage): velocity_fake = self._predict_velocity(self.fake_transformer, renoised_xt, sigma, condition) velocity_teacher_cond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, condition) velocity_teacher_uncond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, negative_condition) - velocity_teacher = _do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) + velocity_teacher = self._do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) zeros = torch.zeros_like(sigma) - x_pred_fake = _euler_step(renoised_xt, velocity_fake, sigma, zeros) - x_pred_teacher = _euler_step(renoised_xt, velocity_teacher, sigma, zeros) + x_pred_fake = self.scheduler.euler_step(renoised_xt, velocity_fake, sigma, zeros) + x_pred_teacher = self.scheduler.euler_step(renoised_xt, velocity_teacher, sigma, zeros) x0 = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=True, xt=xt_start) - return _dmd_loss(x0, x_pred_fake, x_pred_teacher) + return self._dmd_loss(x0, x_pred_fake, x_pred_teacher) def train(self): self.setup() From ef053f4f6f2895d6bece6e55c2495ae7f7c23aeb Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 13 May 2026 09:23:43 +0000 Subject: [PATCH 03/17] refactor --- lightx2v_train/lightx2v_train/schedulers/__init__.py | 2 +- .../lightx2v_train/schedulers/{dmd.py => dmd_scheduler.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename lightx2v_train/lightx2v_train/schedulers/{dmd.py => dmd_scheduler.py} (100%) diff --git a/lightx2v_train/lightx2v_train/schedulers/__init__.py b/lightx2v_train/lightx2v_train/schedulers/__init__.py index 5e3574ddc..e33e2c0a2 100644 --- a/lightx2v_train/lightx2v_train/schedulers/__init__.py +++ b/lightx2v_train/lightx2v_train/schedulers/__init__.py @@ -1,4 +1,4 @@ -from .dmd import DMDFlowMatchingScheduler +from .dmd_scheduler import DMDFlowMatchingScheduler from .flow_matching import RectifiedFlowMatchingScheduler __all__ = ["DMDFlowMatchingScheduler", "RectifiedFlowMatchingScheduler"] diff --git a/lightx2v_train/lightx2v_train/schedulers/dmd.py b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py similarity index 100% rename from lightx2v_train/lightx2v_train/schedulers/dmd.py rename to lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py From 70ee08326845188525c57956881b030aed1c9589 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 13 May 2026 09:49:43 +0000 Subject: [PATCH 04/17] refactor --- .../schedulers/dmd_scheduler.py | 3 +-- .../lightx2v_train/trainers/dmd_lora.py | 21 ++++--------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py index c28440856..fdc0feb53 100644 --- a/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py +++ b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py @@ -4,9 +4,8 @@ class DMDFlowMatchingScheduler(RectifiedFlowMatchingScheduler): - def __init__(self, config, dmd_config=None): + def __init__(self, config, dmd_config={}): super().__init__(config) - dmd_config = dmd_config or {} self.inference_shift = float(dmd_config.get("inference_shift", 3.0)) self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0)) self.min_sigma = float(dmd_config.get("sigma_min", 0.02)) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 989a1e9fd..1d0aef6b6 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -20,8 +20,6 @@ class DmdLoraTrainer(BaseTrainer): def get_configs(self): model_config = self.config["model"] - if model_config.get("name") != "qwen_image": - raise ValueError("dmd_lora currently supports model.name: qwen_image only.") self.running_dtype = get_running_dtype(model_config["running_dtype"]) training_config = self.config["training"] @@ -30,12 +28,6 @@ def get_configs(self): self.lora_alpha = lora_config.get("alpha", self.lora_rank) self.lora_target_modules = lora_config.get("target_modules") - fake_config = training_config.get("fake", {}) - fake_lora_config = fake_config.get("lora", lora_config) - self.fake_lora_rank = fake_lora_config.get("rank", self.lora_rank) - self.fake_lora_alpha = fake_lora_config.get("alpha", self.fake_lora_rank) - self.fake_lora_target_modules = fake_lora_config.get("target_modules", self.lora_target_modules) - self.gradient_checkpointing = training_config.get("gradient_checkpointing", True) optimizer_config = training_config.get("optimizer", {}) @@ -45,6 +37,7 @@ def get_configs(self): self.optimizer_weight_decay = optimizer_config.get("weight_decay", 0.01) self.optimizer_adam_epsilon = optimizer_config.get("adam_epsilon", 1e-8) + fake_config = training_config.get("fake", {}) fake_optimizer_config = fake_config.get("optimizer", {}) self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) self.fake_optimizer_adam_beta1 = fake_optimizer_config.get("adam_beta1", self.optimizer_adam_beta1) @@ -61,7 +54,6 @@ def get_configs(self): self.max_grad_norm = training_config.get("max_grad_norm", 1.0) self.save_every_iters = training_config.get("save_every_iters", 0) self.save_total_limit = training_config.get("save_total_limit") - self.save_fake_lora = fake_config.get("save_lora", False) self.dmd_config = training_config.get("dmd", {}) self.num_inference_steps = int(self.dmd_config.get("num_inference_steps", 4)) @@ -82,9 +74,9 @@ def setup(self): self.fake_transformer = self.model.load_transformer() self._add_lora_to_transformer( self.fake_transformer, - self.fake_lora_rank, - self.fake_lora_alpha, - self.fake_lora_target_modules, + self.lora_rank, + self.lora_alpha, + self.lora_target_modules, ) self._set_lora_trainable(self.fake_transformer) if self.gradient_checkpointing and hasattr(self.fake_transformer, "enable_gradient_checkpointing"): @@ -306,8 +298,3 @@ def save_checkpoint(self, iteration, save_total_limit): save_dir = os.path.join(self.output_dir, f"checkpoint-{iteration}") os.makedirs(save_dir, exist_ok=True) self.model.save_lora_weights(save_dir) - if self.save_fake_lora: - fake_dir = os.path.join(save_dir, "fake") - os.makedirs(fake_dir, exist_ok=True) - fake_state = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.fake_transformer)) - self.model.pipeline_cls.save_lora_weights(fake_dir, fake_state, safe_serialization=True) From 7983b574b97ff78426e2c44042849bbea8366df0 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 13 May 2026 11:02:42 +0000 Subject: [PATCH 05/17] update dmd_lora --- .../lightx2v_train/model_zoo/qwen_image.py | 9 +- .../lightx2v_train/trainers/dmd_lora.py | 181 ++++++------------ .../lightx2v_train/trainers/lora.py | 40 +++- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index 10566647f..ec0934fc1 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -15,8 +15,15 @@ class QwenImageModel(BaseModel): pipeline_cls = QwenImagePipeline - def load_components(self): + def load_components(self, transformer_only=False, reference_model=None): model_path = self.config["model"]["pretrained_model_name_or_path"] + if transformer_only: + if reference_model is not None: + self.text_pipeline = reference_model.text_pipeline + self.vae = reference_model.vae + self.transformer = self.load_transformer() + return + self.text_pipeline = QwenImagePipeline.from_pretrained( model_path, transformer=None, diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 1d0aef6b6..33a809119 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -2,41 +2,20 @@ import torch import torch.nn.functional as F -from diffusers.optimization import get_scheduler -from diffusers.utils import convert_state_dict_to_diffusers -from peft import LoraConfig -from peft.utils import get_peft_model_state_dict from tqdm.auto import tqdm -from lightx2v_train.runtime.checkpoint import prune_checkpoints from lightx2v_train.schedulers import DMDFlowMatchingScheduler from lightx2v_train.utils.registry import TRAINER_REGISTER -from lightx2v_train.utils.utils import get_running_dtype -from .base import BaseTrainer +from .lora import LoraTrainer @TRAINER_REGISTER("dmd_lora") -class DmdLoraTrainer(BaseTrainer): +class DmdLoraTrainer(LoraTrainer): def get_configs(self): - model_config = self.config["model"] - self.running_dtype = get_running_dtype(model_config["running_dtype"]) + super().get_configs() training_config = self.config["training"] - lora_config = training_config.get("lora", {}) - self.lora_rank = lora_config.get("rank", 16) - self.lora_alpha = lora_config.get("alpha", self.lora_rank) - self.lora_target_modules = lora_config.get("target_modules") - - self.gradient_checkpointing = training_config.get("gradient_checkpointing", True) - - optimizer_config = training_config.get("optimizer", {}) - self.optimizer_learning_rate = optimizer_config.get("learning_rate", 1e-4) - self.optimizer_adam_beta1 = optimizer_config.get("adam_beta1", 0.9) - self.optimizer_adam_beta2 = optimizer_config.get("adam_beta2", 0.999) - self.optimizer_weight_decay = optimizer_config.get("weight_decay", 0.01) - self.optimizer_adam_epsilon = optimizer_config.get("adam_epsilon", 1e-8) - fake_config = training_config.get("fake", {}) fake_optimizer_config = fake_config.get("optimizer", {}) self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) @@ -45,16 +24,6 @@ def get_configs(self): self.fake_optimizer_weight_decay = fake_optimizer_config.get("weight_decay", self.optimizer_weight_decay) self.fake_optimizer_adam_epsilon = fake_optimizer_config.get("adam_epsilon", self.optimizer_adam_epsilon) - self.lr_scheduler_name = training_config.get("lr_scheduler", "constant") - self.lr_warmup_iters = training_config.get("lr_warmup_iters", 0) - self.max_train_iters = training_config["max_train_iters"] - - self.output_dir = training_config["output_dir"] - self.gradient_accumulation_iters = training_config.get("gradient_accumulation_iters", 1) - self.max_grad_norm = training_config.get("max_grad_norm", 1.0) - self.save_every_iters = training_config.get("save_every_iters", 0) - self.save_total_limit = training_config.get("save_total_limit") - self.dmd_config = training_config.get("dmd", {}) self.num_inference_steps = int(self.dmd_config.get("num_inference_steps", 4)) self.fake_update_ratio = int(self.dmd_config.get("fake_update_ratio", 1)) @@ -66,74 +35,39 @@ def setup(self): self.get_configs() print("[dmd_lora] single-GPU resident mode: student/fake/teacher transformers stay on CUDA") - self.model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) - self.model.set_lora_trainable() + self.setup_lora() + + self.fake_model = self.model.__class__(self.config) + self.fake_model.load_components(transformer_only=True, reference_model=self.model) + self.fake_model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) + self.fake_model.set_lora_trainable() if self.gradient_checkpointing: - self.model.enable_gradient_checkpointing() - - self.fake_transformer = self.model.load_transformer() - self._add_lora_to_transformer( - self.fake_transformer, - self.lora_rank, - self.lora_alpha, - self.lora_target_modules, - ) - self._set_lora_trainable(self.fake_transformer) - if self.gradient_checkpointing and hasattr(self.fake_transformer, "enable_gradient_checkpointing"): - self.fake_transformer.enable_gradient_checkpointing() - - self.teacher_transformer = self.model.load_transformer() - self.teacher_transformer.requires_grad_(False) - self.teacher_transformer.eval() - - self.optimizer = torch.optim.AdamW( - self.model.trainable_parameters(), - lr=self.optimizer_learning_rate, - betas=(self.optimizer_adam_beta1, self.optimizer_adam_beta2), - weight_decay=self.optimizer_weight_decay, - eps=self.optimizer_adam_epsilon, - ) - self.fake_optimizer = torch.optim.AdamW( - (p for p in self.fake_transformer.parameters() if p.requires_grad), - lr=self.fake_optimizer_learning_rate, - betas=(self.fake_optimizer_adam_beta1, self.fake_optimizer_adam_beta2), + self.fake_model.enable_gradient_checkpointing() + + self.teacher_model = self.model.__class__(self.config) + self.teacher_model.load_components(transformer_only=True, reference_model=self.model) + self.teacher_model.transformer.requires_grad_(False) + self.teacher_model.transformer.eval() + + self.optimizer = self.build_optimizer(self.model.trainable_parameters()) + self.fake_optimizer = self.build_optimizer( + self.fake_model.trainable_parameters(), + learning_rate=self.fake_optimizer_learning_rate, + adam_beta1=self.fake_optimizer_adam_beta1, + adam_beta2=self.fake_optimizer_adam_beta2, weight_decay=self.fake_optimizer_weight_decay, - eps=self.fake_optimizer_adam_epsilon, + adam_epsilon=self.fake_optimizer_adam_epsilon, ) - self.lr_scheduler = get_scheduler( - self.lr_scheduler_name, - optimizer=self.optimizer, - num_warmup_steps=self.lr_warmup_iters, - num_training_steps=self.max_train_iters, - ) - self.fake_lr_scheduler = get_scheduler( - self.lr_scheduler_name, - optimizer=self.fake_optimizer, + self.lr_scheduler = self.build_lr_scheduler(self.optimizer) + self.fake_lr_scheduler = self.build_lr_scheduler( + self.fake_optimizer, num_warmup_steps=0, num_training_steps=max(1, self.max_train_iters * self.fake_update_ratio), ) self.scheduler = DMDFlowMatchingScheduler(self.config, self.dmd_config) print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") - print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_transformer)}") - - @staticmethod - def _add_lora_to_transformer(transformer, rank, alpha, target_modules): - transformer.add_adapter( - LoraConfig( - r=rank, - lora_alpha=alpha, - init_lora_weights="gaussian", - target_modules=target_modules, - ) - ) - - @staticmethod - def _set_lora_trainable(transformer): - transformer.requires_grad_(False) - transformer.train() - for name, param in transformer.named_parameters(): - param.requires_grad = "lora" in name + print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_model.transformer)}") @staticmethod def _count_trainable(module): @@ -188,11 +122,11 @@ def _encode_conditions(self, sample): negative_condition = self.model.encode_prompt_condition(negative_prompt) return condition, negative_condition - def _predict_velocity(self, transformer, latents, sigma, condition): - denoiser_input = self.model.prepare_denoiser_input(latents, {}, condition) - prediction = self.model.denoise_with_transformer(transformer, denoiser_input, sigma, condition) - prediction = self.model.postprocess_denoiser_output(prediction, denoiser_input) - return self.model.prepare_flow_matching_target(prediction) + def _predict_velocity(self, model, latents, sigma, condition): + denoiser_input = model.prepare_denoiser_input(latents, {}, condition) + prediction = model.denoise(denoiser_input, sigma, condition) + prediction = model.postprocess_denoiser_output(prediction, denoiser_input) + return model.prepare_flow_matching_target(prediction) def sample_initial_latents(self, latent_shape): return torch.randn(latent_shape, device=self.model.device, dtype=self.running_dtype) @@ -210,7 +144,7 @@ def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enable sigma = self.scheduler.sigma_at(idx, latent_shape[0], device=self.model.device, dtype=self.running_dtype) context = torch.enable_grad if (grad_enabled and idx == end_step_idx) else torch.no_grad with context(): - velocity = self._predict_velocity(self.model.transformer, xt, sigma, condition) + velocity = self._predict_velocity(self.model, xt, sigma, condition) xt, x0 = self.scheduler.step_by_index(velocity, idx, xt) return x0 @@ -227,15 +161,15 @@ def forward_loss(self, sample, stage): velocity_gt = self.scheduler.build_train_gt(x0_ref.float(), noise) if stage == "fake": - self.fake_transformer.train() - velocity_fake = self._predict_velocity(self.fake_transformer, renoised_xt, sigma, condition) + self.fake_model.transformer.train() + velocity_fake = self._predict_velocity(self.fake_model, renoised_xt, sigma, condition) return F.mse_loss(velocity_fake.float(), velocity_gt.float(), reduction="mean") with torch.no_grad(): - self.fake_transformer.eval() - velocity_fake = self._predict_velocity(self.fake_transformer, renoised_xt, sigma, condition) - velocity_teacher_cond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, condition) - velocity_teacher_uncond = self._predict_velocity(self.teacher_transformer, renoised_xt, sigma, negative_condition) + self.fake_model.transformer.eval() + velocity_fake = self._predict_velocity(self.fake_model, renoised_xt, sigma, condition) + velocity_teacher_cond = self._predict_velocity(self.teacher_model, renoised_xt, sigma, condition) + velocity_teacher_uncond = self._predict_velocity(self.teacher_model, renoised_xt, sigma, negative_condition) velocity_teacher = self._do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) zeros = torch.zeros_like(sigma) @@ -248,32 +182,37 @@ def train(self): self.setup() os.makedirs(self.output_dir, exist_ok=True) + max_train_iters = self.max_train_iters + fake_update_ratio = self.fake_update_ratio + max_grad_norm = self.max_grad_norm + save_every_iters = self.save_every_iters + save_total_limit = self.save_total_limit current_iter = 0 running_dmd = 0.0 running_fake = 0.0 - progress = tqdm(total=self.max_train_iters, desc="DMD-LoRA iterations") - while current_iter < self.max_train_iters: + progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations") + + while current_iter < max_train_iters: for sample in self.dataloader: loss_dmd = self.forward_loss(sample, stage="generator") loss_dmd.backward() - torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), max_grad_norm) self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) - running_dmd += loss_dmd.detach().float().item() + running_dmd += loss_dmd.item() - fake_losses = [] - for _ in range(self.fake_update_ratio): + fake_loss = 0.0 + for _ in range(fake_update_ratio): loss_fake = self.forward_loss(sample, stage="fake") loss_fake.backward() - torch.nn.utils.clip_grad_norm_(self.fake_transformer.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.fake_model.transformer.parameters(), max_grad_norm) self.fake_optimizer.step() self.fake_lr_scheduler.step() self.fake_optimizer.zero_grad(set_to_none=True) - fake_losses.append(loss_fake.detach()) - if fake_losses: - running_fake += torch.stack(fake_losses).mean().float().item() + fake_loss += loss_fake.item() + running_fake += fake_loss / fake_update_ratio current_iter += 1 progress.update(1) @@ -285,16 +224,10 @@ def train(self): running_dmd = 0.0 running_fake = 0.0 - if self.save_every_iters and current_iter % self.save_every_iters == 0: - self.save_checkpoint(current_iter, self.save_total_limit) + if save_every_iters and current_iter % save_every_iters == 0: + self.save_checkpoint(current_iter, save_total_limit) - if current_iter >= self.max_train_iters: + if current_iter >= max_train_iters: break progress.close() - - def save_checkpoint(self, iteration, save_total_limit): - prune_checkpoints(self.output_dir, save_total_limit) - save_dir = os.path.join(self.output_dir, f"checkpoint-{iteration}") - os.makedirs(save_dir, exist_ok=True) - self.model.save_lora_weights(save_dir) diff --git a/lightx2v_train/lightx2v_train/trainers/lora.py b/lightx2v_train/lightx2v_train/trainers/lora.py index 7fd6d2f9e..34fb50337 100644 --- a/lightx2v_train/lightx2v_train/trainers/lora.py +++ b/lightx2v_train/lightx2v_train/trainers/lora.py @@ -45,23 +45,43 @@ def get_configs(self): def setup(self): self.get_configs() + self.setup_lora() + + self.optimizer = self.build_optimizer(self.model.trainable_parameters()) + self.lr_scheduler = self.build_lr_scheduler(self.optimizer) + + def setup_lora(self): self.model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) self.model.set_lora_trainable() if self.gradient_checkpointing: self.model.enable_gradient_checkpointing() - self.optimizer = torch.optim.AdamW( - self.model.trainable_parameters(), - lr=self.optimizer_learning_rate, - betas=(self.optimizer_adam_beta1, self.optimizer_adam_beta2), - weight_decay=self.optimizer_weight_decay, - eps=self.optimizer_adam_epsilon, + def build_optimizer( + self, + parameters, + learning_rate=None, + adam_beta1=None, + adam_beta2=None, + weight_decay=None, + adam_epsilon=None, + ): + return torch.optim.AdamW( + parameters, + lr=self.optimizer_learning_rate if learning_rate is None else learning_rate, + betas=( + self.optimizer_adam_beta1 if adam_beta1 is None else adam_beta1, + self.optimizer_adam_beta2 if adam_beta2 is None else adam_beta2, + ), + weight_decay=self.optimizer_weight_decay if weight_decay is None else weight_decay, + eps=self.optimizer_adam_epsilon if adam_epsilon is None else adam_epsilon, ) - self.lr_scheduler = get_scheduler( + + def build_lr_scheduler(self, optimizer, num_warmup_steps=None, num_training_steps=None): + return get_scheduler( self.lr_scheduler_name, - optimizer=self.optimizer, - num_warmup_steps=self.lr_warmup_iters, - num_training_steps=self.max_train_iters, + optimizer=optimizer, + num_warmup_steps=self.lr_warmup_iters if num_warmup_steps is None else num_warmup_steps, + num_training_steps=self.max_train_iters if num_training_steps is None else num_training_steps, ) def compute_loss_on_sample(self, sample): From 39882695cbf8e8ddfcf368c2c5ae756807eeec6e Mon Sep 17 00:00:00 2001 From: Musisoul Date: Thu, 14 May 2026 12:22:31 +0000 Subject: [PATCH 06/17] update build_model --- lightx2v_train/lightx2v_train/trainers/dmd_lora.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 33a809119..ab07a2d1d 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from tqdm.auto import tqdm +from lightx2v_train.model_zoo import build_model from lightx2v_train.schedulers import DMDFlowMatchingScheduler from lightx2v_train.utils.registry import TRAINER_REGISTER @@ -37,14 +38,14 @@ def setup(self): self.setup_lora() - self.fake_model = self.model.__class__(self.config) + self.fake_model = build_model(self.config) self.fake_model.load_components(transformer_only=True, reference_model=self.model) self.fake_model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) self.fake_model.set_lora_trainable() if self.gradient_checkpointing: self.fake_model.enable_gradient_checkpointing() - self.teacher_model = self.model.__class__(self.config) + self.teacher_model = build_model(self.config) self.teacher_model.load_components(transformer_only=True, reference_model=self.model) self.teacher_model.transformer.requires_grad_(False) self.teacher_model.transformer.eval() From 05208c9ec13f2f7465a589b21fc3f5d4791b48de Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 09:11:05 +0000 Subject: [PATCH 07/17] merge main --- .gitignore | 3 +- .../configs/lora/longcat_image_lora.yaml | 68 ----------- lightx2v_train/infer.py | 33 ++++++ .../lightx2v_train/infer/__init__.py | 6 + lightx2v_train/lightx2v_train/infer/base.py | 45 ++++++++ lightx2v_train/lightx2v_train/infer/image.py | 69 +++++++++++ .../lightx2v_train/infer/image_native.py | 50 ++++++++ .../lightx2v_train/model_zoo/base.py | 62 ++++++---- .../lightx2v_train/model_zoo/longcat_image.py | 93 +++++++++------ .../lightx2v_train/model_zoo/qwen_image.py | 107 +++++++++++------- .../schedulers/flow_matching.py | 52 +++++---- .../lightx2v_train/trainers/base.py | 8 +- .../lightx2v_train/trainers/dmd_lora.py | 10 +- .../lightx2v_train/trainers/lora.py | 65 +++++++---- .../lightx2v_train/utils/registry.py | 9 ++ 15 files changed, 463 insertions(+), 217 deletions(-) delete mode 100644 lightx2v_train/configs/lora/longcat_image_lora.yaml create mode 100644 lightx2v_train/infer.py create mode 100644 lightx2v_train/lightx2v_train/infer/__init__.py create mode 100644 lightx2v_train/lightx2v_train/infer/base.py create mode 100644 lightx2v_train/lightx2v_train/infer/image.py create mode 100644 lightx2v_train/lightx2v_train/infer/image_native.py diff --git a/.gitignore b/.gitignore index 1f959e2a9..7004def94 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,6 @@ .log *.pid *.ipynb* -*.mp4 build/ dist/ .cache/ @@ -31,3 +30,5 @@ app/.gradio/ save_results/* *.egg-info/ lightx2v_train/train_output/* +lightx2v_train/output_train/* +lightx2v_train/output_infer/* diff --git a/lightx2v_train/configs/lora/longcat_image_lora.yaml b/lightx2v_train/configs/lora/longcat_image_lora.yaml deleted file mode 100644 index 32bbc93ed..000000000 --- a/lightx2v_train/configs/lora/longcat_image_lora.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - name: longcat_image - pretrained_model_name_or_path: meituan-longcat/LongCat-Image - enable_prompt_rewrite_training: false - running_dtype: bf16 - -data: - train: - name: image_dataset - num_workers: 8 - prompt_dropout_rate: 0.1 - image_size: 1024 - random_ratio: false - shuffle: true - data_path: - - /path/to/image_dataset/train.jsonl - val: - name: image_dataset - num_workers: 8 - image_size: 1024 - shuffle: false - data_path: - - /path/to/image_dataset/val.jsonl - -scheduler: - training: - num_train_timesteps: 1000 - timestep_distribution: logitnormal - logitnormal_mean: 0.0 - logitnormal_std: 1.0 - min_t: 0.001 - max_t: 1.0 - do_time_shift: true - time_shift_mu: 5.0 - time_shift_power: 1.0 - -training: - method: lora - max_train_iters: 3000 - gradient_accumulation_iters: 1 - gradient_checkpointing: true - max_grad_norm: 1.0 - lr_scheduler: constant - lr_warmup_iters: 10 - save_every_iters: 250 - save_total_limit: 10 - lora: - rank: 16 - alpha: 16 - target_modules: - - to_k - - to_q - - to_v - - to_out.0 - optimizer: - learning_rate: 0.0001 - adam_beta1: 0.9 - adam_beta2: 0.999 - weight_decay: 0.01 - adam_epsilon: 0.00000001 - output_dir: ./train_output/longcat_image_lora - -inference: - width: 1024 - height: 1024 - num_inference_steps: 50 - cfg_guidance_scale: 4.0 - negative_prompt: " " diff --git a/lightx2v_train/infer.py b/lightx2v_train/infer.py new file mode 100644 index 000000000..2bc54a0b3 --- /dev/null +++ b/lightx2v_train/infer.py @@ -0,0 +1,33 @@ +import argparse + +from lightx2v_train.data import build_data +from lightx2v_train.model_zoo import build_model +from lightx2v_train.runtime import load_config + +from lightx2v_train.infer import build_inferencer + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run inference with a trained LightX2V model.") + parser.add_argument("--config", required=True, help="Path to a YAML config file.") + return parser.parse_args() + + +def main(): + args = parse_args() + config = load_config(args.config) + + model = build_model(config) + model.load_components() + + dataloader_val = build_data(config, train_or_val="val") + + inferencer = build_inferencer(config) + inferencer.set_model(model) + inferencer.set_data(dataloader_val) + + inferencer.infer() + + +if __name__ == "__main__": + main() diff --git a/lightx2v_train/lightx2v_train/infer/__init__.py b/lightx2v_train/lightx2v_train/infer/__init__.py new file mode 100644 index 000000000..e2e1d191c --- /dev/null +++ b/lightx2v_train/lightx2v_train/infer/__init__.py @@ -0,0 +1,6 @@ +from lightx2v_train.utils.registry import build_inferencer + +from .image import ImageInferencer +from .image_native import NativePipelineInferencer + +__all__ = ["build_inferencer", "ImageInferencer", "NativePipelineInferencer"] diff --git a/lightx2v_train/lightx2v_train/infer/base.py b/lightx2v_train/lightx2v_train/infer/base.py new file mode 100644 index 000000000..bbd072a06 --- /dev/null +++ b/lightx2v_train/lightx2v_train/infer/base.py @@ -0,0 +1,45 @@ +import os + +import torch + +from lightx2v_train.schedulers.flow_matching import RectifiedFlowMatchingScheduler + + +class BaseInferencer: + def __init__(self, config): + self.config = config + self.infer_config = config.get("inference", {}) + self.output_infer_dir = self.infer_config.get("output_dir", None) + if self.output_infer_dir is not None: + os.makedirs(self.output_infer_dir, exist_ok=True) + + self.model = None + self.dataloader_eval = None + self.enable_cfg = True + self.guidance_scale = None + + self.scheduler = RectifiedFlowMatchingScheduler(config) + + def set_data(self, dataloader_val): + self.dataloader_eval = dataloader_val + + def set_model(self, model): + self.model = model + + def cfg_guided_denoise(self, latents, timestep_or_sigma, pos_cond, neg_cond): + denoiser_input = self.model.prepare_denoiser_input(latents) + + pred_pos = self.model.denoise(denoiser_input, timestep_or_sigma, pos_cond) + pred_pos = self.model.postprocess_denoiser_output(pred_pos, denoiser_input) + + if self.enable_cfg: + pred_neg = self.model.denoise(denoiser_input, timestep_or_sigma, neg_cond) + pred_neg = self.model.postprocess_denoiser_output(pred_neg, denoiser_input) + pred = pred_neg + self.guidance_scale * (pred_pos - pred_neg) + else: + pred = pred_pos + return self.model.postprocess_infer_step_output(pred) + + @torch.no_grad() + def infer(self): + raise NotImplementedError diff --git a/lightx2v_train/lightx2v_train/infer/image.py b/lightx2v_train/lightx2v_train/infer/image.py new file mode 100644 index 000000000..48fa85f06 --- /dev/null +++ b/lightx2v_train/lightx2v_train/infer/image.py @@ -0,0 +1,69 @@ +from pathlib import Path + +import torch +from tqdm.auto import tqdm + +from lightx2v_train.utils.registry import INFERENCER_REGISTER + +from .base import BaseInferencer + + +@INFERENCER_REGISTER("image") +class ImageInferencer(BaseInferencer): + @torch.no_grad() + def infer(self): + prompts = [sample["prompt"] for sample in self.dataloader_eval.dataset.samples] + + height = self.infer_config.get("height", 1024) + width = self.infer_config.get("width", 1024) + num_inference_steps = self.infer_config.get("num_inference_steps", 50) + + base_seed = self.infer_config.get("seed", 42) + # self.lora_path = self.infer_config.get("lora_path", None) + + # if self.lora_path: + # self.model.load_lora_for_infer(self.lora_path) + + self.scheduler.set_timesteps(num_inference_steps) + + self.enable_cfg = self.infer_config.get("enable_cfg", True) + if self.enable_cfg: + self.guidance_scale = self.infer_config.get("cfg_guidance_scale", 4.0) + negative_prompt = self.infer_config.get("negative_prompt", " ") + neg_cond = self.model.encode_condition({"prompt": negative_prompt}) + else: + self.guidance_scale = None + neg_cond = None + + saved_paths = [] + self.model.transformer.eval() + with torch.no_grad(): + for i, prompt in enumerate(prompts): + generator = torch.Generator(device=self.model.device).manual_seed(base_seed + i) + pos_cond = self.model.encode_condition({"prompt": prompt}) + latent = self.model.prepare_infer_latents(height, width, generator) + + for step_idx, current_timestep in enumerate(tqdm(self.scheduler.timesteps, desc=f"[{i + 1}/{len(prompts)}] Denoising")): + # current_timestep is in [0, 1000] + sigma = self.scheduler.sigmas[step_idx].unsqueeze(0) # shape (1,) required by diffusers + # sigma is in [0, 1] + model_output = self.cfg_guided_denoise( + latents=latent, + timestep_or_sigma=sigma, + pos_cond=pos_cond, + neg_cond=neg_cond, + ) + latent = self.scheduler.step(model_output, current_timestep, latent) + + images = self.model.decode_latent(latent) + + if self.output_infer_dir is not None: + save_path = Path(self.output_infer_dir) / f"{i:05d}.png" + images[0].save(save_path) + print(f"Saved to {save_path}") + saved_paths.append(str(save_path)) + + # if self.lora_path: + # self.model.unload_lora_for_infer() + + return saved_paths diff --git a/lightx2v_train/lightx2v_train/infer/image_native.py b/lightx2v_train/lightx2v_train/infer/image_native.py new file mode 100644 index 000000000..782a8cad0 --- /dev/null +++ b/lightx2v_train/lightx2v_train/infer/image_native.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import torch + +from lightx2v_train.utils.registry import INFERENCER_REGISTER + +from .base import BaseInferencer + + +@INFERENCER_REGISTER("native_pipeline") +class NativePipelineInferencer(BaseInferencer): + @torch.no_grad() + def infer(self): + prompts = [sample["prompt"] for sample in self.dataloader_eval.dataset.samples] + enable_cfg = self.infer_config.get("enable_cfg", False) + negative_prompt = self.infer_config.get("negative_prompt", " ") if enable_cfg else None + base_seed = self.infer_config.get("seed", 42) + + # Model-specific kwargs (e.g. QwenImage uses `true_cfg_scale` instead of `guidance_scale`) + pipeline_kwargs = self.model.get_pipeline_infer_kwargs(self.infer_config) + + # Use the pipeline's original pretrained scheduler for bit-exact alignment with diffusers + pipe = self.model.assemble_pipeline() + + # self.lora_path = self.infer_config.get("lora_path", None) + # if self.lora_path: + # pipe.load_lora_weights(self.lora_path) + + saved_paths = [] + self.model.transformer.eval() + with torch.no_grad(): + for i, prompt in enumerate(prompts): + generator = torch.Generator(device=self.model.device).manual_seed(base_seed + i) + result = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + generator=generator, + **pipeline_kwargs, + ) + + if self.output_infer_dir is not None: + save_path = Path(self.output_infer_dir) / f"{i:05d}.png" + result.images[0].save(save_path) + print(f"Saved to {save_path}") + saved_paths.append(str(save_path)) + + # if self.lora_path: + # self.model.unload_lora_for_infer() + + return saved_paths diff --git a/lightx2v_train/lightx2v_train/model_zoo/base.py b/lightx2v_train/lightx2v_train/model_zoo/base.py index 16401750b..4213616a3 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/base.py +++ b/lightx2v_train/lightx2v_train/model_zoo/base.py @@ -7,36 +7,17 @@ from lightx2v_train.utils.utils import get_running_dtype -class DenoiserInput: - def __init__(self, hidden_states, extra): - self.hidden_states = hidden_states - self.extra = extra - - class BaseModel: def __init__(self, config): self.config = config self.running_dtype = get_running_dtype(config["model"]["running_dtype"]) self.device = torch.device("cuda") - self.pipeline = None - self.flow_matching = None self.transformer = None self.vae = None def load_components(self): raise NotImplementedError - def build_pipeline(self): - raise NotImplementedError - - def generate(self, **kwargs): - lora_path = kwargs.pop("lora_path", None) - pipe = self.build_pipeline() - if lora_path is not None: - pipe.load_lora_weights(lora_path) - pipe.to(self.device) - return pipe(**kwargs) - def add_lora(self, rank, alpha, target_modules): lora_config = LoraConfig( r=rank, @@ -67,13 +48,20 @@ def prepare_flow_matching_target(self, velocity): """Layout/format alignment between flow-matching velocity and denoiser output. Override when needed.""" return velocity + def postprocess_infer_step_output(self, pred): + """Convert denoiser prediction to the latent format expected by scheduler.step(). + + Override when postprocess_denoiser_output returns a different layout than encode_to_latent. + """ + return pred + def encode_to_latent(self, sample): raise NotImplementedError def encode_condition(self, sample): raise NotImplementedError - def prepare_denoiser_input(self, noisy_latent, sample, condition): + def prepare_denoiser_input(self, noisy_latent): raise NotImplementedError def denoise(self, denoiser_input, timesteps, condition): @@ -82,6 +70,40 @@ def denoise(self, denoiser_input, timesteps, condition): def postprocess_denoiser_output(self, prediction, denoiser_input): raise NotImplementedError + def prepare_infer_latents(self, height, width, generator=None): + raise NotImplementedError + + def decode_latent(self, latent): + """Decode a latent tensor into a list of PIL images.""" + raise NotImplementedError + + def assemble_pipeline(self, scheduler=None): + """Assemble a full diffusers pipeline from loaded components for pipeline-based inference. + + Args: + scheduler: The scheduler to inject into the pipeline. If None, the pipeline's + original pretrained scheduler is used. Pass the framework's + RectifiedFlowMatchingScheduler for training-inference alignment. + """ + raise NotImplementedError + + def get_pipeline_infer_kwargs(self, infer_config): + """Return kwargs to pass to pipeline.__call__. Override to adapt model-specific parameter names.""" + return { + "height": infer_config.get("height", 1024), + "width": infer_config.get("width", 1024), + "num_inference_steps": infer_config.get("num_inference_steps", 50), + "guidance_scale": infer_config.get("cfg_guidance_scale", 4.0), + } + + def load_lora_for_infer(self, lora_path): + pipe = self.assemble_pipeline() + pipe.load_lora_weights(lora_path) + + def unload_lora_for_infer(self): + pipe = self.assemble_pipeline() + pipe.unload_lora_weights() + def save_lora_weights(self, save_dir): lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.transformer)) if hasattr(self.pipeline_cls, "save_lora_weights"): diff --git a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py index 7752fd9b0..5fdb76bc2 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py @@ -1,10 +1,23 @@ +from dataclasses import dataclass + +import numpy as np +import torch +from PIL import Image from diffusers import AutoencoderKL, LongCatImagePipeline from diffusers.models.transformers import LongCatImageTransformer2DModel from diffusers.pipelines.longcat_image.pipeline_longcat_image import prepare_pos_ids from lightx2v_train.utils.registry import MODEL_REGISTER -from .base import BaseModel, DenoiserInput +from .base import BaseModel + + +@dataclass +class LongCatImageDenoiserInput: + hidden_states: torch.Tensor + img_ids: torch.Tensor + height: int + width: int @MODEL_REGISTER("longcat_image") @@ -23,17 +36,6 @@ def load_components(self): self.transformer = LongCatImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) self.vae.requires_grad_(False) - def build_pipeline(self): - pipe = LongCatImagePipeline( - scheduler=self.flow_matching, - vae=self.vae, - text_encoder=self.text_pipeline.text_encoder, - tokenizer=self.text_pipeline.tokenizer, - text_processor=self.text_pipeline.text_processor, - transformer=self.transformer, - ) - return pipe - @property def vae_scale_factor(self): return 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -56,32 +58,22 @@ def encode_condition(self, sample): ) return {"prompt_embed": prompt_embed, "text_ids": text_ids} - def prepare_denoiser_input(self, noisy_latent, sample, condition): + def prepare_denoiser_input(self, noisy_latent): n = noisy_latent.shape[0] - packed = LongCatImagePipeline._pack_latents( - noisy_latent, - n, - noisy_latent.shape[1], - noisy_latent.shape[2], - noisy_latent.shape[3], - ) - latent_image_ids = prepare_pos_ids( + h, w = noisy_latent.shape[2], noisy_latent.shape[3] + packed = LongCatImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w) + img_ids = prepare_pos_ids( modality_id=1, type="image", - start=( - self.text_pipeline.tokenizer_max_length, - self.text_pipeline.tokenizer_max_length, - ), - height=noisy_latent.shape[2] // 2, - width=noisy_latent.shape[3] // 2, + start=(self.text_pipeline.tokenizer_max_length, self.text_pipeline.tokenizer_max_length), + height=h // 2, + width=w // 2, ).to(self.device) - return DenoiserInput( + return LongCatImageDenoiserInput( hidden_states=packed, - extra={ - "img_ids": latent_image_ids, - "height": noisy_latent.shape[2], - "width": noisy_latent.shape[3], - }, + img_ids=img_ids, + height=h, + width=w, ) def denoise(self, denoiser_input, timestep_or_sigma, condition): @@ -91,14 +83,43 @@ def denoise(self, denoiser_input, timestep_or_sigma, condition): guidance=None, encoder_hidden_states=condition["prompt_embed"], txt_ids=condition["text_ids"], - img_ids=denoiser_input.extra["img_ids"], + img_ids=denoiser_input.img_ids, return_dict=False, )[0] def postprocess_denoiser_output(self, prediction, denoiser_input): return LongCatImagePipeline._unpack_latents( prediction, - height=denoiser_input.extra["height"] * self.vae_scale_factor, - width=denoiser_input.extra["width"] * self.vae_scale_factor, + height=denoiser_input.height * self.vae_scale_factor, + width=denoiser_input.width * self.vae_scale_factor, vae_scale_factor=self.vae_scale_factor, ) + + def prepare_infer_latents(self, height, width, generator=None): + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + # latent shape: (batch=1, latent_channels, latent_h, latent_w) + shape = (1, self.vae.config.latent_channels, latent_h, latent_w) + return torch.randn(shape, generator=generator, device=self.device, dtype=self.running_dtype) + + def decode_latent(self, latent): + # Reverse the normalization from encode_to_latent: + # encode: normalized = (raw - shift) * scale + # decode: raw = normalized / scale + shift + shift = getattr(self.vae.config, "shift_factor", 0.0) + scale = getattr(self.vae.config, "scaling_factor", 1.0) + latent = latent / scale + shift + + image = self.vae.decode(latent).sample # (B, C, H, W) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.permute(0, 2, 3, 1).float().cpu().numpy() + return [Image.fromarray((img * 255).round().astype(np.uint8)) for img in image] + + def assemble_pipeline(self, scheduler=None): + return LongCatImagePipeline( + tokenizer=self.text_pipeline.tokenizer, + text_encoder=self.text_pipeline.text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler or self.text_pipeline.scheduler, + ).to(self.device) diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index ec0934fc1..08b11a47e 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -1,9 +1,20 @@ +from dataclasses import dataclass + import torch from diffusers import AutoencoderKLQwenImage, QwenImagePipeline, QwenImageTransformer2DModel +from diffusers.image_processor import VaeImageProcessor from lightx2v_train.utils.registry import MODEL_REGISTER -from .base import BaseModel, DenoiserInput +from .base import BaseModel + + +@dataclass +class QwenImageDenoiserInput: + hidden_states: torch.Tensor + img_shapes: list + height: int + width: int @MODEL_REGISTER("qwen_image") @@ -21,6 +32,8 @@ def load_components(self, transformer_only=False, reference_model=None): if reference_model is not None: self.text_pipeline = reference_model.text_pipeline self.vae = reference_model.vae + self.vae_scale_factor = reference_model.vae_scale_factor + self.image_processor = reference_model.image_processor self.transformer = self.load_transformer() return @@ -32,34 +45,23 @@ def load_components(self, transformer_only=False, reference_model=None): ).to(self.device) self.vae = AutoencoderKLQwenImage.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype) self.transformer = self.load_transformer() + + self.text_pipeline.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) def load_transformer(self): model_path = self.config["model"]["pretrained_model_name_or_path"] return QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) - def build_pipeline(self): - pipe = QwenImagePipeline( - scheduler=self.flow_matching, - vae=self.vae, - text_encoder=self.text_pipeline.text_encoder, - tokenizer=self.text_pipeline.tokenizer, - transformer=self.transformer, - ) - return pipe - - @property - def vae_scale_factor(self): - return 2 ** len(self.vae.temperal_downsample) - def encode_to_latent(self, sample): image = sample["target_image"].to(device=self.device, dtype=self.running_dtype) pixel_values = image.unsqueeze(2) - latent = self.vae.encode(pixel_values).latent_dist.sample() - latent = latent.permute(0, 2, 1, 3, 4) + latent = self.vae.encode(pixel_values).latent_dist.sample() # (B, C, T, H, W) - latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.running_dtype).view(1, 1, self.vae.config.z_dim, 1, 1) - latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.running_dtype).view(1, 1, self.vae.config.z_dim, 1, 1) + latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.running_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) + latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.running_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) return (latent - latent_mean) * latent_std def encode_condition(self, sample): @@ -78,23 +80,16 @@ def encode_prompt_condition(self, prompt): "prompt_embed_mask": prompt_embed_mask, } - def prepare_denoiser_input(self, noisy_latent, sample, condition): + def prepare_denoiser_input(self, noisy_latent): + # noisy_latent: (B, C, T, H, W) n = noisy_latent.shape[0] - packed = QwenImagePipeline._pack_latents( - noisy_latent, - n, - noisy_latent.shape[2], - noisy_latent.shape[3], - noisy_latent.shape[4], - ) - img_shapes = [(1, noisy_latent.shape[3] // 2, noisy_latent.shape[4] // 2)] * n - return DenoiserInput( + h, w = noisy_latent.shape[3], noisy_latent.shape[4] + packed = QwenImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w) + return QwenImageDenoiserInput( hidden_states=packed, - extra={ - "img_shapes": img_shapes, - "height": noisy_latent.shape[3], - "width": noisy_latent.shape[4], - }, + img_shapes=[(1, h // 2, w // 2)] * n, + height=h, + width=w, ) def denoise(self, denoiser_input, timestep_or_sigma, condition): @@ -107,17 +102,51 @@ def denoise_with_transformer(self, transformer, denoiser_input, timestep_or_sigm guidance=None, encoder_hidden_states_mask=condition["prompt_embed_mask"], encoder_hidden_states=condition["prompt_embed"], - img_shapes=denoiser_input.extra["img_shapes"], + img_shapes=denoiser_input.img_shapes, return_dict=False, )[0] def postprocess_denoiser_output(self, prediction, denoiser_input): return QwenImagePipeline._unpack_latents( prediction, - height=denoiser_input.extra["height"] * self.vae_scale_factor, - width=denoiser_input.extra["width"] * self.vae_scale_factor, + height=denoiser_input.height * self.vae_scale_factor, + width=denoiser_input.width * self.vae_scale_factor, vae_scale_factor=self.vae_scale_factor, ) - def prepare_flow_matching_target(self, velocity): - return velocity.permute(0, 2, 1, 3, 4) + def prepare_infer_latents(self, height, width, generator=None): + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + shape = (1, self.vae.config.z_dim, 1, latent_h, latent_w) + return torch.randn(shape, generator=generator, device=self.device, dtype=self.running_dtype) + + def decode_latent(self, latent): + # Reverse the normalization from encode_to_latent: + # encode: normalized = (raw - mean) / latents_std + # decode: raw = normalized * latents_std + mean + latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.running_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) + latent_std = torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.running_dtype).view(1, self.vae.config.z_dim, 1, 1, 1) + latent = latent * latent_std + latent_mean # (B, C, T, H, W), C == z_dim + + image = self.vae.decode(latent).sample # (B, C, T, H, W) + image = image[:, :, 0, :, :] # drop temporal dim -> (B, C, H, W), T == 1 + + return self.image_processor.postprocess(image, output_type="pil") + + def assemble_pipeline(self, scheduler=None): + return QwenImagePipeline( + tokenizer=self.text_pipeline.tokenizer, + text_encoder=self.text_pipeline.text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler or self.text_pipeline.scheduler, # use the original scheduler for bit-exact alignment with diffusers + ).to(self.device) + + def get_pipeline_infer_kwargs(self, infer_config): + # QwenImagePipeline uses `true_cfg_scale` instead of the standard `guidance_scale` + return { + "height": infer_config.get("height", 1024), + "width": infer_config.get("width", 1024), + "num_inference_steps": infer_config.get("num_inference_steps", 50), + "true_cfg_scale": infer_config.get("cfg_guidance_scale", 4.0), + } diff --git a/lightx2v_train/lightx2v_train/schedulers/flow_matching.py b/lightx2v_train/lightx2v_train/schedulers/flow_matching.py index 7ccce33b2..0c1fb67cb 100644 --- a/lightx2v_train/lightx2v_train/schedulers/flow_matching.py +++ b/lightx2v_train/lightx2v_train/schedulers/flow_matching.py @@ -1,5 +1,4 @@ import torch -from diffusers.schedulers.scheduling_utils import SchedulerOutput from lightx2v_train.utils.utils import get_running_dtype @@ -23,12 +22,8 @@ def __init__(self, config): self.time_shift_mu = scheduler_training_config.get("time_shift_mu", 5.0) self.time_shift_power = scheduler_training_config.get("time_shift_power", 1.0) - _sigmas = torch.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps) - self._train_sigmas = _sigmas - self._train_timesteps = _sigmas * self.num_train_timesteps - - self.sigmas = torch.cat([_sigmas, torch.zeros(1)]) - self.timesteps = self._train_timesteps + self.sigmas = None + self.timesteps = None self.num_inference_steps = None self.running_dtype = get_running_dtype(config["model"]["running_dtype"]) @@ -56,24 +51,33 @@ def add_noise(self, latent, noise, sigmas): def build_train_gt(self, latent, noise): return noise - latent - def set_timesteps(self, num_inference_steps, device=None): + def set_timesteps(self, num_inference_steps, sigmas=None): self.num_inference_steps = num_inference_steps - device = device or self.device - - sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) - self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device) - self.timesteps = (sigmas * self.num_train_timesteps).to(device) - def step(self, model_output, timestep, sample, return_dict=True): - step_index = (self.timesteps == timestep).nonzero()[0].item() + if sigmas is None: + sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) + if self.do_time_shift: + sigmas = self.time_shift(sigmas) + else: + sigmas = torch.tensor(sigmas, dtype=torch.float32) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(self.device) + self.timesteps = (sigmas * self.num_train_timesteps).to(self.device) + + def step(self, model_output, current_timestep, latent): + f""" + ADD NOISE: + x_t = (1 - sigma_t) * x_0 + sigma_t * N ------ self.add_noise(...) + => x_t = sigma_t * (N - x_0) + x_0 + => x_t = sigma_t * v + x_0 + REMOVE NOISE: + x_t = sigma_t * v + x_0 + x_t-1 = sigma_t-1 * v + x_0 + => x_t - x_t-1 = (sigma_t - sigma_t-1) * v + => x_t-1 = x_t + (sigma_t-1 - sigma_t) * v + => x_t-1 = x_t + (sigma_next - sigma) * model_output ------ (*) + """ + step_index = (self.timesteps == current_timestep).nonzero()[0].item() sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] - - prev_sample = sample + (sigma_next - sigma) * model_output - - if not return_dict: - return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample, timestep=None): - return sample + prev_sample = latent + (sigma_next - sigma) * model_output # ------ (*) from above + return prev_sample diff --git a/lightx2v_train/lightx2v_train/trainers/base.py b/lightx2v_train/lightx2v_train/trainers/base.py index 62e915a03..963ed31de 100644 --- a/lightx2v_train/lightx2v_train/trainers/base.py +++ b/lightx2v_train/lightx2v_train/trainers/base.py @@ -4,13 +4,17 @@ class BaseTrainer: def __init__(self, config): self.config = config + self.model_config = self.config["model"] + self.training_config = self.config["training"] + self.infer_config = self.config["inference"] + self.noise_scheduler = RectifiedFlowMatchingScheduler(config) def set_model(self, model): self.model = model - def set_data(self, dataloader, dataloader_eval=None): - self.dataloader = dataloader + def set_data(self, dataloader_train, dataloader_eval=None): + self.dataloader_train = dataloader_train self.dataloader_eval = dataloader_eval def train(self): diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index ab07a2d1d..1cc50b9cd 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -106,8 +106,8 @@ def _latent_shape(self, sample): latent_channels = self.model.transformer.config.in_channels // 4 return ( batch_size, - 1, int(latent_channels), + 1, image.shape[-2] // self.model.vae_scale_factor, image.shape[-1] // self.model.vae_scale_factor, ) @@ -124,10 +124,10 @@ def _encode_conditions(self, sample): return condition, negative_condition def _predict_velocity(self, model, latents, sigma, condition): - denoiser_input = model.prepare_denoiser_input(latents, {}, condition) + denoiser_input = model.prepare_denoiser_input(latents) prediction = model.denoise(denoiser_input, sigma, condition) prediction = model.postprocess_denoiser_output(prediction, denoiser_input) - return model.prepare_flow_matching_target(prediction) + return prediction def sample_initial_latents(self, latent_shape): return torch.randn(latent_shape, device=self.model.device, dtype=self.running_dtype) @@ -181,7 +181,7 @@ def forward_loss(self, sample, stage): def train(self): self.setup() - os.makedirs(self.output_dir, exist_ok=True) + os.makedirs(self.output_train_dir, exist_ok=True) max_train_iters = self.max_train_iters fake_update_ratio = self.fake_update_ratio @@ -195,7 +195,7 @@ def train(self): progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations") while current_iter < max_train_iters: - for sample in self.dataloader: + for sample in self.dataloader_train: loss_dmd = self.forward_loss(sample, stage="generator") loss_dmd.backward() torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), max_grad_norm) diff --git a/lightx2v_train/lightx2v_train/trainers/lora.py b/lightx2v_train/lightx2v_train/trainers/lora.py index 34fb50337..776e1c470 100644 --- a/lightx2v_train/lightx2v_train/trainers/lora.py +++ b/lightx2v_train/lightx2v_train/trainers/lora.py @@ -4,6 +4,7 @@ from diffusers.optimization import get_scheduler from tqdm.auto import tqdm +from lightx2v_train.infer import build_inferencer from lightx2v_train.runtime.checkpoint import prune_checkpoints from lightx2v_train.utils.registry import TRAINER_REGISTER from lightx2v_train.utils.utils import get_running_dtype @@ -14,39 +15,43 @@ @TRAINER_REGISTER("lora") class LoraTrainer(BaseTrainer): def get_configs(self): - model_config = self.config["model"] - self.running_dtype = get_running_dtype(model_config["running_dtype"]) + self.running_dtype = get_running_dtype(self.model_config["running_dtype"]) - training_config = self.config["training"] - - lora_config = training_config.get("lora", {}) + lora_config = self.training_config.get("lora", {}) self.lora_rank = lora_config.get("rank", 16) self.lora_alpha = lora_config.get("alpha", self.lora_rank) self.lora_target_modules = lora_config.get("target_modules") - self.gradient_checkpointing = training_config.get("gradient_checkpointing", True) + self.gradient_checkpointing = self.training_config.get("gradient_checkpointing", True) - optimizer_config = training_config.get("optimizer", {}) + optimizer_config = self.training_config.get("optimizer", {}) self.optimizer_learning_rate = optimizer_config.get("learning_rate", 1e-4) self.optimizer_adam_beta1 = optimizer_config.get("adam_beta1", 0.9) self.optimizer_adam_beta2 = optimizer_config.get("adam_beta2", 0.999) self.optimizer_weight_decay = optimizer_config.get("weight_decay", 0.01) self.optimizer_adam_epsilon = optimizer_config.get("adam_epsilon", 1e-8) - self.lr_scheduler_name = training_config.get("lr_scheduler", "constant") - self.lr_warmup_iters = training_config["lr_warmup_iters"] - self.max_train_iters = training_config["max_train_iters"] + self.lr_scheduler_name = self.training_config.get("lr_scheduler", "constant") + self.lr_warmup_iters = self.training_config["lr_warmup_iters"] + self.max_train_iters = self.training_config["max_train_iters"] + + self.output_train_dir = self.training_config["output_dir"] + self.gradient_accumulation_iters = self.training_config["gradient_accumulation_iters"] + self.max_grad_norm = self.training_config.get("max_grad_norm", 1.0) + self.save_every_iters = self.training_config["save_every_iters"] + self.save_total_limit = self.training_config["save_total_limit"] - self.output_dir = training_config["output_dir"] - self.gradient_accumulation_iters = training_config["gradient_accumulation_iters"] - self.max_grad_norm = training_config.get("max_grad_norm", 1.0) - self.save_every_iters = training_config["save_every_iters"] - self.save_total_limit = training_config["save_total_limit"] + self.infer_every_iters = self.infer_config.get("infer_every_iters", None) def setup(self): self.get_configs() self.setup_lora() + if self.infer_every_iters: + self.inferencer = build_inferencer(self.config) + self.inferencer.set_model(self.model) + # set_data is deferred to train() when dataloader_eval is available + self.optimizer = self.build_optimizer(self.model.trainable_parameters()) self.lr_scheduler = self.build_lr_scheduler(self.optimizer) @@ -93,17 +98,17 @@ def compute_loss_on_sample(self, sample): noisy_latent = self.noise_scheduler.add_noise(latent, noise, timestep_or_sigma) condition = self.model.encode_condition(sample) - denoiser_input = self.model.prepare_denoiser_input(noisy_latent, sample, condition) + denoiser_input = self.model.prepare_denoiser_input(noisy_latent) prediction = self.model.denoise(denoiser_input, timestep_or_sigma, condition) prediction = self.model.postprocess_denoiser_output(prediction, denoiser_input) - target = self.model.prepare_flow_matching_target(self.noise_scheduler.build_train_gt(latent, noise)) + target = self.noise_scheduler.build_train_gt(latent, noise) loss = torch.mean(((prediction.float() - target.float()) ** 2).reshape(target.shape[0], -1), dim=1) return loss.mean() def train(self): self.setup() - os.makedirs(self.output_dir, exist_ok=True) + os.makedirs(self.output_train_dir, exist_ok=True) max_train_iters = self.max_train_iters grad_accum_iters = self.gradient_accumulation_iters @@ -115,8 +120,12 @@ def train(self): running_loss = 0.0 progress = tqdm(total=max_train_iters, desc="Training iterations") + if self.infer_every_iters: + self.inferencer.set_data(self.dataloader_eval) + self.run_inference(current_iter) + while current_iter < max_train_iters: - for sample in self.dataloader: + for sample in self.dataloader_train: loss = self.compute_loss_on_sample(sample) (loss / grad_accum_iters).backward() running_loss += loss.item() / grad_accum_iters @@ -138,15 +147,27 @@ def train(self): if save_every_iters and current_iter % save_every_iters == 0: self.save_checkpoint(current_iter, save_total_limit) + if self.infer_every_iters and current_iter % self.infer_every_iters == 0: + self.run_inference(current_iter) + if current_iter >= max_train_iters: break progress.close() + def run_inference(self, current_iter): + base_output_dir = self.infer_config.get("output_dir", "./output_infer") + iter_output_dir = os.path.join(base_output_dir, f"iter-{current_iter:09d}") + + self.inferencer.output_infer_dir = iter_output_dir + os.makedirs(iter_output_dir, exist_ok=True) + self.inferencer.infer() + + self.model.set_lora_trainable() + def save_checkpoint(self, iteration, save_total_limit): - output_dir = self.output_dir - prune_checkpoints(output_dir, save_total_limit) + prune_checkpoints(self.output_train_dir, save_total_limit) - save_dir = os.path.join(output_dir, f"checkpoint-{iteration}") + save_dir = os.path.join(self.output_train_dir, f"checkpoint-{iteration:09d}") os.makedirs(save_dir, exist_ok=True) self.model.save_lora_weights(save_dir) diff --git a/lightx2v_train/lightx2v_train/utils/registry.py b/lightx2v_train/lightx2v_train/utils/registry.py index a03bbf0ec..ef129e83e 100644 --- a/lightx2v_train/lightx2v_train/utils/registry.py +++ b/lightx2v_train/lightx2v_train/utils/registry.py @@ -55,6 +55,7 @@ def merge(self, other_register): MODEL_REGISTER = Register() TRAINER_REGISTER = Register() +INFERENCER_REGISTER = Register() DATA_REGISTER = Register() @@ -74,6 +75,14 @@ def build_trainer(config): return TRAINER_REGISTER[name](config) +def build_inferencer(config): + name = config["inference"]["method"] + if name not in INFERENCER_REGISTER: + available = ", ".join(sorted(INFERENCER_REGISTER.keys())) + raise ValueError(f"Unknown inferencer {name!r}. Available inferencers: {available}") + return INFERENCER_REGISTER[name](config) + + def build_data(config, train_or_val): data_config = config.get("data", {}) if train_or_val not in data_config: From d705f4d34f8834e59be8085fd6ef94c724e65de8 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 09:21:40 +0000 Subject: [PATCH 08/17] revert --- .../lightx2v_train/model_zoo/base.py | 2 +- .../lightx2v_train/model_zoo/longcat_image.py | 2 +- .../lightx2v_train/trainers/lora.py | 21 ++++--------------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/lightx2v_train/lightx2v_train/model_zoo/base.py b/lightx2v_train/lightx2v_train/model_zoo/base.py index 4213616a3..2e1c5c0be 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/base.py +++ b/lightx2v_train/lightx2v_train/model_zoo/base.py @@ -61,7 +61,7 @@ def encode_to_latent(self, sample): def encode_condition(self, sample): raise NotImplementedError - def prepare_denoiser_input(self, noisy_latent): + def prepare_denoiser_input(self, noisy_latent, sample, condition): raise NotImplementedError def denoise(self, denoiser_input, timesteps, condition): diff --git a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py index 5fdb76bc2..3158d3aca 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py @@ -58,7 +58,7 @@ def encode_condition(self, sample): ) return {"prompt_embed": prompt_embed, "text_ids": text_ids} - def prepare_denoiser_input(self, noisy_latent): + def prepare_denoiser_input(self, noisy_latent, sample, condition): n = noisy_latent.shape[0] h, w = noisy_latent.shape[2], noisy_latent.shape[3] packed = LongCatImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w) diff --git a/lightx2v_train/lightx2v_train/trainers/lora.py b/lightx2v_train/lightx2v_train/trainers/lora.py index 0c4817e60..2217eeda7 100644 --- a/lightx2v_train/lightx2v_train/trainers/lora.py +++ b/lightx2v_train/lightx2v_train/trainers/lora.py @@ -45,17 +45,6 @@ def get_configs(self): def setup(self): self.get_configs() - self.setup_lora() - - if self.infer_every_iters: - self.inferencer = build_inferencer(self.config) - self.inferencer.set_model(self.model) - # set_data is deferred to train() when dataloader_eval is available - - self.optimizer = self.build_optimizer(self.model.trainable_parameters()) - self.lr_scheduler = self.build_lr_scheduler(self.optimizer) - - def setup_lora(self): self.model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) self.model.set_lora_trainable() if self.gradient_checkpointing: @@ -73,13 +62,11 @@ def setup_lora(self): weight_decay=self.optimizer_weight_decay, eps=self.optimizer_adam_epsilon, ) - - def build_lr_scheduler(self, optimizer, num_warmup_steps=None, num_training_steps=None): - return get_scheduler( + self.lr_scheduler = get_scheduler( self.lr_scheduler_name, - optimizer=optimizer, - num_warmup_steps=self.lr_warmup_iters if num_warmup_steps is None else num_warmup_steps, - num_training_steps=self.max_train_iters if num_training_steps is None else num_training_steps, + optimizer=self.optimizer, + num_warmup_steps=self.lr_warmup_iters, + num_training_steps=self.max_train_iters, ) def compute_loss_on_sample(self, sample): From a07ad174363a3a1bc2b0a5f2cecff7edcb043488 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 09:42:54 +0000 Subject: [PATCH 09/17] update --- .../lightx2v_train/trainers/dmd_lora.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 1cc50b9cd..4adca0563 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from lightx2v_train.model_zoo import build_model @@ -15,9 +16,7 @@ class DmdLoraTrainer(LoraTrainer): def get_configs(self): super().get_configs() - - training_config = self.config["training"] - fake_config = training_config.get("fake", {}) + fake_config = self.training_config.get("fake", {}) fake_optimizer_config = fake_config.get("optimizer", {}) self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) self.fake_optimizer_adam_beta1 = fake_optimizer_config.get("adam_beta1", self.optimizer_adam_beta1) @@ -25,7 +24,7 @@ def get_configs(self): self.fake_optimizer_weight_decay = fake_optimizer_config.get("weight_decay", self.optimizer_weight_decay) self.fake_optimizer_adam_epsilon = fake_optimizer_config.get("adam_epsilon", self.optimizer_adam_epsilon) - self.dmd_config = training_config.get("dmd", {}) + self.dmd_config = self.training_config.get("dmd", {}) self.num_inference_steps = int(self.dmd_config.get("num_inference_steps", 4)) self.fake_update_ratio = int(self.dmd_config.get("fake_update_ratio", 1)) self.guidance_scale = float(self.dmd_config.get("guidance_scale", 3.0)) @@ -33,11 +32,7 @@ def get_configs(self): self.cfg_norm = self.dmd_config.get("cfg_norm", "layer_norm") def setup(self): - self.get_configs() - print("[dmd_lora] single-GPU resident mode: student/fake/teacher transformers stay on CUDA") - - self.setup_lora() - + super().setup() self.fake_model = build_model(self.config) self.fake_model.load_components(transformer_only=True, reference_model=self.model) self.fake_model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) @@ -50,21 +45,20 @@ def setup(self): self.teacher_model.transformer.requires_grad_(False) self.teacher_model.transformer.eval() - self.optimizer = self.build_optimizer(self.model.trainable_parameters()) - self.fake_optimizer = self.build_optimizer( + self.fake_optimizer = torch.optim.AdamW( self.fake_model.trainable_parameters(), - learning_rate=self.fake_optimizer_learning_rate, - adam_beta1=self.fake_optimizer_adam_beta1, - adam_beta2=self.fake_optimizer_adam_beta2, + lr=self.fake_optimizer_learning_rate, + betas=(self.fake_optimizer_adam_beta1, self.fake_optimizer_adam_beta2), weight_decay=self.fake_optimizer_weight_decay, - adam_epsilon=self.fake_optimizer_adam_epsilon, + eps=self.fake_optimizer_adam_epsilon, ) - self.lr_scheduler = self.build_lr_scheduler(self.optimizer) - self.fake_lr_scheduler = self.build_lr_scheduler( - self.fake_optimizer, + self.fake_lr_scheduler = get_scheduler( + self.lr_scheduler_name, + optimizer=self.fake_optimizer, num_warmup_steps=0, num_training_steps=max(1, self.max_train_iters * self.fake_update_ratio), ) + self.scheduler = DMDFlowMatchingScheduler(self.config, self.dmd_config) print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") From b9f0b3bdd7caec8141576b5c2843ebbf0f72f70d Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 10:10:07 +0000 Subject: [PATCH 10/17] u --- lightx2v_train/lightx2v_train/model_zoo/qwen_image.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index 08b11a47e..652b2d0ca 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -27,7 +27,6 @@ class QwenImageModel(BaseModel): pipeline_cls = QwenImagePipeline def load_components(self, transformer_only=False, reference_model=None): - model_path = self.config["model"]["pretrained_model_name_or_path"] if transformer_only: if reference_model is not None: self.text_pipeline = reference_model.text_pipeline @@ -36,6 +35,7 @@ def load_components(self, transformer_only=False, reference_model=None): self.image_processor = reference_model.image_processor self.transformer = self.load_transformer() return + model_path = self.config["model"]["pretrained_model_name_or_path"] self.text_pipeline = QwenImagePipeline.from_pretrained( model_path, @@ -93,10 +93,7 @@ def prepare_denoiser_input(self, noisy_latent): ) def denoise(self, denoiser_input, timestep_or_sigma, condition): - return self.denoise_with_transformer(self.transformer, denoiser_input, timestep_or_sigma, condition) - - def denoise_with_transformer(self, transformer, denoiser_input, timestep_or_sigma, condition): - return transformer( + return self.transformer( hidden_states=denoiser_input.hidden_states, timestep=timestep_or_sigma, # timestep_or_sigma is in [0, 1] not [0, 1000] guidance=None, From b972e96cbf23d1ca7b04cb6b04ee3a52589417b3 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 10:22:53 +0000 Subject: [PATCH 11/17] u --- lightx2v_train/lightx2v_train/trainers/dmd_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 4adca0563..ca0ef2b0f 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -14,8 +14,8 @@ @TRAINER_REGISTER("dmd_lora") class DmdLoraTrainer(LoraTrainer): - def get_configs(self): - super().get_configs() + def __init__(self, config): + super().__init__(config) fake_config = self.training_config.get("fake", {}) fake_optimizer_config = fake_config.get("optimizer", {}) self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) From 77b8fe755caca271ef178acbf67227a03ac1519f Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 10:24:29 +0000 Subject: [PATCH 12/17] u --- lightx2v_train/lightx2v_train/model_zoo/qwen_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index 652b2d0ca..f8f4ee947 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -36,7 +36,6 @@ def load_components(self, transformer_only=False, reference_model=None): self.transformer = self.load_transformer() return model_path = self.config["model"]["pretrained_model_name_or_path"] - self.text_pipeline = QwenImagePipeline.from_pretrained( model_path, transformer=None, From 2cabca3fbfcfb63f64a74594c61ac7ac6a8ccb05 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Fri, 15 May 2026 10:48:15 +0000 Subject: [PATCH 13/17] resume --- .../lightx2v_train/trainers/dmd_lora.py | 63 +++++++++++++++++-- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index ca0ef2b0f..1f71502fb 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -1,4 +1,5 @@ import os +import shutil import torch import torch.nn.functional as F @@ -6,6 +7,7 @@ from tqdm.auto import tqdm from lightx2v_train.model_zoo import build_model +from lightx2v_train.runtime.checkpoint import prune_checkpoints from lightx2v_train.schedulers import DMDFlowMatchingScheduler from lightx2v_train.utils.registry import TRAINER_REGISTER @@ -31,8 +33,8 @@ def __init__(self, config): self.negative_prompt = self.dmd_config.get("negative_prompt", " ") self.cfg_norm = self.dmd_config.get("cfg_norm", "layer_norm") - def setup(self): - super().setup() + def setup(self, resume_ckpt_path=None): + super().setup(resume_ckpt_path=resume_ckpt_path) self.fake_model = build_model(self.config) self.fake_model.load_components(transformer_only=True, reference_model=self.model) self.fake_model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) @@ -61,6 +63,9 @@ def setup(self): self.scheduler = DMDFlowMatchingScheduler(self.config, self.dmd_config) + if resume_ckpt_path is not None: + self.load_resume_ckpt(resume_ckpt_path) + print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_model.transformer)}") @@ -174,7 +179,8 @@ def forward_loss(self, sample, stage): return self._dmd_loss(x0, x_pred_fake, x_pred_teacher) def train(self): - self.setup() + resume_ckpt_path, current_iter = self._resolve_resume() + self.setup(resume_ckpt_path=resume_ckpt_path) os.makedirs(self.output_train_dir, exist_ok=True) max_train_iters = self.max_train_iters @@ -182,11 +188,10 @@ def train(self): max_grad_norm = self.max_grad_norm save_every_iters = self.save_every_iters save_total_limit = self.save_total_limit - current_iter = 0 running_dmd = 0.0 running_fake = 0.0 - progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations") + progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations", initial=current_iter) while current_iter < max_train_iters: for sample in self.dataloader_train: @@ -226,3 +231,51 @@ def train(self): break progress.close() + + def load_resume_ckpt(self, resume_ckpt_path): + training_state_path = os.path.join(resume_ckpt_path, "training_state.pt") + fake_lora_path = os.path.join(resume_ckpt_path, "fake_lora") + fake_lora_weights_path = os.path.join(fake_lora_path, "pytorch_lora_weights.safetensors") + + if os.path.exists(fake_lora_weights_path): + self.fake_model.load_lora_weights_for_resume(fake_lora_path) + else: + print(f"Warning: fake LoRA weights not found in {fake_lora_path}. Fake model not restored.") + + if not os.path.exists(training_state_path): + return + + state = torch.load(training_state_path, map_location="cpu", weights_only=False) + if "fake_optimizer" in state: + self.fake_optimizer.load_state_dict(state["fake_optimizer"]) + else: + print(f"Warning: fake optimizer state not found in {training_state_path}.") + + if "fake_lr_scheduler" in state: + self.fake_lr_scheduler.load_state_dict(state["fake_lr_scheduler"]) + else: + print(f"Warning: fake lr scheduler state not found in {training_state_path}.") + + def save_checkpoint(self, iteration, save_total_limit): + prune_checkpoints(self.output_train_dir, save_total_limit) + + save_dir = os.path.join(self.output_train_dir, f"checkpoint-{iteration:09d}") + os.makedirs(save_dir, exist_ok=True) + self.model.save_lora_weights(save_dir) + + fake_save_dir = os.path.join(save_dir, "fake_lora") + os.makedirs(fake_save_dir, exist_ok=True) + self.fake_model.save_lora_weights(fake_save_dir) + + config_path = self.config.get("config_path") + if config_path is not None: + shutil.copy2(config_path, os.path.join(save_dir, "config.yaml")) + + training_state = { + "iteration": iteration, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "fake_optimizer": self.fake_optimizer.state_dict(), + "fake_lr_scheduler": self.fake_lr_scheduler.state_dict(), + } + torch.save(training_state, os.path.join(save_dir, "training_state.pt")) From d96ae23e5f79be5aa177641d355865ffe7279055 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Mon, 18 May 2026 10:31:18 +0000 Subject: [PATCH 14/17] refactor scheduler --- .../schedulers/dmd_scheduler.py | 21 +++---------- .../lightx2v_train/trainers/dmd_lora.py | 31 ++++++++++++------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py index fdc0feb53..756542fe1 100644 --- a/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py +++ b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py @@ -12,13 +12,6 @@ def __init__(self, config, dmd_config={}): self.max_sigma = float(dmd_config.get("sigma_max", 1.0)) self.discrete_samples = int(dmd_config.get("discrete_samples", 1000)) - @staticmethod - def expand_to(value, target): - value = value.to(device=target.device) - while value.ndim < target.ndim: - value = value.view(*value.shape, 1) - return value - @staticmethod def linear_shift(mu, t): return mu / (mu + (1 / t - 1)) @@ -54,17 +47,11 @@ def sample_renoise_sigma(self, batch_size, device=None, dtype=None): return sigma def add_noise(self, latent, noise, sigmas): - sigmas = self.expand_to(sigmas, latent).to(dtype=torch.float32) - return ((1.0 - sigmas) * latent.float() + sigmas * noise.float()).to(dtype=latent.dtype) - - def euler_step(self, sample, velocity, sigma, target_sigma): - sigma = self.expand_to(sigma, sample).to(dtype=torch.float32) - target_sigma = self.expand_to(target_sigma, sample).to(dtype=torch.float32) - return sample.float() + (target_sigma - sigma) * velocity.float() + return ((1.0 - sigmas) * latent + sigmas * noise).to(dtype=latent.dtype) - def step_by_index(self, model_output, step_idx, sample): + def step_by_index(self, velocity, step_idx, sample): sigma = self.sigma_at(step_idx, sample.shape[0], device=sample.device) sigma_next = self.sigma_at(int(step_idx) + 1, sample.shape[0], device=sample.device) - x0 = sample.float() - self.expand_to(sigma, sample).float() * model_output.float() - next_sample = self.euler_step(sample, model_output, sigma, sigma_next) + next_sample = sample + (sigma_next - sigma) * velocity + x0 = sample - sigma * velocity return next_sample.to(sample.dtype), x0.to(sample.dtype) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 1f71502fb..911800d51 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -32,6 +32,7 @@ def __init__(self, config): self.guidance_scale = float(self.dmd_config.get("guidance_scale", 3.0)) self.negative_prompt = self.dmd_config.get("negative_prompt", " ") self.cfg_norm = self.dmd_config.get("cfg_norm", "layer_norm") + self.image_sizes = self.dmd_config.get("image_sizes", []) def setup(self, resume_ckpt_path=None): super().setup(resume_ckpt_path=resume_ckpt_path) @@ -100,6 +101,13 @@ def _dmd_loss(latents, x_pred_fake_flow, x_pred_teacher): def _latent_shape(self, sample): image = sample["target_image"] batch_size = image.shape[0] + if self.image_sizes: + height, width = self.image_sizes[ + torch.randint(0, len(self.image_sizes), (1,), device=self.model.device).item() + ] + else: + height, width = image.shape[-2], image.shape[-1] + latent_channels = getattr(self.model.vae.config, "z_dim", None) if latent_channels is None: latent_channels = self.model.transformer.config.in_channels // 4 @@ -107,8 +115,8 @@ def _latent_shape(self, sample): batch_size, int(latent_channels), 1, - image.shape[-2] // self.model.vae_scale_factor, - image.shape[-1] // self.model.vae_scale_factor, + height // self.model.vae_scale_factor, + width // self.model.vae_scale_factor, ) def _encode_conditions(self, sample): @@ -148,9 +156,8 @@ def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enable xt, x0 = self.scheduler.step_by_index(velocity, idx, xt) return x0 - def forward_loss(self, sample, stage): - condition, negative_condition = self._encode_conditions(sample) - latent_shape = self._latent_shape(sample) + def forward_loss(self, latent_shape, conditions, stage): + condition, negative_condition = conditions end_step_idx = self.sample_end_step() xt_start = self.sample_initial_latents(latent_shape) x0_ref = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=False, xt=xt_start) @@ -158,11 +165,11 @@ def forward_loss(self, sample, stage): sigma = self.scheduler.sample_renoise_sigma(latent_shape[0], device=self.model.device, dtype=self.running_dtype) noise = torch.randn(latent_shape, device=self.model.device, dtype=torch.float32) renoised_xt = self.scheduler.add_noise(x0_ref, noise, sigma) - velocity_gt = self.scheduler.build_train_gt(x0_ref.float(), noise) if stage == "fake": self.fake_model.transformer.train() velocity_fake = self._predict_velocity(self.fake_model, renoised_xt, sigma, condition) + velocity_gt = self.scheduler.build_train_gt(x0_ref.float(), noise) return F.mse_loss(velocity_fake.float(), velocity_gt.float(), reduction="mean") with torch.no_grad(): @@ -172,9 +179,8 @@ def forward_loss(self, sample, stage): velocity_teacher_uncond = self._predict_velocity(self.teacher_model, renoised_xt, sigma, negative_condition) velocity_teacher = self._do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) - zeros = torch.zeros_like(sigma) - x_pred_fake = self.scheduler.euler_step(renoised_xt, velocity_fake, sigma, zeros) - x_pred_teacher = self.scheduler.euler_step(renoised_xt, velocity_teacher, sigma, zeros) + x_pred_fake = renoised_xt - sigma * velocity_fake + x_pred_teacher = renoised_xt - sigma * velocity_teacher x0 = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=True, xt=xt_start) return self._dmd_loss(x0, x_pred_fake, x_pred_teacher) @@ -195,7 +201,10 @@ def train(self): while current_iter < max_train_iters: for sample in self.dataloader_train: - loss_dmd = self.forward_loss(sample, stage="generator") + conditions = self._encode_conditions(sample) + latent_shape = self._latent_shape(sample) + + loss_dmd = self.forward_loss(latent_shape, conditions, stage="student") loss_dmd.backward() torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), max_grad_norm) self.optimizer.step() @@ -205,7 +214,7 @@ def train(self): fake_loss = 0.0 for _ in range(fake_update_ratio): - loss_fake = self.forward_loss(sample, stage="fake") + loss_fake = self.forward_loss(latent_shape, conditions, stage="fake") loss_fake.backward() torch.nn.utils.clip_grad_norm_(self.fake_model.transformer.parameters(), max_grad_norm) self.fake_optimizer.step() From cc776389b0a69641dd9cdacc838d4bd1a68031ec Mon Sep 17 00:00:00 2001 From: Musisoul Date: Mon, 18 May 2026 10:31:50 +0000 Subject: [PATCH 15/17] lint --- lightx2v_train/lightx2v_train/trainers/dmd_lora.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 911800d51..3ba421c9a 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -102,9 +102,7 @@ def _latent_shape(self, sample): image = sample["target_image"] batch_size = image.shape[0] if self.image_sizes: - height, width = self.image_sizes[ - torch.randint(0, len(self.image_sizes), (1,), device=self.model.device).item() - ] + height, width = self.image_sizes[torch.randint(0, len(self.image_sizes), (1,), device=self.model.device).item()] else: height, width = image.shape[-2], image.shape[-1] From d7e36d98f54c27cc4b987f7a023e9af0a7d16087 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Mon, 18 May 2026 10:33:13 +0000 Subject: [PATCH 16/17] update yaml --- .../configs/dmd_lora/qwen_image_dmd_lora.yaml | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml diff --git a/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml new file mode 100644 index 000000000..831372175 --- /dev/null +++ b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml @@ -0,0 +1,106 @@ +model: + name: qwen_image + pretrained_model_name_or_path: /path/to/Qwen/Qwen-Image + max_sequence_length: 1024 + running_dtype: bf16 + +data: + train: + name: image_dataset + num_workers: 8 + prompt_dropout_rate: 0.0 + target_area: 1048576 # 1024 * 1024 + shuffle: true + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl + val: + name: image_dataset + num_workers: 8 + shuffle: false + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl + +scheduler: + num_train_timesteps: 1000 + timestep_distribution: uniform + min_t: 0.001 + max_t: 1.0 + time_shift_settings: + do_time_shift: true + shift_type: exponential + # shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p) + time_shift_power: 1.0 + dynamic_shift: true + patch_size: [2, 2] # [H, W] + # https://github.com/huggingface/diffusers/blob/v0.38.0/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L59 + shift_x1: 256 + shift_x2: 4096 + shift_y1: 0.5 + shift_y2: 1.15 + +training: + method: dmd_lora + max_train_iters: 10 + gradient_accumulation_iters: 1 + gradient_checkpointing: true + max_grad_norm: 1.0 + lr_scheduler: constant + lr_warmup_iters: 10 + save_every_iters: 5 + save_total_limit: 10 + lora: + rank: 32 + alpha: 32 + target_modules: + - to_k + - to_q + - to_v + - to_out.0 + # - add_q_proj + # - add_k_proj + # - add_v_proj + # - to_add_out + # - img_mlp.net.0.proj + # - img_mlp.net.2 + # - txt_mlp.net.0.proj + # - txt_mlp.net.2 + optimizer: + learning_rate: 0.0001 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + fake: + optimizer: + learning_rate: 0.00002 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + dmd: + num_inference_steps: 4 + fake_update_ratio: 2 + guidance_scale: 4.0 + negative_prompt: " " + cfg_norm: layer_norm + image_sizes: + - [1024, 1024] + - [768, 1344] + - [1344, 768] + sigma_min: 0.02 + sigma_max: 1.0 + discrete_samples: 1000 + renoise_shift: 5.0 + inference_shift: 3.0 + output_dir: ./output_train/qwen_image_dmd_lora + +inference: + method: image_infer + default_width: 1024 + default_height: 1024 + num_inference_steps: 4 + cfg_guidance_scale: 4.0 + negative_prompt: " " + +resume: + auto_resume: true From 7fd2a36d935aea26616dd56edc986420b72cf586 Mon Sep 17 00:00:00 2001 From: Musisoul Date: Wed, 27 May 2026 09:18:23 +0000 Subject: [PATCH 17/17] inference in training --- lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml | 7 +++++-- lightx2v_train/lightx2v_train/trainers/dmd_lora.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml index 831372175..1f35e2015 100644 --- a/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml +++ b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml @@ -40,13 +40,13 @@ scheduler: training: method: dmd_lora - max_train_iters: 10 + max_train_iters: 1000 gradient_accumulation_iters: 1 gradient_checkpointing: true max_grad_norm: 1.0 lr_scheduler: constant lr_warmup_iters: 10 - save_every_iters: 5 + save_every_iters: 100 save_total_limit: 10 lora: rank: 32 @@ -101,6 +101,9 @@ inference: num_inference_steps: 4 cfg_guidance_scale: 4.0 negative_prompt: " " + enable_cfg: false + output_dir: ./output_infer/qwen_image_dmd_lora + infer_every_iters: ${training.save_every_iters} resume: auto_resume: true diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py index 3ba421c9a..df0e8c793 100644 --- a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -196,6 +196,10 @@ def train(self): running_fake = 0.0 progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations", initial=current_iter) + if self.infer_every_iters: + self.inferencer.set_data(self.dataloader_eval) + if current_iter == 0: + self.run_inference(current_iter) while current_iter < max_train_iters: for sample in self.dataloader_train: @@ -234,6 +238,9 @@ def train(self): if save_every_iters and current_iter % save_every_iters == 0: self.save_checkpoint(current_iter, save_total_limit) + if self.infer_every_iters and current_iter % self.infer_every_iters == 0: + self.run_inference(current_iter) + if current_iter >= max_train_iters: break