From b58db6934c1c4bc1be1021c4c88a6b94ceea740c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Aug 2025 14:18:18 -0700 Subject: [PATCH 01/36] Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options --- comfy/ldm/modules/attention.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28dfdc8..0ab5cf16df55 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -7,6 +7,7 @@ from einops import rearrange, repeat from typing import Optional import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -91,6 +92,17 @@ def forward(self, x): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + transformer_options = kwargs.pop("transformer_options", None) + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](*args, **kwargs) + return func(*args, **kwargs) + return wrapper + +@wrap_attn def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, q.dtype) @@ -159,7 +171,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - +@wrap_attn def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, query.dtype) @@ -230,6 +242,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states +@wrap_attn def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, q.dtype) @@ -359,6 +372,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape except: pass +@wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): b = q.shape[0] dim_head = q.shape[-1] @@ -427,7 +441,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - +@wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -470,7 +484,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - +@wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -534,7 +548,7 @@ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - +@wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -629,7 +643,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +654,9 @@ def forward(self, x, context=None, value=None, mask=None): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) From 68b00e9c6004c5121e31806546ef2b33761c8122 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Aug 2025 17:13:33 -0700 Subject: [PATCH 02/36] Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added --- comfy/ldm/modules/attention.py | 85 ++++++++++++++++++++++++++++++++-- comfy/samplers.py | 6 +++ 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 0ab5cf16df55..6a8ffe10b378 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,5 +1,8 @@ import math import sys +import json +import os +from datetime import datetime import torch import torch.nn.functional as F @@ -92,13 +95,89 @@ def forward(self, x): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +import inspect +LOG_ATTN_CALLS = False +LOG_CONTENTS = {} + +def save_log_contents(): + import folder_paths + output_dir = folder_paths.get_output_directory() + + # Create attn_logs directory if it doesn't exist + attn_logs_dir = os.path.join(output_dir, "attn_logs") + os.makedirs(attn_logs_dir, exist_ok=True) + + # Generate timestamp filename (down to second) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}.json" + filepath = os.path.join(attn_logs_dir, filename) + + # Save LOG_CONTENTS as JSON file + try: + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(list(LOG_CONTENTS.values()), f, indent=2, ensure_ascii=False) + logging.info(f"Saved attention log contents to {filepath}") + except Exception as e: + logging.error(f"Failed to save attention log contents: {e}") + +def get_class_from_frame(frame): + # Check for 'self' (instance method) or 'cls' (classmethod) + if 'self' in frame.f_locals: + return frame.f_locals['self'].__class__.__name__ + elif 'cls' in frame.f_locals: + return frame.f_locals['cls'].__name__ + return None + +def has_transformer_options_passed(frame): + if 'transformer_options' in frame.f_locals.keys(): + if frame.f_locals['transformer_options']: + return True + return False + def wrap_attn(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if LOG_ATTN_CALLS: + continue_to_add = True + to_add = 1000 + logged_stack = [] + logged_stack_to_index = -1 + for frame_info in inspect.stack()[1:]: + if not continue_to_add: + break + if to_add == 0: + break + if frame_info.function == "_calc_cond_batch_outer": + break + if 'venv' in frame_info.filename: + continue + elif 'ComfyUI' not in frame_info.filename: + continue + elif 'execution.py' in frame_info.filename: + continue + elif 'patcher_extension.py' in frame_info.filename: + continue + to_add -= 1 + cls_name = get_class_from_frame(frame_info.frame) + log_string = f"{frame_info.filename}:{frame_info.lineno}" + if cls_name: + log_string += f":{cls_name}.{frame_info.function}" + else: + log_string += f":{frame_info.function}" + if has_transformer_options_passed(frame_info.frame): + log_string += ":✅" + if logged_stack_to_index == -1: + logged_stack_to_index = len(logged_stack) + else: + log_string += ":❌" + logged_stack.append(log_string) + # logging.info(f"Attn call stack: {logged_stack}") + # logging.info(f"Logged stack to index: {logged_stack[:logged_stack_to_index+1]}") + LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) transformer_options = kwargs.pop("transformer_options", None) if transformer_options is not None: if "optimized_attention_override" in transformer_options: - return transformer_options["optimized_attention_override"](*args, **kwargs) + return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) return func(*args, **kwargs) return wrapper @@ -760,7 +839,7 @@ def forward(self, x, context=None, transformer_options={}): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -800,7 +879,7 @@ def forward(self, x, context=None, transformer_options={}): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] diff --git a/comfy/samplers.py b/comfy/samplers.py index c7dfef4ea8a6..acf86c03d654 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1019,6 +1019,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) preprocess_conds_hooks(self.conds) + import comfy.ldm.modules.attention #TODO: Remove this $$$$$ try: orig_model_options = self.model_options self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) @@ -1033,12 +1034,17 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) + comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.LOG_CONTENTS = {} output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() + comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.save_log_contents() + comfy.ldm.modules.attention.LOG_CONTENTS = {} del self.conds return output From 29b7990dc274ee0221f1be28f1e7fa30528e808f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Aug 2025 17:55:35 -0700 Subject: [PATCH 03/36] Fix memory usage issue with inspect --- comfy/ldm/modules/attention.py | 84 +++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 6a8ffe10b378..86cb3d384abb 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -142,38 +142,58 @@ def wrapper(*args, **kwargs): to_add = 1000 logged_stack = [] logged_stack_to_index = -1 - for frame_info in inspect.stack()[1:]: - if not continue_to_add: - break - if to_add == 0: - break - if frame_info.function == "_calc_cond_batch_outer": - break - if 'venv' in frame_info.filename: - continue - elif 'ComfyUI' not in frame_info.filename: - continue - elif 'execution.py' in frame_info.filename: - continue - elif 'patcher_extension.py' in frame_info.filename: - continue - to_add -= 1 - cls_name = get_class_from_frame(frame_info.frame) - log_string = f"{frame_info.filename}:{frame_info.lineno}" - if cls_name: - log_string += f":{cls_name}.{frame_info.function}" - else: - log_string += f":{frame_info.function}" - if has_transformer_options_passed(frame_info.frame): - log_string += ":✅" - if logged_stack_to_index == -1: - logged_stack_to_index = len(logged_stack) - else: - log_string += ":❌" - logged_stack.append(log_string) - # logging.info(f"Attn call stack: {logged_stack}") - # logging.info(f"Logged stack to index: {logged_stack[:logged_stack_to_index+1]}") - LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + + frame = inspect.currentframe() + try: + # skip wrapper, start at actual wrapped function + frame = frame.f_back + + while frame and continue_to_add and to_add > 0: + code = frame.f_code + filename = code.co_filename + function = code.co_name + lineno = frame.f_lineno + + if function == "_calc_cond_batch_outer": + break + if 'venv' in filename: + frame = frame.f_back + continue + elif 'ComfyUI' not in filename: + frame = frame.f_back + continue + elif 'execution.py' in filename: + frame = frame.f_back + continue + elif 'patcher_extension.py' in filename: + frame = frame.f_back + continue + + to_add -= 1 + cls_name = get_class_from_frame(frame) + log_string = f"{filename}:{lineno}" + if cls_name: + log_string += f":{cls_name}.{function}" + else: + log_string += f":{function}" + + if has_transformer_options_passed(frame): + log_string += ":✅" + if logged_stack_to_index == -1: + logged_stack_to_index = len(logged_stack) + else: + log_string += ":❌" + + logged_stack.append(log_string) + + # move up the stack + frame = frame.f_back + + LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + + finally: + # Important: break ref cycles so tensors aren't pinned + del frame transformer_options = kwargs.pop("transformer_options", None) if transformer_options is not None: if "optimized_attention_override" in transformer_options: From dd21b4aa51a346396f442970728b4c6067900b03 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Aug 2025 17:56:21 -0700 Subject: [PATCH 04/36] Made WAN attention receive transformer_options, test node added to wan to test out attention override later --- comfy/ldm/wan/model.py | 14 ++++++++------ comfy_extras/nodes_wan.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index dedfb47e279d..7627da6430c2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -52,7 +52,7 @@ def __init__(self, self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -75,6 +75,7 @@ def qkv_fn(x): k.view(b, s, n * d), v, heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -83,7 +84,7 @@ def qkv_fn(x): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -95,7 +96,7 @@ def forward(self, x, context, **kwargs): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -206,6 +207,7 @@ def forward( freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -224,12 +226,12 @@ def forward( # self-attention y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), - freqs) + freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -564,7 +566,7 @@ def block_wrap(args): out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 312260f00225..4fedbba7d440 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1003,6 +1003,42 @@ def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return io.NodeOutput(out_latent) +import comfy.patcher_extension +import comfy.ldm.modules.attention +class AttentionOverrideTest(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AttentionOverrideTest", + category="devtools", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @staticmethod + def attention_override(func, transformer_options, *args, **kwargs): + new_attention = comfy.ldm.modules.attention.attention_basic + return new_attention.__wrapped__(*args, **kwargs) + + @staticmethod + def sampler_sampler_wrapper(executor, *args, **kwargs): + try: + # extra_args = args[2] + return executor(*args, **kwargs) + finally: + pass + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + model = model.clone() + + model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) + return io.NodeOutput(model) class WanExtension(ComfyExtension): @override @@ -1020,6 +1056,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: WanPhantomSubjectToVideo, WanSoundImageToVideo, Wan22ImageToVideoLatent, + AttentionOverrideTest, ] async def comfy_entrypoint() -> WanExtension: From 669b9ef8e699b016a5dcf2666631bb8c8f1d8f1c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 13:14:41 -0700 Subject: [PATCH 05/36] Added **kwargs to all attention functions so transformer_options could potentially be passed through --- comfy/ldm/modules/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 86cb3d384abb..b3bb71734a01 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -194,7 +194,7 @@ def wrapper(*args, **kwargs): finally: # Important: break ref cycles so tensors aren't pinned del frame - transformer_options = kwargs.pop("transformer_options", None) + transformer_options = kwargs.get("transformer_options", None) if transformer_options is not None: if "optimized_attention_override" in transformer_options: return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) @@ -202,7 +202,7 @@ def wrapper(*args, **kwargs): return wrapper @wrap_attn -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -271,7 +271,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out @wrap_attn -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -342,7 +342,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, return hidden_states @wrap_attn -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -472,7 +472,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -487,7 +487,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -541,7 +541,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -584,7 +584,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -614,7 +614,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -648,7 +648,7 @@ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: From 51a30c2ad7e2588485dc2f543e0a650b19cb3d0d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 18:53:20 -0700 Subject: [PATCH 06/36] Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention --- comfy/ldm/modules/attention.py | 176 ++++++++++++++++++++------------- 1 file changed, 110 insertions(+), 66 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b3bb71734a01..f6013672b6a2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging import functools @@ -21,23 +21,58 @@ import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + +def _register_core_attention_functions(): + """ + Register attention functions exposed by core ComfyUI. + """ + # NOTE: attention_basic is purposely not registered, as it is not used in code + if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) + if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) + if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) + register_attention_function("pytorch", attention_pytorch) + register_attention_function("sub_quad", attention_sub_quad) + register_attention_function("split", attention_split) + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -137,68 +172,76 @@ def has_transformer_options_passed(frame): def wrap_attn(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if LOG_ATTN_CALLS: - continue_to_add = True - to_add = 1000 - logged_stack = [] - logged_stack_to_index = -1 - - frame = inspect.currentframe() - try: - # skip wrapper, start at actual wrapped function - frame = frame.f_back - - while frame and continue_to_add and to_add > 0: - code = frame.f_code - filename = code.co_filename - function = code.co_name - lineno = frame.f_lineno - - if function == "_calc_cond_batch_outer": - break - if 'venv' in filename: - frame = frame.f_back - continue - elif 'ComfyUI' not in filename: - frame = frame.f_back - continue - elif 'execution.py' in filename: - frame = frame.f_back - continue - elif 'patcher_extension.py' in filename: - frame = frame.f_back - continue - - to_add -= 1 - cls_name = get_class_from_frame(frame) - log_string = f"{filename}:{lineno}" - if cls_name: - log_string += f":{cls_name}.{function}" - else: - log_string += f":{function}" + remove_attn_wrapper_key = False + try: + if LOG_ATTN_CALLS: + continue_to_add = True + to_add = 1000 + logged_stack = [] + logged_stack_to_index = -1 + + frame = inspect.currentframe() + try: + # skip wrapper, start at actual wrapped function + frame = frame.f_back - if has_transformer_options_passed(frame): - log_string += ":✅" - if logged_stack_to_index == -1: - logged_stack_to_index = len(logged_stack) - else: - log_string += ":❌" + while frame and continue_to_add and to_add > 0: + code = frame.f_code + filename = code.co_filename + function = code.co_name + lineno = frame.f_lineno + + if function == "_calc_cond_batch_outer": + break + if 'venv' in filename: + frame = frame.f_back + continue + elif 'ComfyUI' not in filename: + frame = frame.f_back + continue + elif 'execution.py' in filename: + frame = frame.f_back + continue + elif 'patcher_extension.py' in filename: + frame = frame.f_back + continue + + to_add -= 1 + cls_name = get_class_from_frame(frame) + log_string = f"{filename}:{lineno}" + if cls_name: + log_string += f":{cls_name}.{function}" + else: + log_string += f":{function}" - logged_stack.append(log_string) + if has_transformer_options_passed(frame): + log_string += ":✅" + if logged_stack_to_index == -1: + logged_stack_to_index = len(logged_stack) + else: + log_string += ":❌" - # move up the stack - frame = frame.f_back + logged_stack.append(log_string) - LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + # move up the stack + frame = frame.f_back - finally: - # Important: break ref cycles so tensors aren't pinned - del frame - transformer_options = kwargs.get("transformer_options", None) - if transformer_options is not None: - if "optimized_attention_override" in transformer_options: - return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) - return func(*args, **kwargs) + LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + + finally: + # Important: break ref cycles so tensors aren't pinned + del frame + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] return wrapper @wrap_attn @@ -707,6 +750,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +_register_core_attention_functions() optimized_attention_masked = optimized_attention From 1f499f0794d30dfbed7bf9de3e33bc268eee0249 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 18:54:22 -0700 Subject: [PATCH 07/36] Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) --- comfy/samplers.py | 5 +++-- comfy_extras/nodes_wan.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index acf86c03d654..81847dfa614e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1034,7 +1034,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + # comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ comfy.ldm.modules.attention.LOG_CONTENTS = {} output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: @@ -1042,8 +1042,9 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() + if comfy.ldm.modules.attention.LOG_ATTN_CALLS: + comfy.ldm.modules.attention.save_log_contents() comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ - comfy.ldm.modules.attention.save_log_contents() comfy.ldm.modules.attention.LOG_CONTENTS = {} del self.conds diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4fedbba7d440..6f4f7f676876 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1005,14 +1005,18 @@ def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io import comfy.patcher_extension import comfy.ldm.modules.attention +import logging + class AttentionOverrideTest(io.ComfyNode): @classmethod def define_schema(cls): + attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys()) return io.Schema( node_id="AttentionOverrideTest", category="devtools", inputs=[ io.Model.Input("model"), + io.Combo.Input("attention", options=attention_function_names), ], outputs=[ io.Model.Output(), @@ -1020,9 +1024,10 @@ def define_schema(cls): ) @staticmethod - def attention_override(func, transformer_options, *args, **kwargs): - new_attention = comfy.ldm.modules.attention.attention_basic - return new_attention.__wrapped__(*args, **kwargs) + def attention_override_factory(attention_func): + def attention_override(func, *args, **kwargs): + return attention_func(*args, **kwargs) + return attention_override @staticmethod def sampler_sampler_wrapper(executor, *args, **kwargs): @@ -1033,10 +1038,14 @@ def sampler_sampler_wrapper(executor, *args, **kwargs): pass @classmethod - def execute(cls, model: io.Model.Type) -> io.NodeOutput: - model = model.clone() + def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput: + attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None) + if attention_func is None: + logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.") + return model - model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override + model = model.clone() + model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) return io.NodeOutput(model) From a7d70e42a00069c9bb1d6a4e87b6c5e28264984c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 19:33:02 -0700 Subject: [PATCH 08/36] Make flux work with optimized_attention_override --- comfy/ldm/flux/layers.py | 10 +++++----- comfy/ldm/flux/math.py | 4 ++-- comfy/ldm/flux/model.py | 17 +++++++++++------ 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 113eb20962a2..ef21b416b250 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -182,7 +182,7 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N attn = attention(torch.cat((img_q, txt_q), dim=2), torch.cat((img_k, txt_k), dim=2), torch.cat((img_v, txt_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: @@ -190,7 +190,7 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] @@ -244,7 +244,7 @@ def __init__( self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -252,7 +252,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e09781768a7..4d743cda264d 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -6,7 +6,7 @@ import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: q_shape = q.shape k_shape = k.shape @@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 0a77fa0970f5..4e56a1b858bf 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -135,14 +135,16 @@ def block_wrap(args): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -151,7 +153,8 @@ def block_wrap(args): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -172,17 +175,19 @@ def block_wrap(args): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") From 48ed71caf80bc924fd50e14d25ff17e3cd8ef553 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 19:43:39 -0700 Subject: [PATCH 09/36] Add logs to verify optimized_attention_override is passed all the way into attention function --- comfy/ldm/modules/attention.py | 9 ++++++++- comfy/samplers.py | 7 ++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f6013672b6a2..080139b77cba 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -226,7 +226,14 @@ def wrapper(*args, **kwargs): # move up the stack frame = frame.f_back - LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + # check if we get what we want from transformer_options + t_check = "❌❌❌" + transformer_options = kwargs.get("transformer_options", None) + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + t_check = "✅✅✅" + + LOG_CONTENTS["|".join(logged_stack)] = (t_check, logged_stack_to_index, logged_stack) finally: # Important: break ref cycles so tensors aren't pinned diff --git a/comfy/samplers.py b/comfy/samplers.py index 81847dfa614e..d76d686ee904 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1034,8 +1034,13 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - # comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ comfy.ldm.modules.attention.LOG_CONTENTS = {} + if "optimized_attention_override" not in self.model_options["transformer_options"]: + def optimized_attention_override(func, *args, **kwargs): + return func(*args, **kwargs) + self.model_options["transformer_options"]["optimized_attention_override"] = optimized_attention_override + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) From f752715aac62f0b29c91b29340a08c5dc7577200 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 19:52:52 -0700 Subject: [PATCH 10/36] Make Qwen work with optimized_attention_override --- comfy/ldm/qwen_image/model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 57a458210032..d5fe003c5cd0 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -132,6 +132,7 @@ def forward( encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: seq_txt = encoder_hidden_states.shape[1] @@ -159,7 +160,7 @@ def forward( joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -226,6 +227,7 @@ def forward( encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) txt_mod_params = self.txt_mod(temb) @@ -242,6 +244,7 @@ def forward( encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) hidden_states = hidden_states + img_gate1 * img_attn_output @@ -434,9 +437,9 @@ def _forward( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -446,11 +449,12 @@ def block_wrap(args): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] From 4cafd58f714fbaceb9f326a968db6b5749e768bd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:10:50 -0700 Subject: [PATCH 11/36] Made hidream work with optimized_attention_override --- comfy/ldm/hidream/model.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index ae49cf94586b..28d81c79e55b 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -72,8 +72,8 @@ def forward(self, timesteps, wdtype): return t_emb -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options) class HiDreamAttnProcessor_flashattn: @@ -86,6 +86,7 @@ def __call__( image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, *args, **kwargs, ) -> torch.FloatTensor: @@ -133,7 +134,7 @@ def __call__( query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) - hidden_states = attention(query, key, value) + hidden_states = attention(query, key, value, transformer_options=transformer_options) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) @@ -199,6 +200,7 @@ def forward( image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.Tensor: return self.processor( self, @@ -206,6 +208,7 @@ def forward( image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) @@ -406,7 +409,7 @@ def forward( text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, - + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ @@ -419,6 +422,7 @@ def forward( norm_image_tokens, image_tokens_masks, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -483,6 +487,7 @@ def forward( text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ @@ -500,6 +505,7 @@ def forward( image_tokens_masks, norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -550,6 +556,7 @@ def forward( text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: return self.block( image_tokens, @@ -557,6 +564,7 @@ def forward( text_tokens, adaln_input, rope, + transformer_options=transformer_options, ) @@ -786,6 +794,7 @@ def _forward( text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, + transformer_options=transformer_options, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -809,6 +818,7 @@ def _forward( text_tokens=None, adaln_input=adaln_input, rope=rope, + transformer_options=transformer_options, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 From 1ddfb5bb14fa0a4d25296e14d39f5376eec1d843 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:13:51 -0700 Subject: [PATCH 12/36] Made wan patches_replace work with optimized_attention_override --- comfy/ldm/wan/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 7627da6430c2..e53a49769c14 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -561,9 +561,9 @@ def forward_orig( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) From 0ac5c6344f0f07dcadfc962141b46d899268c726 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:21:14 -0700 Subject: [PATCH 13/36] Made SD3 work with optimized_attention_override --- comfy/ldm/modules/diffusionmodules/mmdit.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 4d6beba2d762..42f406f1aaa3 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): return _block_mixing(*args, **kwargs) -def _block_mixing(context, x, context_block, x_block, c): +def _block_mixing(context, x, context_block, x_block, c, transformer_options={}): context_qkv, context_intermediates = context_block.pre_attention(context, c) if x_block.x_block_self_attn: @@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn = optimized_attention( qkv[0], qkv[1], qkv[2], heads=x_block.attn.num_heads, + transformer_options=transformer_options, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn2 = optimized_attention( x_qkv2[0], x_qkv2[1], x_qkv2[2], heads=x_block.attn2.num_heads, + transformer_options=transformer_options, ) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) else: @@ -958,10 +960,10 @@ def forward_core_with_concat( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap}) context = out["txt"] x = out["img"] else: @@ -970,6 +972,7 @@ def block_wrap(args): x, c=c_mod, use_checkpoint=self.use_checkpoint, + transformer_options=transformer_options, ) if control is not None: control_o = control.get("output") From ef894cdf08d1a1d555e8f2c451088d041575ae66 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:26:53 -0700 Subject: [PATCH 14/36] Made HunyuanVideo work with optimized_attention_override --- comfy/ldm/hunyuan_video/model.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index da1011596b5c..17089c701e7e 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -78,13 +78,13 @@ def __init__( operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn.qkv(norm_x) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) - attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) + attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) @@ -115,14 +115,14 @@ def __init__( ] ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): m = None if mask is not None: m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = m + m.transpose(2, 3) for block in self.blocks: - x = block(x, c, m) + x = block(x, c, m, transformer_options=transformer_options) return x @@ -150,6 +150,7 @@ def forward( x, timesteps, mask, + transformer_options={}, ): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) @@ -158,7 +159,7 @@ def forward( c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options) return x class HunyuanVideo(nn.Module): @@ -267,7 +268,7 @@ def forward_orig( if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - txt = self.txt_in(txt, timesteps, txt_mask) + txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -285,14 +286,14 @@ def forward_orig( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -307,13 +308,13 @@ def block_wrap(args): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") From 61b5c5fc75eec4de5ec782b741880ef09410d2f9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:34:06 -0700 Subject: [PATCH 15/36] Made Mochi work with optimized_attention_override --- comfy/ldm/genmo/joint_model/asymm_models_joint.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 366a8b7133cb..5c1bb4d42330 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -109,6 +109,7 @@ def forward( scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. crop_y, + transformer_options={}, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") @@ -143,7 +144,7 @@ def forward( xy = optimized_attention(q, k, - v, self.num_heads, skip_reshape=True) + v, self.num_heads, skip_reshape=True, transformer_options=transformer_options) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x = self.proj_x(x) @@ -224,6 +225,7 @@ def forward( x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + transformer_options={}, **attn_kwargs, ): """Forward pass of a block. @@ -256,6 +258,7 @@ def forward( y, scale_x=scale_msa_x, scale_y=scale_msa_y, + transformer_options=transformer_options, **attn_kwargs, ) @@ -524,10 +527,11 @@ def block_wrap(args): args["txt"], rope_cos=args["rope_cos"], rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] + crop_y=args["num_tokens"], + transformer_options=args["transformer_options"] ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] else: @@ -538,6 +542,7 @@ def block_wrap(args): rope_cos=rope_cos, rope_sin=rope_sin, crop_y=num_tokens, + transformer_options=transformer_options, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. From 2cda45d1b4d4a3ca5f3e94a16ef98a64a7c5f0ff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:42:22 -0700 Subject: [PATCH 16/36] Made LTX work with optimized_attention_override --- comfy/ldm/lightricks/model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index aa2ea62b16a5..def365ba70b8 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -271,7 +271,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, mask=None, pe=None): + def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -285,9 +285,9 @@ def forward(self, x, context=None, mask=None, pe=None): k = apply_rotary_emb(k, pe) if mask is None: - out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -303,12 +303,12 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa - x += self.attn2(x, context=context, mask=attention_mask) + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp x += self.ff(y) * gate_mlp @@ -479,10 +479,10 @@ def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transfor if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -490,7 +490,8 @@ def block_wrap(args): context=context, attention_mask=attention_mask, timestep=timestep, - pe=pe + pe=pe, + transformer_options=transformer_options, ) # 3. Output From 9461f3038783dec19cf8c12bbdfffc171c32f792 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:56:56 -0700 Subject: [PATCH 17/36] Made StableAudio work with optimized_attention_override --- comfy/ldm/audio/dit.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67eac4..5ca0e1e52610 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ def forward( mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ def forward( heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ def forward( global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ def forward( residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ def forward( x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ def forward( return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -645,13 +648,13 @@ def forward( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: From 27ebd312aefa8dd0e7c00487d118cc5060ed835b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:03:28 -0700 Subject: [PATCH 18/36] Made optimized_attention_override work with ACE Step --- comfy/ldm/ace/attention.py | 9 ++++++++- comfy/ldm/ace/model.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index f20a01669145..670eb9783385 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -133,6 +133,7 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + transformer_options={}, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( @@ -140,6 +141,7 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + transformer_options=transformer_options, **cross_attention_kwargs, ) @@ -366,6 +368,7 @@ def __call__( encoder_attention_mask: Optional[torch.FloatTensor] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, + transformer_options={}, *args, **kwargs, ) -> torch.Tensor: @@ -433,7 +436,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = optimized_attention( - query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options, ).to(query.dtype) # linear proj @@ -697,6 +700,7 @@ def forward( rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, + transformer_options={}, ): N = hidden_states.shape[0] @@ -720,6 +724,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) else: attn_output, _ = self.attn( @@ -729,6 +734,7 @@ def forward( encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, + transformer_options=transformer_options, ) if self.use_adaln_single: @@ -743,6 +749,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) hidden_states = attn_output + hidden_states diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 41d85eeb571b..3993298539be 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -314,6 +314,7 @@ def decode( output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + transformer_options={}, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -339,6 +340,7 @@ def decode( rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis, temb=temb, + transformer_options=transformer_options, ) output = self.final_layer(hidden_states, embedded_timestep, output_length) @@ -393,6 +395,7 @@ def _forward( output_length = hidden_states.shape[-1] + transformer_options = kwargs.get("transformer_options", {}) output = self.decode( hidden_states=hidden_states, attention_mask=attention_mask, @@ -402,6 +405,7 @@ def _forward( output_length=output_length, block_controlnet_hidden_states=block_controlnet_hidden_states, controlnet_scale=controlnet_scale, + transformer_options=transformer_options, ) return output From 8b9b4bbb62a5cbef5a71bb15e759f0ca83704d0f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:06:44 -0700 Subject: [PATCH 19/36] Made Hunyuan3D work with optimized_attention_override --- comfy/ldm/hunyuan3d/model.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 0fa5e78c1594..4991b164531e 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -99,14 +99,16 @@ def block_wrap(args): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -115,7 +117,8 @@ def block_wrap(args): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) img = torch.cat((txt, img), 1) @@ -126,17 +129,19 @@ def block_wrap(args): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = img[:, txt.shape[1]:, ...] img = self.final_layer(img, vec) From 4a44ed4a76c784c729edaf501e885f09f4012649 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:18:34 -0700 Subject: [PATCH 20/36] Make CosmosPredict2 work with optimized_attention_override --- comfy/ldm/cosmos/predict2.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index fcc83ba76d0d..07a4fc79ff96 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -44,7 +44,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: """Computes multi-head attention using PyTorch's native implementation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation. @@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) class Attention(nn.Module): @@ -180,8 +180,8 @@ def apply_norm_and_rotary_pos_emb( return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] return self.output_dropout(self.output_proj(result)) def forward( @@ -189,6 +189,7 @@ def forward( x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Args: @@ -196,7 +197,7 @@ def forward( context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + return self.compute_attention(q, k, v, transformer_options=transformer_options) class Timesteps(nn.Module): @@ -459,6 +460,7 @@ def forward( rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,6 +514,7 @@ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -525,6 +528,7 @@ def _x_fn( layer_norm_cross_attn: Callable, _scale_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: _normalized_x_B_T_H_W_D = _fn( _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D @@ -534,6 +538,7 @@ def _x_fn( rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -547,6 +552,7 @@ def _x_fn( self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, ) x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -865,6 +871,7 @@ def _forward( "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), } for block in self.blocks: x_B_T_H_W_D = block( From 8fe2dea29729a52d0d7c86342a0b25e216efea5c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:23:03 -0700 Subject: [PATCH 21/36] Made CosmosVideo work with optimized_attention_override --- comfy/ldm/cosmos/blocks.py | 10 +++++++++- comfy/ldm/cosmos/model.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3ff9d..afb43d469dae 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ def forward( context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ def forward( context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ def forward( context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ def forward( context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ def forward( crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ def forward( adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ def forward( context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ def forward( crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ def forward( crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 53698b758b53..52ef7ef4340a 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -520,6 +520,7 @@ def _forward( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -534,6 +535,7 @@ def _forward( crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") From 09c84b31a25e321749dca04f2c089e39cdf0a522 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:30:18 -0700 Subject: [PATCH 22/36] Made Omnigen 2 work with optimized_attention_override --- comfy/ldm/omnigen/omnigen2.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 4884449f85ff..82edc92da71e 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -120,7 +120,7 @@ def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps nn.Dropout(0.0) ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) @@ -146,7 +146,7 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) - hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) + hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) hidden_states = self.to_out[0](hidden_states) return hidden_states @@ -182,16 +182,16 @@ def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multip self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) @@ -390,7 +390,7 @@ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): ref_img_sizes, img_sizes, ) - def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): + def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) @@ -405,17 +405,17 @@ def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, pad shift += ref_img_len for layer in self.noise_refiner: - hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: - ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) + ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states - def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape @@ -444,7 +444,7 @@ def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention ) for layer in self.context_refiner: - text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( @@ -453,13 +453,14 @@ def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, + transformer_options=transformer_options, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) From 034d6c12e6b8c08d353908d6704edc2e4d12b4d9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:42:08 -0700 Subject: [PATCH 23/36] Made StableCascade work with optimized_attention_override --- comfy/ldm/cascade/common.py | 12 ++++++------ comfy/ldm/cascade/stage_b.py | 14 +++++++------- comfy/ldm/cascade/stage_c.py | 14 +++++++------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c821ccc..42ef98c7a78c 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -32,12 +32,12 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) - def forward(self, q, k, v): + def forward(self, q, k, v, transformer_options={}): q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options) return self.out_proj(out) @@ -47,13 +47,13 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv, self_attn=False, transformer_options={}): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] - x = self.attn(x, kv, kv) + x = self.attn(x, kv, kv, transformer_options=transformer_options) x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -114,9 +114,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, de operations.Linear(c_cond, c, dtype=dtype, device=device) ) - def forward(self, x, kv): + def forward(self, x, kv, transformer_options={}): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options) return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 77383095681d..428c67fdf621 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -173,7 +173,7 @@ def gen_c_embeddings(self, clip): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip): + def _down_encode(self, x, r_embed, clip, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -187,7 +187,7 @@ def _down_encode(self, x, r_embed, clip): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -199,7 +199,7 @@ def _down_encode(self, x, r_embed, clip): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip): + def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -216,7 +216,7 @@ def _up_decode(self, level_outputs, r_embed, clip): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -228,7 +228,7 @@ def _up_decode(self, level_outputs, r_embed, clip): x = upscaler(x) return x - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) @@ -245,8 +245,8 @@ def forward(self, x, r, effnet, clip, pixels=None, **kwargs): nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) + level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d0349057..ebc4434e2bf1 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -182,7 +182,7 @@ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip, cnet=None): + def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -201,7 +201,7 @@ def _down_encode(self, x, r_embed, clip, cnet=None): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -213,7 +213,7 @@ def _down_encode(self, x, r_embed, clip, cnet=None): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -235,7 +235,7 @@ def _up_decode(self, level_outputs, r_embed, clip, cnet=None): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -247,7 +247,7 @@ def _up_decode(self, level_outputs, r_embed, clip, cnet=None): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -262,8 +262,8 @@ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **k # Model Blocks x = self.embedding(x) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): From 17090c56be0f555c1aa80c7be27fbdca2f9fe969 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:46:56 -0700 Subject: [PATCH 24/36] Made AuraFlow work with optimized_attention_override --- comfy/ldm/aura/mmdit.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index d7f32b5e82f7..66d9613b61de 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -85,7 +85,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera ) #@torch.compile() - def forward(self, c): + def forward(self, c, transformer_options={}): bsz, seqlen1, _ = c.shape @@ -95,7 +95,7 @@ def forward(self, c): v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) q, k = self.q_norm1(q), self.k_norm1(k) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c = self.w1o(output) return c @@ -144,7 +144,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera #@torch.compile() - def forward(self, c, x): + def forward(self, c, x, transformer_options={}): bsz, seqlen1, _ = c.shape bsz, seqlen2, _ = x.shape @@ -168,7 +168,7 @@ def forward(self, c, x): torch.cat([cv, xv], dim=1), ) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) @@ -207,7 +207,7 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, self.is_last = is_last #@torch.compile() - def forward(self, c, x, global_cond, **kwargs): + def forward(self, c, x, global_cond, transformer_options={}, **kwargs): cres, xres = c, x @@ -225,7 +225,7 @@ def forward(self, c, x, global_cond, **kwargs): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) + c, x = self.attn(c, x, transformer_options=transformer_options) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) @@ -255,13 +255,13 @@ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, o self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) #@torch.compile() - def forward(self, cx, global_cond, **kwargs): + def forward(self, cx, global_cond, transformer_options={}, **kwargs): cxres = cx shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( global_cond ).chunk(6, dim=1) cx = modulate(self.norm1(cx), shift_msa, scale_msa) - cx = self.attn(cx) + cx = self.attn(cx, transformer_options=transformer_options) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout @@ -473,13 +473,14 @@ def block_wrap(args): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], - args["vec"]) + args["vec"], + transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) c = out["txt"] x = out["img"] else: - c, x = layer(c, x, global_cond, **kwargs) + c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) @@ -488,13 +489,13 @@ def block_wrap(args): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], args["vec"]) + out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) cx = out["img"] else: - cx = layer(cx, global_cond, **kwargs) + cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs) x = cx[:, c_len:] From d644aba6bca0f46cb168ea38164ebaea860de399 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 22:00:44 -0700 Subject: [PATCH 25/36] Made Lumina work with optimized_attention_override --- comfy/ldm/lumina/model.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e08ed817de39..f87d98ac0e35 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -104,6 +104,7 @@ def forward( x: torch.Tensor, x_mask: torch.Tensor, freqs_cis: torch.Tensor, + transformer_options={}, ) -> torch.Tensor: """ @@ -140,7 +141,7 @@ def forward( if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -268,6 +269,7 @@ def forward( x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + transformer_options={}, ): """ Perform a forward pass through the TransformerBlock. @@ -290,6 +292,7 @@ def forward( modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( @@ -304,6 +307,7 @@ def forward( self.attention_norm1(x), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + self.ffn_norm2( @@ -494,7 +498,7 @@ def unpatchify( return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: bsz = len(x) pH = pW = self.patch_size @@ -554,7 +558,7 @@ def patchify_and_embed( # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) # refine image flat_x = [] @@ -573,7 +577,7 @@ def patchify_and_embed( padded_img_embed = self.x_embedder(padded_img_embed) padded_img_mask = padded_img_mask.unsqueeze(1) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) if cap_mask is not None: mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) @@ -616,12 +620,13 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwa cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) x = self.final_layer(x, adaln_input) x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] From 8be3edb606143158ee9706919b221c2ea66e0776 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 22:45:31 -0700 Subject: [PATCH 26/36] Made Chroma work with optimized_attention_override --- comfy/ldm/chroma/layers.py | 8 ++++---- comfy/ldm/chroma/model.py | 17 +++++++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2a0dec606239..fc7110cce50e 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -76,7 +76,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention @@ -95,7 +95,7 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -148,7 +148,7 @@ def __init__( self.mlp_act = nn.GELU(approximate="tanh") - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: mod = vec x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -157,7 +157,7 @@ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x.addcmul_(mod.gate, output) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 5cff44dc8a6f..4f709f87d7c9 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -193,14 +193,16 @@ def block_wrap(args): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": double_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -209,7 +211,8 @@ def block_wrap(args): txt=txt, vec=double_mod, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -229,17 +232,19 @@ def block_wrap(args): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": single_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") From 2d13bf1c7a7da9e60613949cf503c057af85cb19 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 22:45:45 -0700 Subject: [PATCH 27/36] Made SVD work with optimized_attention_override --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 080139b77cba..361ad10fe55a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1181,7 +1181,7 @@ def forward( B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options) x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) From 1ae6fe14a75dfd95aad25cb3feaa02e7d5a76463 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Aug 2025 02:31:16 -0700 Subject: [PATCH 28/36] Fix WanI2VCrossAttention so that it expects to receive transformer_options --- comfy/ldm/wan/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e53a49769c14..0759027526a8 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -117,7 +117,7 @@ def __init__(self, # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_img_len): + def forward(self, x, context, context_img_len, transformer_options={}): r""" Args: x(Tensor): Shape [B, L1, C] @@ -132,9 +132,9 @@ def forward(self, x, context, context_img_len): v = self.v(context) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) + img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) # output x = x + img_x From af288b9946ea29d76ad2a442467d0754a93c3d97 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Aug 2025 13:06:37 -0700 Subject: [PATCH 29/36] Fixed Wan2.1 Fun Camera transformer_options passthrough --- comfy/ldm/wan/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0759027526a8..5a9c0b1f1856 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -843,12 +843,12 @@ def forward_orig( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) From d553073a1ee17d10de0cbf1b06689c4f2ddc2e4f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Aug 2025 13:20:43 -0700 Subject: [PATCH 30/36] Fixed WAN 2.1 VACE transformer_options passthrough --- comfy/ldm/wan/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5a9c0b1f1856..970ae4e5c018 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -744,17 +744,17 @@ def forward_orig( if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): - c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) x += c_skip * vace_strength[iii] del c_skip # head From cb959f9669c9a91e4122f3f3692547d33c25d7ad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Aug 2025 21:48:36 -0700 Subject: [PATCH 31/36] Add optimized to get_attention_function --- comfy/ldm/modules/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 361ad10fe55a..3c9df9a00312 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -51,7 +51,9 @@ def register_attention_function(name: str, func: Callable): logging.warning(f"Attention function {name} already registered, skipping registration.") def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: - if name not in REGISTERED_ATTENTION_FUNCTIONS: + if name == "optimized": + return optimized_attention + elif name not in REGISTERED_ATTENTION_FUNCTIONS: if default is ...: raise KeyError(f"Attention function {name} not found.") else: @@ -62,7 +64,7 @@ def _register_core_attention_functions(): """ Register attention functions exposed by core ComfyUI. """ - # NOTE: attention_basic is purposely not registered, as it is not used in code + # NOTE: attention_basic is purposely not registered, as it should not be used if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) if FLASH_ATTENTION_IS_AVAILABLE: From 720d0a88e689b82b4d1bb6711cb8e280a52e0c71 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 01:11:34 -0700 Subject: [PATCH 32/36] Disable attention logs for now --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 79d26db35da8..7775bcada116 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1034,7 +1034,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ comfy.ldm.modules.attention.LOG_CONTENTS = {} if "optimized_attention_override" not in self.model_options["transformer_options"]: def optimized_attention_override(func, *args, **kwargs): From eaa9433ff8f5b932e0e018fbff791517311e6b0b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 14:45:12 -0700 Subject: [PATCH 33/36] Remove attention logging code --- comfy/ldm/modules/attention.py | 102 --------------------------------- comfy/samplers.py | 12 ---- 2 files changed, 114 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3c9df9a00312..804fd0df9b55 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -132,114 +132,12 @@ def forward(self, x): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -import inspect -LOG_ATTN_CALLS = False -LOG_CONTENTS = {} - -def save_log_contents(): - import folder_paths - output_dir = folder_paths.get_output_directory() - - # Create attn_logs directory if it doesn't exist - attn_logs_dir = os.path.join(output_dir, "attn_logs") - os.makedirs(attn_logs_dir, exist_ok=True) - - # Generate timestamp filename (down to second) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"{timestamp}.json" - filepath = os.path.join(attn_logs_dir, filename) - - # Save LOG_CONTENTS as JSON file - try: - with open(filepath, 'w', encoding='utf-8') as f: - json.dump(list(LOG_CONTENTS.values()), f, indent=2, ensure_ascii=False) - logging.info(f"Saved attention log contents to {filepath}") - except Exception as e: - logging.error(f"Failed to save attention log contents: {e}") - -def get_class_from_frame(frame): - # Check for 'self' (instance method) or 'cls' (classmethod) - if 'self' in frame.f_locals: - return frame.f_locals['self'].__class__.__name__ - elif 'cls' in frame.f_locals: - return frame.f_locals['cls'].__name__ - return None - -def has_transformer_options_passed(frame): - if 'transformer_options' in frame.f_locals.keys(): - if frame.f_locals['transformer_options']: - return True - return False def wrap_attn(func): @functools.wraps(func) def wrapper(*args, **kwargs): remove_attn_wrapper_key = False try: - if LOG_ATTN_CALLS: - continue_to_add = True - to_add = 1000 - logged_stack = [] - logged_stack_to_index = -1 - - frame = inspect.currentframe() - try: - # skip wrapper, start at actual wrapped function - frame = frame.f_back - - while frame and continue_to_add and to_add > 0: - code = frame.f_code - filename = code.co_filename - function = code.co_name - lineno = frame.f_lineno - - if function == "_calc_cond_batch_outer": - break - if 'venv' in filename: - frame = frame.f_back - continue - elif 'ComfyUI' not in filename: - frame = frame.f_back - continue - elif 'execution.py' in filename: - frame = frame.f_back - continue - elif 'patcher_extension.py' in filename: - frame = frame.f_back - continue - - to_add -= 1 - cls_name = get_class_from_frame(frame) - log_string = f"{filename}:{lineno}" - if cls_name: - log_string += f":{cls_name}.{function}" - else: - log_string += f":{function}" - - if has_transformer_options_passed(frame): - log_string += ":✅" - if logged_stack_to_index == -1: - logged_stack_to_index = len(logged_stack) - else: - log_string += ":❌" - - logged_stack.append(log_string) - - # move up the stack - frame = frame.f_back - - # check if we get what we want from transformer_options - t_check = "❌❌❌" - transformer_options = kwargs.get("transformer_options", None) - if transformer_options is not None: - if "optimized_attention_override" in transformer_options: - t_check = "✅✅✅" - - LOG_CONTENTS["|".join(logged_stack)] = (t_check, logged_stack_to_index, logged_stack) - - finally: - # Important: break ref cycles so tensors aren't pinned - del frame if "_inside_attn_wrapper" not in kwargs: transformer_options = kwargs.get("transformer_options", None) remove_attn_wrapper_key = True diff --git a/comfy/samplers.py b/comfy/samplers.py index 7775bcada116..b3202cec6f2c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1019,7 +1019,6 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) preprocess_conds_hooks(self.conds) - import comfy.ldm.modules.attention #TODO: Remove this $$$$$ try: orig_model_options = self.model_options self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) @@ -1034,23 +1033,12 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ - comfy.ldm.modules.attention.LOG_CONTENTS = {} - if "optimized_attention_override" not in self.model_options["transformer_options"]: - def optimized_attention_override(func, *args, **kwargs): - return func(*args, **kwargs) - self.model_options["transformer_options"]["optimized_attention_override"] = optimized_attention_override - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() - if comfy.ldm.modules.attention.LOG_ATTN_CALLS: - comfy.ldm.modules.attention.save_log_contents() - comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ - comfy.ldm.modules.attention.LOG_CONTENTS = {} del self.conds return output From c092b8a4acde135f6e312a81bdfbf31ea8f8e14b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 14:49:04 -0700 Subject: [PATCH 34/36] Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case --- comfy/ldm/modules/attention.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 804fd0df9b55..e26f66bb3b8f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -60,21 +60,6 @@ def get_attention_function(name: str, default: Any=...) -> Union[Callable, None] return default return REGISTERED_ATTENTION_FUNCTIONS[name] -def _register_core_attention_functions(): - """ - Register attention functions exposed by core ComfyUI. - """ - # NOTE: attention_basic is purposely not registered, as it should not be used - if SAGE_ATTENTION_IS_AVAILABLE: - register_attention_function("sage", attention_sage) - if FLASH_ATTENTION_IS_AVAILABLE: - register_attention_function("flash", attention_flash) - if model_management.xformers_enabled(): - register_attention_function("xformers", attention_xformers) - register_attention_function("pytorch", attention_pytorch) - register_attention_function("sub_quad", attention_sub_quad) - register_attention_function("split", attention_split) - from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -657,10 +642,22 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -_register_core_attention_functions() optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): From dd0a5093f6ac499ae4ae475e9e2e74785ca654eb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 14:58:30 -0700 Subject: [PATCH 35/36] Satisfy ruff --- comfy/ldm/modules/attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e26f66bb3b8f..bf2553c37f54 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,8 +1,5 @@ import math import sys -import json -import os -from datetime import datetime import torch import torch.nn.functional as F From 66c4eb006bcc068b202891151174ecb8d6d0bf57 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 15:19:36 -0700 Subject: [PATCH 36/36] Remove AttentionOverrideTest node, that's something to cook up for later --- comfy_extras/nodes_wan.py | 46 --------------------------------------- 1 file changed, 46 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 66806e34d0a1..4f73369f5b03 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1058,51 +1058,6 @@ def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return io.NodeOutput(out_latent) -import comfy.patcher_extension -import comfy.ldm.modules.attention -import logging - -class AttentionOverrideTest(io.ComfyNode): - @classmethod - def define_schema(cls): - attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys()) - return io.Schema( - node_id="AttentionOverrideTest", - category="devtools", - inputs=[ - io.Model.Input("model"), - io.Combo.Input("attention", options=attention_function_names), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @staticmethod - def attention_override_factory(attention_func): - def attention_override(func, *args, **kwargs): - return attention_func(*args, **kwargs) - return attention_override - - @staticmethod - def sampler_sampler_wrapper(executor, *args, **kwargs): - try: - # extra_args = args[2] - return executor(*args, **kwargs) - finally: - pass - - @classmethod - def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput: - attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None) - if attention_func is None: - logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.") - return model - - model = model.clone() - model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func) - model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) - return io.NodeOutput(model) class WanExtension(ComfyExtension): @override @@ -1121,7 +1076,6 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: WanSoundImageToVideo, WanSoundImageToVideoExtend, Wan22ImageToVideoLatent, - AttentionOverrideTest, ] async def comfy_entrypoint() -> WanExtension: