diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a81beca9748..6c0cebc8de09 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -399,6 +399,8 @@ else: _import_structure["modular_pipelines"].extend( [ + "Flux2AutoBlocks", + "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1091,6 +1093,8 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + Flux2AutoBlocks, + Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 252b9f33dfe8..b792a7923cb9 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -52,6 +52,10 @@ "FluxKontextAutoBlocks", "FluxKontextModularPipeline", ] + _import_structure["flux2"] = [ + "Flux2AutoBlocks", + "Flux2ModularPipeline", + ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", "QwenImageModularPipeline", @@ -71,6 +75,7 @@ else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline + from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py new file mode 100644 index 000000000000..0af7819dd0cf --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -0,0 +1,123 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = [ + "Flux2TextEncoderStep", + "Flux2RemoteTextEncoderStep", + "Flux2ProcessImagesInputStep", + "Flux2VaeEncoderStep", + ] + _import_structure["before_denoise"] = [ + "Flux2SetTimestepsStep", + "Flux2PrepareLatentsStep", + "Flux2RoPEInputsStep", + "Flux2PrepareImageLatentsStep", + ] + _import_structure["denoise"] = [ + "Flux2LoopDenoiser", + "Flux2LoopAfterDenoiser", + "Flux2DenoiseLoopWrapper", + "Flux2DenoiseStep", + ] + _import_structure["decoders"] = ["Flux2DecodeStep"] + _import_structure["inputs"] = [ + "Flux2TextInputStep", + "Flux2ImageInputStep", + ] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "IMAGE_CONDITIONED_BLOCKS", + "Flux2AutoBeforeDenoiseStep", + "Flux2AutoBlocks", + "Flux2AutoDecodeStep", + "Flux2AutoDenoiseStep", + "Flux2AutoImageInputStep", + "Flux2AutoTextEncoderStep", + "Flux2AutoTextInputStep", + "Flux2AutoVaeEncoderStep", + "Flux2BeforeDenoiseStep", + "Flux2VaeEncoderSequentialStep", + ] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, + ) + from .decoders import Flux2DecodeStep + from .denoise import ( + Flux2DenoiseLoopWrapper, + Flux2DenoiseStep, + Flux2LoopAfterDenoiser, + Flux2LoopDenoiser, + ) + from .encoders import ( + Flux2ProcessImagesInputStep, + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, + ) + from .inputs import ( + Flux2ImageInputStep, + Flux2TextInputStep, + ) + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + IMAGE_CONDITIONED_BLOCKS, + TEXT2IMAGE_BLOCKS, + Flux2AutoBeforeDenoiseStep, + Flux2AutoBlocks, + Flux2AutoDecodeStep, + Flux2AutoDenoiseStep, + Flux2AutoImageInputStep, + Flux2AutoTextEncoderStep, + Flux2AutoTextInputStep, + Flux2AutoVaeEncoderStep, + Flux2BeforeDenoiseStep, + Flux2VaeEncoderSequentialStep, + ) + from .modular_pipeline import Flux2ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py new file mode 100644 index 000000000000..b4aed76945a1 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -0,0 +1,505 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + """Compute empirical mu for Flux2 timestep scheduling.""" + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Flux2SetTimestepsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("guidance_scale", default=4.0), + InputParam("latents", type_hint=torch.Tensor), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + scheduler = components.scheduler + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + vae_scale_factor = components.vae_scale_factor + + latent_height = 2 * (int(height) // (vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + timesteps = block_state.timesteps + + if timesteps is None and sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + block_state.device, + timesteps=timesteps, + sigmas=sigmas, + mu=mu, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"), + ] + + @staticmethod + def check_inputs(components, block_state): + vae_scale_factor = components.vae_scale_factor + if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def _prepare_latent_ids(latents: torch.Tensor): + """ + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents: Latent tensor of shape (B, C, H, W) + + Returns: + Position IDs tensor of shape (B, H*W, 4) + """ + batch_size, _, height, width = latents.shape + + t = torch.arange(1) + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) + + latent_ids = torch.cartesian_prod(t, h, w, l) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(block_state.device) + + latents = self._pack_latents(latents) + + block_state.latents = latents + block_state.latent_ids = latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="latent_ids"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + OutputParam( + name="latent_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares image latents and their position IDs for Flux2 image conditioning." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image_latents", type_hint=List[torch.Tensor]), + InputParam("batch_size", required=True, type_hint=int), + InputParam("num_images_per_prompt", default=1, type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents", + ), + ] + + @staticmethod + def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10): + """ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + Args: + image_latents: A list of image latent feature tensors of shape (1, C, H, W). + scale: Factor used to define the time separation between latents. + + Returns: + Combined coordinate tensor of shape (1, N_total, 4) + """ + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image_latents = block_state.image_latents + + if image_latents is None: + block_state.image_latents = None + block_state.image_latent_ids = None + else: + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + image_latent_ids = self._prepare_image_ids(image_latents) + + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) + packed = packed.squeeze(0) + packed_latents.append(packed) + + image_latents = torch.cat(packed_latents, dim=0) + image_latents = image_latents.unsqueeze(0) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + block_state.image_latents = image_latents + block_state.image_latent_ids = image_latent_ids + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py new file mode 100644 index 000000000000..b769d9119891 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -0,0 +1,146 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="Position IDs for the latents, used for unpacking", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor: + """ + Unpack latents using position IDs to scatter tokens into place. + + Args: + x: Packed latents tensor of shape (B, seq_len, C) + x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates + + Returns: + Unpacked latents tensor of shape (B, C, H, W) + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + @staticmethod + def _unpatchify_latents(latents): + """Convert patchified latents back to regular format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if block_state.output_type == "latent": + block_state.images = block_state.latents + else: + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + + latents = self._unpatchify_latents(latents) + + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py new file mode 100644 index 000000000000..c12eca65c6a9 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -0,0 +1,252 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2LoopDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +class Flux2LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents after denoising. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2LoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py new file mode 100644 index 000000000000..861b875d77c0 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -0,0 +1,420 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def format_text_input(prompts: List[str], system_message: str = None): + """Format prompts for Mistral3 chat template.""" + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2TextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + # fmt: off + DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Mistral3ForConditionalGeneration), + ComponentSpec("tokenizer", AutoProcessor), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from Mistral3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + prompt_embeds = getattr(block_state, "prompt_embeds", None) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_mistral_3_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: Tuple[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + if block_state.prompt_embeds is not None: + self.set_block_state(state, block_state) + return components, state + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_mistral_3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + system_message=self.DEFAULT_SYSTEM_MESSAGE, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using a remote API endpoint" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from remote API used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + prompt_embeds = getattr(block_state, "prompt_embeds", None) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + import io + + import requests + from huggingface_hub import get_token + + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + if block_state.prompt_embeds is not None: + self.set_block_state(state, block_state) + return components, state + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + response = requests.post( + self.REMOTE_URL, + json={"prompt": prompt}, + headers={ + "Authorization": f"Bearer {get_token()}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True) + block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Image preprocess step for Flux2. Validates and preprocesses reference images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.condition_images = None + else: + if not isinstance(images, list): + images = [images] + + condition_images = [] + for img in images: + components.image_processor.check_image_input(img) + + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = components.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_img = components.image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + condition_images.append(condition_img) + + if block_state.height is None: + block_state.height = image_height + if block_state.width is None: + block_state.width = image_width + + block_state.condition_images = condition_images + + self.set_block_state(state, block_state) + return components, state + + +class Flux2VaeEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLFlux2)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("condition_images", type_hint=List[torch.Tensor]), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=List[torch.Tensor], + description="List of latent representations for each reference image", + ), + ] + + @staticmethod + def _patchify_latents(latents): + """Convert latents to patchified format for Flux2.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator): + """Encode a single image using Flux2 VAE with batch norm normalization.""" + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) + latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + condition_images = block_state.condition_images + + if condition_images is None: + block_state.image_latents = None + else: + device = components._execution_device + dtype = components.vae.dtype + + image_latents = [] + for image in condition_images: + image = image.to(device=device, dtype=dtype) + latent = self._encode_vae_image( + vae=components.vae, + image=image, + generator=block_state.generator, + ) + image_latents.append(latent) + + block_state.image_latents = image_latents + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py new file mode 100644 index 000000000000..fccaefe147a0 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -0,0 +1,140 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) + + +class Flux2TextInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Text input processing step that standardizes text embeddings for Flux2 pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ImageInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Image input processing step that prepares image latents for Flux2 conditioning.\n" + "This step expands image latents to match the batch size." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam("batch_size", required=True, type_hint=int), + InputParam("image_latents", type_hint=torch.Tensor), + InputParam("image_latent_ids", type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents expanded to batch size", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Image latent position IDs expanded to batch size", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + image_latents = block_state.image_latents + image_latent_ids = block_state.image_latent_ids + target_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + if image_latents is not None: + block_state.image_latents = image_latents.repeat(target_batch_size, 1, 1) + + if image_latent_ids is not None: + block_state.image_latent_ids = image_latent_ids.repeat(target_batch_size, 1, 1) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks.py b/src/diffusers/modular_pipelines/flux2/modular_blocks.py new file mode 100644 index 000000000000..e8814fba31fc --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks.py @@ -0,0 +1,237 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep +from .denoise import Flux2DenoiseStep +from .encoders import ( + Flux2ProcessImagesInputStep, + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ImageInputStep, + Flux2TextInputStep, +) + + +class Flux2AutoTextInputStep(AutoPipelineBlocks): + block_classes = [Flux2TextInputStep] + block_names = ["text_input"] + block_trigger_inputs = [None] + + @property + def description(self): + return ( + "Text input step that processes text embeddings and determines batch size.\n" + " - `Flux2TextInputStep` is always used." + ) + + +class Flux2AutoImageInputStep(AutoPipelineBlocks): + block_classes = [Flux2ImageInputStep] + block_names = ["image_input"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Image input step that expands image latents to match batch size.\n" + " - `Flux2ImageInputStep` is used when `image_latents` is provided.\n" + " - Skipped when no image conditioning is used." + ) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +Flux2VaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ] +) + + +class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2VaeEncoderBlocks.values() + block_names = Flux2VaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning." + + +class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2VaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +class Flux2AutoTextEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2RemoteTextEncoderStep, Flux2TextEncoderStep] + block_names = ["remote", "local"] + block_trigger_inputs = ["remote_text_encoder", None] + + @property + def description(self): + return ( + "Text encoder step that generates text embeddings to guide the image generation.\n" + "This is an auto pipeline block that selects between local and remote text encoding.\n" + " - `Flux2RemoteTextEncoderStep` is used when `remote_text_encoder=True`.\n" + " - `Flux2TextEncoderStep` is used otherwise (default)." + ) + + +Flux2BeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ] +) + + +class Flux2BeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2BeforeDenoiseBlocks.values() + block_names = Flux2BeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation." + + +class Flux2AutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "flux2" + block_classes = [Flux2BeforeDenoiseStep] + block_names = ["before_denoise"] + block_trigger_inputs = [None] + + @property + def description(self): + return ( + "Before denoise step that prepares the inputs for the denoise step.\n" + "This is an auto pipeline block for Flux2.\n" + " - `Flux2BeforeDenoiseStep` is used for both text-to-image and image-conditioned generation." + ) + + +class Flux2AutoDenoiseStep(AutoPipelineBlocks): + block_classes = [Flux2DenoiseStep] + block_names = ["denoise"] + block_trigger_inputs = [None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents. " + "This is an auto pipeline block that works for Flux2 text-to-image and image-conditioned tasks." + " - `Flux2DenoiseStep` (denoise) for text-to-image and image-conditioned tasks." + ) + + +class Flux2AutoDecodeStep(AutoPipelineBlocks): + block_classes = [Flux2DecodeStep] + block_names = ["decode"] + block_trigger_inputs = [None] + + @property + def description(self): + return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`" + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2AutoTextEncoderStep()), + ("text_input", Flux2AutoTextInputStep()), + ("image_encoder", Flux2AutoVaeEncoderStep()), + ("image_input", Flux2AutoImageInputStep()), + ("before_denoise", Flux2AutoBeforeDenoiseStep()), + ("denoise", Flux2AutoDenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +class Flux2AutoBlocks(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n" + "- For text-to-image generation, all you need to provide is `prompt`.\n" + "- For image-conditioned generation, you need to provide `image` (list of PIL images)." + ) + + +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + +IMAGE_CONDITIONED_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("preprocess_images", Flux2ProcessImagesInputStep()), + ("vae_encoder", Flux2VaeEncoderStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("image_input", Flux2ImageInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image_conditioned": IMAGE_CONDITIONED_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py new file mode 100644 index 000000000000..3e497f3b1e98 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -0,0 +1,57 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import Flux2LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a6336de71a52..947e137500b4 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -58,6 +58,7 @@ ("wan", "WanModularPipeline"), ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), + ("flux2", "Flux2ModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), @@ -1585,7 +1586,6 @@ def __init__( for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) @property diff --git a/tests/modular_pipelines/flux2/__init__.py b/tests/modular_pipelines/flux2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py new file mode 100644 index 000000000000..4a44dcab634a --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import tempfile +import unittest + +import numpy as np +import PIL +import torch + +from diffusers.modular_pipelines import ( + Flux2AutoBlocks, + Flux2ModularPipeline, + ModularPipeline, +) +from diffusers.modular_pipelines.flux2 import ( + Flux2AutoTextEncoderStep, + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_save_from_pretrained(self): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + with tempfile.TemporaryDirectory() as tmpdirname: + base_pipe.save_pretrained(tmpdirname) + + pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + + pipes.append(pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs() + image = pipe(**inputs, output="images") + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2AutoTextEncoderStep(unittest.TestCase): + def test_auto_text_encoder_block_classes(self): + auto_step = Flux2AutoTextEncoderStep() + + assert len(auto_step.block_classes) == 2 + assert Flux2RemoteTextEncoderStep in auto_step.block_classes + assert Flux2TextEncoderStep in auto_step.block_classes + + def test_auto_text_encoder_trigger_inputs(self): + auto_step = Flux2AutoTextEncoderStep() + + assert auto_step.block_trigger_inputs == ["remote_text_encoder", None] + + def test_auto_text_encoder_block_names(self): + auto_step = Flux2AutoTextEncoderStep() + + assert auto_step.block_names == ["remote", "local"]