Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b58db69
Looking into a @wrap_attn decorator to look for 'optimized_attention_…
Kosinkadink Aug 27, 2025
68b00e9
Created logging code for this branch so that it can be used to track …
Kosinkadink Aug 28, 2025
29b7990
Fix memory usage issue with inspect
Kosinkadink Aug 28, 2025
dd21b4a
Made WAN attention receive transformer_options, test node added to wa…
Kosinkadink Aug 28, 2025
669b9ef
Added **kwargs to all attention functions so transformer_options coul…
Kosinkadink Aug 28, 2025
51a30c2
Make sure wrap_attn doesn't make itself recurse infinitely, attempt t…
Kosinkadink Aug 29, 2025
1f499f0
Turn off attention logging for now, make AttentionOverrideTestNode ha…
Kosinkadink Aug 29, 2025
a7d70e4
Make flux work with optimized_attention_override
Kosinkadink Aug 29, 2025
48ed71c
Add logs to verify optimized_attention_override is passed all the way…
Kosinkadink Aug 29, 2025
f752715
Make Qwen work with optimized_attention_override
Kosinkadink Aug 29, 2025
4cafd58
Made hidream work with optimized_attention_override
Kosinkadink Aug 29, 2025
1ddfb5b
Made wan patches_replace work with optimized_attention_override
Kosinkadink Aug 29, 2025
0ac5c63
Made SD3 work with optimized_attention_override
Kosinkadink Aug 29, 2025
ef894cd
Made HunyuanVideo work with optimized_attention_override
Kosinkadink Aug 29, 2025
61b5c5f
Made Mochi work with optimized_attention_override
Kosinkadink Aug 29, 2025
2cda45d
Made LTX work with optimized_attention_override
Kosinkadink Aug 29, 2025
9461f30
Made StableAudio work with optimized_attention_override
Kosinkadink Aug 29, 2025
27ebd31
Made optimized_attention_override work with ACE Step
Kosinkadink Aug 29, 2025
8b9b4bb
Made Hunyuan3D work with optimized_attention_override
Kosinkadink Aug 29, 2025
4a44ed4
Make CosmosPredict2 work with optimized_attention_override
Kosinkadink Aug 29, 2025
8fe2dea
Made CosmosVideo work with optimized_attention_override
Kosinkadink Aug 29, 2025
09c84b3
Made Omnigen 2 work with optimized_attention_override
Kosinkadink Aug 29, 2025
034d6c1
Made StableCascade work with optimized_attention_override
Kosinkadink Aug 29, 2025
17090c5
Made AuraFlow work with optimized_attention_override
Kosinkadink Aug 29, 2025
d644aba
Made Lumina work with optimized_attention_override
Kosinkadink Aug 29, 2025
8be3edb
Made Chroma work with optimized_attention_override
Kosinkadink Aug 29, 2025
2d13bf1
Made SVD work with optimized_attention_override
Kosinkadink Aug 29, 2025
1ae6fe1
Fix WanI2VCrossAttention so that it expects to receive transformer_op…
Kosinkadink Aug 29, 2025
af288b9
Fixed Wan2.1 Fun Camera transformer_options passthrough
Kosinkadink Aug 29, 2025
d553073
Fixed WAN 2.1 VACE transformer_options passthrough
Kosinkadink Aug 29, 2025
cb959f9
Add optimized to get_attention_function
Kosinkadink Aug 30, 2025
d9bb453
Merge branch 'master' into attention-select
Kosinkadink Aug 30, 2025
720d0a8
Disable attention logs for now
Kosinkadink Aug 30, 2025
eaa9433
Remove attention logging code
Kosinkadink Aug 30, 2025
c092b8a
Remove _register_core_attention_functions, as we wouldn't want someon…
Kosinkadink Aug 30, 2025
dd0a509
Satisfy ruff
Kosinkadink Aug 30, 2025
66c4eb0
Remove AttentionOverrideTest node, that's something to cook up for later
Kosinkadink Aug 30, 2025
f71feac
Merge branch 'master' into attention-select
Kosinkadink Sep 12, 2025
0a86b5b
Merge branch 'master' into attention-select
Kosinkadink Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion comfy/ldm/ace/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,15 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)

Expand Down Expand Up @@ -366,6 +368,7 @@ def __call__(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -433,7 +436,7 @@ def __call__(

# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)

# linear proj
Expand Down Expand Up @@ -697,6 +700,7 @@ def forward(
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):

N = hidden_states.shape[0]
Expand All @@ -720,6 +724,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
Expand All @@ -729,6 +734,7 @@ def forward(
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)

if self.use_adaln_single:
Expand All @@ -743,6 +749,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

Expand Down
4 changes: 4 additions & 0 deletions comfy/ldm/ace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def decode(
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
Expand All @@ -339,6 +340,7 @@ def decode(
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)

output = self.final_layer(hidden_states, embedded_timestep, output_length)
Expand Down Expand Up @@ -393,6 +395,7 @@ def _forward(

output_length = hidden_states.shape[-1]

transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
Expand All @@ -402,6 +405,7 @@ def _forward(
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)

return output
25 changes: 14 additions & 11 deletions comfy/ldm/audio/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def forward(
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
causal = None,
transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None

Expand Down Expand Up @@ -363,7 +364,7 @@ def forward(
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))

out = optimized_attention(q, k, v, h, skip_reshape=True)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)

if mask is not None:
Expand Down Expand Up @@ -488,7 +489,8 @@ def forward(
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:

Expand All @@ -498,12 +500,12 @@ def forward(
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual

if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)

if self.conformer is not None:
x = x + self.conformer(x)
Expand All @@ -517,10 +519,10 @@ def forward(
x = x + residual

else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)

if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)

if self.conformer is not None:
x = x + self.conformer(x)
Expand Down Expand Up @@ -606,7 +608,8 @@ def forward(
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]

Expand Down Expand Up @@ -645,13 +648,13 @@ def forward(
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out

out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)

if return_info:
Expand Down
29 changes: 15 additions & 14 deletions comfy/ldm/aura/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera
)

#@torch.compile()
def forward(self, c):
def forward(self, c, transformer_options={}):

bsz, seqlen1, _ = c.shape

Expand All @@ -95,7 +95,7 @@ def forward(self, c):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)

output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c

Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera


#@torch.compile()
def forward(self, c, x):
def forward(self, c, x, transformer_options={}):

bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
Expand All @@ -168,7 +168,7 @@ def forward(self, c, x):
torch.cat([cv, xv], dim=1),
)

output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)

c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None,
self.is_last = is_last

#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):

cres, xres = c, x

Expand All @@ -225,7 +225,7 @@ def forward(self, c, x, global_cond, **kwargs):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)

# attention
c, x = self.attn(c, x)
c, x = self.attn(c, x, transformer_options=transformer_options)


c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
Expand Down Expand Up @@ -255,13 +255,13 @@ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, o
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)

#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
Expand Down Expand Up @@ -473,13 +473,14 @@ def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
args["vec"],
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)

if len(self.single_layers) > 0:
c_len = c.size(1)
Expand All @@ -488,13 +489,13 @@ def block_wrap(args):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out

out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)

x = cx[:, c_len:]

Expand Down
12 changes: 6 additions & 6 deletions comfy/ldm/cascade/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No

self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)

out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)

return self.out_proj(out)

Expand All @@ -47,13 +47,13 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)

def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x

Expand Down Expand Up @@ -114,9 +114,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, de
operations.Linear(c_cond, c, dtype=dtype, device=device)
)

def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x


Expand Down
14 changes: 7 additions & 7 deletions comfy/ldm/cascade/stage_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def gen_c_embeddings(self, clip):
clip = self.clip_norm(clip)
return clip

def _down_encode(self, x, r_embed, clip):
def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
Expand All @@ -187,7 +187,7 @@ def _down_encode(self, x, r_embed, clip):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
Expand All @@ -199,7 +199,7 @@ def _down_encode(self, x, r_embed, clip):
level_outputs.insert(0, x)
return level_outputs

def _up_decode(self, level_outputs, r_embed, clip):
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
Expand All @@ -216,7 +216,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
Expand All @@ -228,7 +228,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
x = upscaler(x)
return x

def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)

Expand All @@ -245,8 +245,8 @@ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)

def update_weights_ema(self, src_model, beta=0.999):
Expand Down
Loading
Loading