-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[core] use kernels
to support _flash_3_hub
attention backend
#12236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not do it this way i.e. changing the behavior when FA3 is not installed. Instead, let's add new backend _flash_3_hf
(or something similar) so that user has to explicitly set it to download the kernel. As an end user, running any remote code should require some form of explicit consent imo, and this approach is better than defaulting to downloading from Hub.
kernels
for FA3 when the build is not locally availablekernels
for FA3 when the build is not locally available
kernels
for FA3 when the build is not locally availablekernels
to support _flash_3_hub
attention backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looking much better!
Getting the following error when trying to compile: ExpandFile "/fsx/sayak/diffusers/check_fa3_backend.py", line 18, in <module>
image = pipe(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
noise_pred = self.transformer(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor
Explanation: torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output
Developer debug context: example_value type: bool; op: call_function; target: <function compiled_with_cxx11_abi at 0x7f67768c0160>
from user code:
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
encoder_hidden_states, hidden_states = block(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
attention_outputs = self.attn(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
hidden_states = dispatch_attention_fn(
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 293, in dispatch_attention_fn
return backend_fn(**kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 717, in _flash_attention_3_hub
out, lse, *_ = flash_attn_3_hub_func(
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 226, in flash_attn_3_hub_func
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 217, in _load_fa3_hub
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
File "/fsx/sayak/diffusers/src/diffusers/utils/kernels_utils.py", line 18, in _get_fa3_from_hub
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 234, in get_kernel
package_name, package_path = install_kernel(repo_id, revision=revision)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 117, in install_kernel
variant = build_variant()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 64, in build_variant
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" Maybe we can investigate this in a future PR. Code:
Cc: @anijain2305 |
The recompilation issues are gone thanks to a recent fix from @danieldk. I will button up this PR and let you know once it’s ready for another review. @a-r-r-o-w |
@DN6 I have followed the ENV var approach to allow users to load from the Hub for using pre-built FA3 kernels. Thanks to @danieldk for pushing the changes for op registrations -- Could you please review? @a-r-r-o-w too. |
@@ -67,6 +67,17 @@ | |||
flash_attn_3_func = None | |||
flash_attn_3_varlen_func = None | |||
|
|||
if DIFFUSERS_ENABLE_HUB_KERNELS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it a constant in constants.py
as I think it will be shared across modules.
deterministic: bool = False, | ||
return_attn_probs: bool = False, | ||
) -> torch.Tensor: | ||
out = flash_attn_3_func_hub( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow (internal) this link
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good to me, just one comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 👍🏽
What does this PR do?
Building FA3 from the source is a time-consuming process, and we could use the
kernels
lib for this.This PR attempts a dirty implementation of using
kernels
to set the FA3 backend when requested by the user. It's an extremely early implementation, so apologies in advance.We could let users specify something like
set_attention_backend("kernels-community/vllm-flash-attn3", interface="flash_attn_func")
when they don't have the FA3 build locally available. But that's a matter for our discussion.Additionally, this also helps keep
diffusers
Hub-first askernels
provides a great way to leverage the platform, IMO.Minimal code to test:
Some comments are inline.