Skip to content

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 25, 2025

What does this PR do?

Building FA3 from the source is a time-consuming process, and we could use the kernels lib for this.

This PR attempts a dirty implementation of using kernels to set the FA3 backend when requested by the user. It's an extremely early implementation, so apologies in advance.

We could let users specify something like set_attention_backend("kernels-community/vllm-flash-attn3", interface="flash_attn_func") when they don't have the FA3 build locally available. But that's a matter for our discussion.

Additionally, this also helps keep diffusers Hub-first as kernels provides a great way to leverage the platform, IMO.

Minimal code to test:

import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3_hub")

prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]

image.save("output.png")

Some comments are inline.

@sayakpaul sayakpaul requested review from DN6 and a-r-r-o-w August 25, 2025 16:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do it this way i.e. changing the behavior when FA3 is not installed. Instead, let's add new backend _flash_3_hf (or something similar) so that user has to explicitly set it to download the kernel. As an end user, running any remote code should require some form of explicit consent imo, and this approach is better than defaulting to downloading from Hub.

@sayakpaul sayakpaul changed the title [wip][core] use kernels for FA3 when the build is not locally available [core] use kernels for FA3 when the build is not locally available Aug 26, 2025
@sayakpaul sayakpaul changed the title [core] use kernels for FA3 when the build is not locally available [core] use kernels to support _flash_3_hub attention backend Aug 26, 2025
@sayakpaul sayakpaul marked this pull request as ready for review August 26, 2025 13:01
@sayakpaul sayakpaul requested a review from a-r-r-o-w August 26, 2025 13:01
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looking much better!

@sayakpaul
Copy link
Member Author

Getting the following error when trying to compile:

Expand
File "/fsx/sayak/diffusers/check_fa3_backend.py", line 18, in <module>
    image = pipe(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor
  Explanation: torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output


  Developer debug context: example_value type: bool; op: call_function; target: <function compiled_with_cxx11_abi at 0x7f67768c0160>


from user code:
   File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
    encoder_hidden_states, hidden_states = block(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
    attention_outputs = self.attn(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
    hidden_states = dispatch_attention_fn(
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 293, in dispatch_attention_fn
    return backend_fn(**kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 717, in _flash_attention_3_hub
    out, lse, *_ = flash_attn_3_hub_func(
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 226, in flash_attn_3_hub_func
    return _load_fa3_hub().flash_attn_func(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 217, in _load_fa3_hub
    fa3_hub = _get_fa3_from_hub()  # won't re-download if already present
  File "/fsx/sayak/diffusers/src/diffusers/utils/kernels_utils.py", line 18, in _get_fa3_from_hub
    flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 234, in get_kernel
    package_name, package_path = install_kernel(repo_id, revision=revision)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 117, in install_kernel
    variant = build_variant()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 64, in build_variant
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Maybe we can investigate this in a future PR.

Code:

import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3_hub")
pipe.transformer.compile(fullgraph=True)

prompt = "A cat holding a sign that says 'hello world'"

with torch._dynamo.config.patch(error_on_recompile=True):
    image = pipe(
        prompt, num_inference_steps=28, guidance_scale=4.0, generator=torch.manual_seed(0)
    ).images[0]
    image.save("output.png")

Cc: @anijain2305

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 26, 2025 15:07
@sayakpaul
Copy link
Member Author

The recompilation issues are gone thanks to a recent fix from @danieldk. I will button up this PR and let you know once it’s ready for another review. @a-r-r-o-w

@sayakpaul
Copy link
Member Author

@DN6 I have followed the ENV var approach to allow users to load from the Hub for using pre-built FA3 kernels. Thanks to @danieldk for pushing the changes for op registrations -- torch.compile with fullgraph traceability now works perfectly fine.

Could you please review? @a-r-r-o-w too.

@@ -67,6 +67,17 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it a constant in constants.py as I think it will be shared across modules.

deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow (internal) this link

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good to me, just one comment

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 👍🏽

@sayakpaul sayakpaul merged commit 130fd8d into main Sep 3, 2025
14 checks passed
@sayakpaul sayakpaul deleted the fa3-from-kernels branch September 3, 2025 03:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants