diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e123f4c19398..f71be7c8ecc0 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -26,6 +26,7 @@ is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, + is_kernels_available, is_sageattention_available, is_sageattention_version, is_torch_npu_available, @@ -35,7 +36,7 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS _REQUIRED_FLASH_VERSION = "2.6.3" @@ -67,6 +68,17 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None +if DIFFUSERS_ENABLE_HUB_KERNELS: + if not is_kernels_available(): + raise ImportError( + "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." + ) + from ..utils.kernels_utils import _get_fa3_from_hub + + flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func +else: + flash_attn_3_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -153,6 +165,8 @@ class AttentionBackendName(str, Enum): FLASH_VARLEN = "flash_varlen" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" + _FLASH_3_HUB = "_flash_3_hub" + # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. # PyTorch native FLEX = "flex" @@ -351,6 +365,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) + # TODO: add support Hub variant of FA3 varlen later + elif backend in [AttentionBackendName._FLASH_3_HUB]: + if not DIFFUSERS_ENABLE_HUB_KERNELS: + raise RuntimeError( + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." + ) + if not is_kernels_available(): + raise RuntimeError( + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + ) + elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, @@ -657,6 +682,44 @@ def _flash_attention_3( return (out, lse) if return_attn_probs else out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out = flash_attn_3_func_hub( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + return_attn_probs=return_attn_probs, + ) + # When `return_attn_probs` is True, the above returns a tuple of + # actual outputs and lse. + return (out[0], out[1]) if return_attn_probs else out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index d9867fb8758d..8b4d76f3cbe3 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,6 +46,7 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES +DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py new file mode 100644 index 000000000000..26d6e3972fb7 --- /dev/null +++ b/src/diffusers/utils/kernels_utils.py @@ -0,0 +1,23 @@ +from ..utils import get_logger +from .import_utils import is_kernels_available + + +logger = get_logger(__name__) + + +_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" + + +def _get_fa3_from_hub(): + if not is_kernels_available(): + return None + else: + from kernels import get_kernel + + try: + # TODO: temporary revision for now. Remove when merged upstream into `main`. + flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") + return flash_attn_3_hub + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + raise