From 13b042fa208e632a7277c7462fa10f6440baa612 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Sep 2025 18:06:49 +0530 Subject: [PATCH 1/2] up --- .../modular_pipelines/flux/__init__.py | 12 +- .../modular_pipelines/flux/before_denoise.py | 220 +++++++++++++++++- .../modular_pipelines/flux/denoise.py | 116 +++++++++ .../modular_pipelines/flux/modular_blocks.py | 89 ++++++- .../flux/modular_pipeline.py | 15 ++ 5 files changed, 441 insertions(+), 11 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py index 2891edf79041..30affef56a79 100644 --- a/src/diffusers/modular_pipelines/flux/__init__.py +++ b/src/diffusers/modular_pipelines/flux/__init__.py @@ -24,15 +24,19 @@ _import_structure["encoders"] = ["FluxTextEncoderStep"] _import_structure["modular_blocks"] = [ "ALL_BLOCKS", + "ALL_BLOCKS_KONTEXT", "AUTO_BLOCKS", + "AUTO_BLOCKS_KONTEXT", "TEXT2IMAGE_BLOCKS", "FluxAutoBeforeDenoiseStep", "FluxAutoBlocks", - "FluxAutoBlocks", "FluxAutoDecodeStep", "FluxAutoDenoiseStep", + "FluxKontextAutoBeforeDenoiseStep", + "FluxKontextAutoBlocks", + "FluxKontextAutoDenoiseStep", ] - _import_structure["modular_pipeline"] = ["FluxModularPipeline"] + _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -44,12 +48,16 @@ from .encoders import FluxTextEncoderStep from .modular_blocks import ( ALL_BLOCKS, + ALL_BLOCKS_KONTEXT, AUTO_BLOCKS, + AUTO_BLOCKS_KONTEXT, TEXT2IMAGE_BLOCKS, FluxAutoBeforeDenoiseStep, FluxAutoBlocks, FluxAutoDecodeStep, FluxAutoDenoiseStep, + FluxKontextAutoBeforeDenoiseStep, + FluxKontextAutoDenoiseStep, ) from .modular_pipeline import FluxModularPipeline else: diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 507acce1ebf6..afc9f3f90382 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -18,6 +18,8 @@ import numpy as np import torch +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging @@ -182,15 +184,15 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) -# Cannot use "# Copied from" because it introduces weird indentation errors. -def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): +def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator, sample_mode: str = "sample"): if isinstance(generator, list): image_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(vae.encode(image), generator=generator) + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor @@ -687,3 +689,213 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip self.set_block_state(state, block_state) return components, state + + +class FluxKontextPrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux_kontext" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the image-to-image generation process with Flux Kontext" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("max_area", type_hint=int, default=1024**2), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "image_latents", type_hint=torch.Tensor, description="Latents computed from the input image(s)." + ), + OutputParam( + "latent_ids", + type_hint=torch.Tensor, + description="IDs computed from the latent sequence needed for RoPE", + ), + OutputParam( + "image_ids", + type_hint=torch.Tensor, + description="IDs computed from the image sequence needed for RoPE", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def preprocess_image( + image, image_processor: VaeImageProcessor, vae_scale_factor: int, latent_channels: int, _auto_resize=True + ): + from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): + multiple_of = vae_scale_factor * 2 + img = image[0] if isinstance(image, list) else image + image_height, image_width = image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = image_processor.resize(image, image_height, image_width) + image = image_processor.preprocess(image, image_height, image_width) + return image + + @staticmethod + def prepare_latents( + comp, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over + # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were + # to go with the "# Copied from ..." approach. Or maybe there's a way? + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + image_latents = _encode_vae_image(image=image, generator=generator, sample_mode="argmax") + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = _pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = _prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + latent_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latent_ids, image_ids + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + + # Adjust height and width if needed. + max_area = block_state.max_area + original_height, original_width = block_state.height, block_state.width + aspect_ratio = original_width / original_height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = components.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + block_state.height = height + block_state.width = width + + # Process input image(s). + # `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it. + image = block_state.image + block_state.image = self.preprocess_image( + image=image, + image_processor=components.image_processor, + vae_scale_factor=components.vae_scale_factor, + latent_channels=components.num_channels_latents, + ) + + batch_size = block_state.batch_size * block_state.num_images_per_prompt + block_state.latents, block_state.image_latents, block_state.latent_ids, block_state.image_ids = ( + self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index ffb436abd450..b3fe6810bc89 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -110,6 +110,106 @@ def __call__( return components, block_state +class FluxKontextLoopDenoiser(ModularPipelineBlocks): + model_name = "flux_kontext" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents for Flux Kontext. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Image latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "text_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from latent sequence needed for RoPE", + ), + InputParam( + "image_ids", + type_hint=torch.Tensor, + description="IDs computed from image sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latent_ids = block_state.latent_ids + image_ids = block_state.image_ids + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + latents = block_state.latents + latent_model_input = latents + image_latents = block_state.image_latents + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.text_ids, + img_ids=latent_ids, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + class FluxLoopAfterDenoiser(ModularPipelineBlocks): model_name = "flux" @@ -225,3 +325,19 @@ def description(self) -> str: " - `FluxLoopAfterDenoiser`\n" "This block supports both text2image and img2img tasks." ) + + +class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper): + block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `FluxKontextLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports both text2image and img2img tasks." + ) diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index 04b439f026a4..a90e7b6e9eb4 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -19,11 +19,12 @@ FluxImg2ImgPrepareLatentsStep, FluxImg2ImgSetTimestepsStep, FluxInputStep, + FluxKontextPrepareLatentsStep, FluxPrepareLatentsStep, FluxSetTimestepsStep, ) from .decoders import FluxDecodeStep -from .denoise import FluxDenoiseStep +from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep @@ -46,7 +47,7 @@ def description(self): ) -# before_denoise: text2img, img2img +# before_denoise: text2img class FluxBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ FluxInputStep, @@ -66,6 +67,26 @@ def description(self): ) +# before_denoise: text2img, img2img (for Kontext) +class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + FluxInputStep, + FluxKontextPrepareLatentsStep, + FluxSetTimestepsStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step in Flux Kontext.\n" + + "This is a sequential pipeline blocks:\n" + + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n" + + " - `FluxKontextPrepareLatentsStep` is used to prepare the latents\n" + + " - `FluxSetTimestepsStep` is used to set the timesteps\n" + ) + + # before_denoise: img2img class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep] @@ -98,7 +119,23 @@ def description(self): ) -# denoise: text2image +# flux kontext (both text2img and img2img) +class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [FluxKontextBeforeDenoiseStep] + block_names = ["text2image", "img2img"] + block_trigger_inputs = [None, "image_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxKontextBeforeDenoiseStep` (text2img) is used when only `image_latents` is None.\n" + + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + ) + + +# denoise: text2image, img2img class FluxAutoDenoiseStep(AutoPipelineBlocks): block_classes = [FluxDenoiseStep] block_names = ["denoise"] @@ -113,7 +150,21 @@ def description(self) -> str: ) -# decode: all task (text2img, img2img, inpainting) +class FluxKontextAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [FluxKontextDenoiseStep] + block_names = ["denoise"] + block_trigger_inputs = [None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents for Flux Kontext. " + "This is a auto pipeline block that works for text2image and img2img tasks." + " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks." + ) + + +# decode: all task (text2img, img2img) class FluxAutoDecodeStep(AutoPipelineBlocks): block_classes = [FluxDecodeStep] block_names = ["non-inpaint"] @@ -124,7 +175,7 @@ def description(self): return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`" -# text2image +# text2image, img2img class FluxAutoBlocks(SequentialPipelineBlocks): block_classes = [ FluxTextEncoderStep, @@ -144,6 +195,25 @@ def description(self): ) +# text2image, img2img +class FluxKontextAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + FluxTextEncoderStep, + FluxKontextAutoBeforeDenoiseStep, + FluxKontextAutoDenoiseStep, + FluxAutoDecodeStep, + ] + block_names = ["text_encoder", "before_denoise", "denoise", "decoder"] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image and image-to-image using Flux Kontext.\n" + + "- for text-to-image generation, all you need to provide is `prompt`\n" + + "- for image-to-image generation, you need to provide either `image` or `image_latents`" + ) + + TEXT2IMAGE_BLOCKS = InsertableDict( [ ("text_encoder", FluxTextEncoderStep), @@ -176,6 +246,15 @@ def description(self): ("decode", FluxAutoDecodeStep), ] ) +AUTO_BLOCKS_KONTEXT = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep), + ("before_denoise", FluxKontextAutoBeforeDenoiseStep), + ("denoise", FluxKontextAutoDenoiseStep), + ("decode", FluxAutoDecodeStep), + ] +) ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} +ALL_BLOCKS_KONTEXT = {"auto": AUTO_BLOCKS_KONTEXT} diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py index e97445d411e4..f319d9432ba0 100644 --- a/src/diffusers/modular_pipelines/flux/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py @@ -57,3 +57,18 @@ def num_channels_latents(self): if getattr(self, "transformer", None): num_channels_latents = self.transformer.config.in_channels // 4 return num_channels_latents + + +class FluxKontextModularPipeline(FluxModularPipeline): + """ + A ModularPipeline for Flux Kontext. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) From e2e2297316d1af36ff39ecf00bc4d1b7d399868d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Sep 2025 19:34:01 +0530 Subject: [PATCH 2/2] up --- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/flux/__init__.py | 2 +- .../modular_pipelines/flux/before_denoise.py | 186 +++++++++++++----- .../modular_pipelines/flux/denoise.py | 4 +- .../modular_pipelines/flux/modular_blocks.py | 14 +- .../modular_pipelines/modular_pipeline.py | 1 + 6 files changed, 157 insertions(+), 54 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 68d707f9e047..97c0869dd8cd 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,7 +46,7 @@ ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] - _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"] + _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline", "FluxKontextModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -57,7 +57,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .components_manager import ComponentsManager - from .flux import FluxAutoBlocks, FluxModularPipeline + from .flux import FluxAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py index 30affef56a79..ca10a633bc0d 100644 --- a/src/diffusers/modular_pipelines/flux/__init__.py +++ b/src/diffusers/modular_pipelines/flux/__init__.py @@ -59,7 +59,7 @@ FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep, ) - from .modular_pipeline import FluxModularPipeline + from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index afc9f3f90382..228160b479b6 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -17,6 +17,7 @@ import numpy as np import torch +from PIL import Image from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor @@ -26,7 +27,7 @@ from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import FluxModularPipeline +from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -319,6 +320,141 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip return components, state +class FluxKontextInputStep(ModularPipelineBlocks): + model_name = "flux_kontext" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt.\n" + " 3. Processes the input `image`." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "image", + required=False, + type_hint=Union[Image.Image, torch.Tensor], + description="Input image/image latents to perform denoising.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "image", + type_hint=torch.Tensor, + description="Processed image/image latents.", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: + if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" + f" {block_state.pooled_prompt_embeds.shape}." + ) + + @staticmethod + def preprocess_image( + image, image_processor: VaeImageProcessor, vae_scale_factor: int, latent_channels: int, _auto_resize=True + ) -> torch.Tensor: + from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): + multiple_of = vae_scale_factor * 2 + img = image[0] if isinstance(image, list) else image + image_height, image_width = image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = image_processor.resize(image, image_height, image_width) + image = image_processor.preprocess(image, image_height, image_width) + return image + + @torch.no_grad() + def __call__(self, components: FluxKontextModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + # TODO: `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it. + block_state.image = self.preprocess_image( + image=block_state.image, + image_processor=components.image_processor, + vae_scale_factor=components.vae_scale_factor, + latent_channels=components.num_channels_latents, + ) + self.set_block_state(state, block_state) + + return components, state + + class FluxSetTimestepsStep(ModularPipelineBlocks): model_name = "flux" @@ -696,25 +832,18 @@ class FluxKontextPrepareLatentsStep(ModularPipelineBlocks): @property def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 16}), - default_creation_method="from_config", - ), - ] + return [ComponentSpec("vae", AutoencoderKL)] @property def description(self) -> str: - return "Prepare latents step that prepares the latents for the image-to-image generation process with Flux Kontext" + return "Step that prepares the latents for the image-to-image generation process with Flux Kontext." @property def inputs(self) -> List[InputParam]: return [ InputParam("height", type_hint=int), InputParam("width", type_hint=int), + InputParam("image", type_hint=Union[Image.Image, torch.Tensor], required=False), InputParam("max_area", type_hint=int, default=1024**2), InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("num_images_per_prompt", type_hint=int, default=1), @@ -758,28 +887,6 @@ def check_inputs(components, block_state): f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." ) - @staticmethod - def preprocess_image( - image, image_processor: VaeImageProcessor, vae_scale_factor: int, latent_channels: int, _auto_resize=True - ): - from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS - - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): - multiple_of = vae_scale_factor * 2 - img = image[0] if isinstance(image, list) else image - image_height, image_width = image_processor.get_default_height_width(img) - aspect_ratio = image_width / image_height - if _auto_resize: - # Kontext is trained on specific resolutions, using one of them is recommended - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS - ) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - image = image_processor.resize(image, image_height, image_width) - image = image_processor.preprocess(image, image_height, image_width) - return image - @staticmethod def prepare_latents( comp, @@ -807,7 +914,7 @@ def prepare_latents( if image is not None: image = image.to(device=device, dtype=dtype) if image.shape[1] != num_channels_latents: - image_latents = _encode_vae_image(image=image, generator=generator, sample_mode="argmax") + image_latents = _encode_vae_image(vae=comp.vae, image=image, generator=generator, sample_mode="argmax") else: image_latents = image if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: @@ -871,20 +978,11 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.height = height block_state.width = width - # Process input image(s). - # `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it. - image = block_state.image - block_state.image = self.preprocess_image( - image=image, - image_processor=components.image_processor, - vae_scale_factor=components.vae_scale_factor, - latent_channels=components.num_channels_latents, - ) - batch_size = block_state.batch_size * block_state.num_images_per_prompt block_state.latents, block_state.image_latents, block_state.latent_ids, block_state.image_ids = ( self.prepare_latents( components, + block_state.image, batch_size, block_state.num_channels_latents, block_state.height, diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index b3fe6810bc89..4eac9ab89575 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -26,7 +26,7 @@ PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import FluxModularPipeline +from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -179,7 +179,7 @@ def inputs(self) -> List[Tuple[str, Any]]: @torch.no_grad() def __call__( - self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + self, components: FluxKontextModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: latent_ids = block_state.latent_ids image_ids = block_state.image_ids diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index a90e7b6e9eb4..97f19d5bf518 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -19,6 +19,7 @@ FluxImg2ImgPrepareLatentsStep, FluxImg2ImgSetTimestepsStep, FluxInputStep, + FluxKontextInputStep, FluxKontextPrepareLatentsStep, FluxPrepareLatentsStep, FluxSetTimestepsStep, @@ -121,7 +122,8 @@ def description(self): # flux kontext (both text2img and img2img) class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [FluxKontextBeforeDenoiseStep] + # Kontext should follow `FluxBeforeDenoiseStep` when T2I mode is on. + block_classes = [FluxBeforeDenoiseStep, FluxKontextBeforeDenoiseStep] block_names = ["text2image", "img2img"] block_trigger_inputs = [None, "image_latents"] @@ -130,7 +132,7 @@ def description(self): return ( "Before denoise step that prepare the inputs for the denoise step.\n" + "This is an auto pipeline block that works for text2image.\n" - + " - `FluxKontextBeforeDenoiseStep` (text2img) is used when only `image_latents` is None.\n" + + " - `FluxBeforeDenoiseStep` (text2img) is used when only `image_latents` is None.\n" + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" ) @@ -249,9 +251,11 @@ def description(self): AUTO_BLOCKS_KONTEXT = InsertableDict( [ ("text_encoder", FluxTextEncoderStep), - ("before_denoise", FluxKontextAutoBeforeDenoiseStep), - ("denoise", FluxKontextAutoDenoiseStep), - ("decode", FluxAutoDecodeStep), + ("input", FluxKontextInputStep), + ("prepare_latents", FluxKontextPrepareLatentsStep), + ("set_timesteps", FluxSetTimestepsStep), + ("denoise", FluxKontextDenoiseStep), + ("decode", FluxDecodeStep), ] ) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3918679c1613..043af83386bc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -556,6 +556,7 @@ def __init__(self): f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." ) default_blocks = [t for t in self.block_trigger_inputs if t is None] + print(f"{default_blocks=}, {self.block_trigger_inputs=}") # can only have 1 or 0 default block, and has to put in the last # the order of blocks matters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]