Skip to content

Conversation

dimitribarbot
Copy link
Contributor

What does this PR do?

Add ControlNet (InstantX/Qwen-Image-ControlNet-Union) support for Qwen-Image-Edit.

This pipeline enables two latent images to be used as inputs: one for Qwen-Image-Edit and another for Qwen-Image-ControlNet-Union. This provides greater control over the expected results.

Inference

import torch
from diffusers import QwenImageControlNetModel, QwenImageEditControlNetPipeline
from diffusers.utils import load_image

base_model = "Qwen/Qwen-Image-Edit"
controlnet_model = "InstantX/Qwen-Image-ControlNet-Union"

controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)

pipe = QwenImageEditControlNetPipeline.from_pretrained(
    base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
).to("cuda")

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/living_room.png"
).convert("RGB")
control_image = load_image(
    "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/depth.png"
)
prompt = (
    "Anime style of a swanky, minimalist living room with a huge floor-to-ceiling window letting in loads of natural light. "
    "A beige couch with white and beige cushions sits on a wooden floor, with a matching coffee table in front. "
    "The walls are a soft, warm beige, decorated with two framed botanical prints. A potted plant chills in the corner near the window. "
    "Sunlight pours through the leaves outside, casting cool shadows on the floor."
)
image = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=" ",
    control_image=image,
    controlnet_conditioning_scale=1.5,
    width=control_image.size[0],
    height=control_image.size[1],
    num_inference_steps=30,
    true_cfg_scale=2.5,
).images[0]
image.save("qwenimage_edit_controlnet.png")

N.B.1. If this PR and image location are accepted, I will upload the living_room.png file to the documentation-images repository.
N.B.2. To achieve the desired result, set controlnet_conditioning_scale to a value greater than 1. A good starting point is 1.5.

Examples

Depth

Input image:

living_room

Control image:

depth

prompt = (
    "Anime style of a swanky, minimalist living room with a huge floor-to-ceiling window letting in loads of natural light. "
    "A beige couch with white and beige cushions sits on a wooden floor, with a matching coffee table in front. "
    "The walls are a soft, warm beige, decorated with two framed botanical prints. A potted plant chills in the corner near the window. "
    "Sunlight pours through the leaves outside, casting cool shadows on the floor."
)

Result:

living_room_edited

Pose

Input image:

depth

Control image:

depth

prompt = (
    "Make this man sit on a concrete ledge in front of a large circular window, with a cityscape reflected in the glass. "
    "The wall is cream-colored, and the sky is clear blue. His shadow is cast on the wall."
)

Result:

pose_with_controlnet

Whereas if we don't use controlnet:

pose_without_controlnet

N.B. All examples were created using the Nunchaku version of the transformer.

Before submitting

Who can review?

@yiyixuxu
@asomoza

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! i left some comments
didn't know the instant x controlnet is compatible with qwen edit :)

Comment on lines +442 to +453
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()

Comment on lines +455 to +466
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()

Comment on lines +468 to +480
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()

Comment on lines +482 to +493
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()

Comment on lines +442 to +452
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)


return prompt_embeds, encoder_attention_mask

def encode_prompt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a #Copied from ?

raise AttributeError("Could not access latents of provided encoder_output")


def calculate_dimensions(target_area, ratio):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from?


return latents

def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from?

@@ -639,7 +639,9 @@ def forward(
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
sample = controlnet_block_samples[index_block // interval_control]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we adjust inputs in pipeline instead?

@dimitribarbot
Copy link
Contributor Author

Thank you for your code review.

In fact, for my PR, I only merged the pipeline_qwenimage_edit.py and pipeline_qwenimage_controlnet.py pipelines. The only things I added were:

  • the sample code at the beginning of the pipeline,
  • the documentation for the new arguments specific to controlnet,
  • a partial update of the hidden_states in the forward of the Qwen-Image transformer.

Your feedback is all valid, and I can address it in my PR. However, this feedback also applies to the other Qwen-Image pipelines. Given that there is already a complete refactoring PR #12322, what would you prefer me to do:

  • Take your comments into account only in the new pipeline I added,
  • Take them into account for all the pipelines concerned,
  • Not take them into account now and integrate them into the refactoring PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants