Skip to content

Commit 2aac0a8

Browse files
committed
Move Chroma Radiance to its own directory in ldm
Minor code cleanups and tooltip improvements
1 parent a87506a commit 2aac0a8

File tree

5 files changed

+76
-35
lines changed

5 files changed

+76
-35
lines changed

comfy/ldm/chroma/layers_dct.py renamed to comfy/ldm/chroma_radiance/layers.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@ class NerfEmbedder(nn.Module):
1616
patch size, and enriches it with positional information before projecting
1717
it to a new hidden size.
1818
"""
19-
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
19+
def __init__(
20+
self,
21+
in_channels: int,
22+
hidden_size_input: int,
23+
max_freqs: int,
24+
dtype=None,
25+
device=None,
26+
operations=None,
27+
):
2028
"""
2129
Initializes the NerfEmbedder.
2230
@@ -38,7 +46,7 @@ def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device
3846
)
3947

4048
@lru_cache(maxsize=4)
41-
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
49+
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
4250
"""
4351
Generates and caches 2D DCT-like positional embeddings for a given patch size.
4452
@@ -179,14 +187,14 @@ def __init__(self, hidden_size, out_channels, dtype=None, device=None, operation
179187
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
180188
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
181189

182-
def forward(self, x):
190+
def forward(self, x: torch.Tensor) -> torch.Tensor:
183191
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
184192
# So we temporarily move the channel dimension to the end for the norm operation.
185193
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
186194

187195

188196
class NerfFinalLayerConv(nn.Module):
189-
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
197+
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
190198
super().__init__()
191199
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
192200
self.conv = operations.Conv2d(
@@ -198,7 +206,7 @@ def __init__(self, hidden_size, out_channels, dtype=None, device=None, operation
198206
device=device,
199207
)
200208

201-
def forward(self, x):
209+
def forward(self, x: torch.Tensor) -> torch.Tensor:
202210
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
203211
# So we temporarily move the channel dimension to the end for the norm operation.
204212
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

comfy/ldm/chroma/model_dct.py renamed to comfy/ldm/chroma_radiance/model.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,43 @@
1212

1313
from comfy.ldm.flux.layers import EmbedND
1414

15-
from .layers import (
15+
from comfy.ldm.chroma.model import Chroma, ChromaParams
16+
from comfy.ldm.chroma.layers import (
1617
DoubleStreamBlock,
1718
SingleStreamBlock,
1819
Approximator,
1920
)
20-
from .layers_dct import (
21+
from .layers import (
2122
NerfEmbedder,
2223
NerfGLUBlock,
2324
NerfFinalLayer,
2425
NerfFinalLayerConv,
2526
)
2627

27-
from . import model as chroma_model
28-
2928

3029
@dataclass
31-
class ChromaRadianceParams(chroma_model.ChromaParams):
30+
class ChromaRadianceParams(ChromaParams):
3231
patch_size: int
3332
nerf_hidden_size: int
3433
nerf_mlp_ratio: int
3534
nerf_depth: int
3635
nerf_max_freqs: int
37-
# nerf_tile_size of 0 means unlimited.
36+
# Setting nerf_tile_size to 0 disables tiling.
3837
nerf_tile_size: int
3938
# Currently one of linear (legacy) or conv.
4039
nerf_final_head_type: str
4140
# None means use the same dtype as the model.
4241
nerf_embedder_dtype: Optional[torch.dtype]
4342

4443

45-
class ChromaRadiance(chroma_model.Chroma):
44+
class ChromaRadiance(Chroma):
4645
"""
4746
Transformer model for flow matching on sequences.
4847
"""
4948

5049
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
50+
if operations is None:
51+
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
5152
nn.Module.__init__(self)
5253
self.dtype = dtype
5354
params = ChromaRadianceParams(**kwargs)
@@ -188,7 +189,9 @@ def forward_nerf(
188189
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
189190
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
190191

191-
if params.nerf_tile_size > 0:
192+
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
193+
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
194+
# the tile size.
192195
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
193196
else:
194197
# Reshape for per-patch processing
@@ -219,8 +222,8 @@ def forward_tiled_nerf(
219222
self,
220223
nerf_hidden: Tensor,
221224
nerf_pixels: Tensor,
222-
B: int,
223-
C: int,
225+
batch: int,
226+
channels: int,
224227
num_patches: int,
225228
patch_size: int,
226229
params: ChromaRadianceParams,
@@ -246,9 +249,9 @@ def forward_tiled_nerf(
246249

247250
# Reshape the tile for per-patch processing
248251
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
249-
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, params.hidden_size)
252+
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
250253
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
251-
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
254+
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
252255

253256
# get DCT-encoded pixel embeddings [pixel-dct]
254257
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
@@ -284,7 +287,16 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
284287
params_dict |= overrides
285288
return params.__class__(**params_dict)
286289

287-
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
290+
def _forward(
291+
self,
292+
x: Tensor,
293+
timestep: Tensor,
294+
context: Tensor,
295+
guidance: Optional[Tensor],
296+
control: Optional[dict]=None,
297+
transformer_options: dict={},
298+
**kwargs: dict,
299+
) -> Tensor:
288300
bs, c, h, w = x.shape
289301
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
290302

@@ -303,5 +315,15 @@ def _forward(self, x, timestep, context, guidance, control=None, transformer_opt
303315
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
304316
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
305317

306-
img_out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
318+
img_out = self.forward_orig(
319+
img,
320+
img_ids,
321+
context,
322+
txt_ids,
323+
timestep,
324+
guidance,
325+
control,
326+
transformer_options,
327+
attn_mask=kwargs.get("attention_mask", None),
328+
)
307329
return self.forward_nerf(img, img_out, params)

comfy/model_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import comfy.ldm.hunyuan3d.model
4343
import comfy.ldm.hidream.model
4444
import comfy.ldm.chroma.model
45-
import comfy.ldm.chroma.model_dct
45+
import comfy.ldm.chroma_radiance.model
4646
import comfy.ldm.ace.model
4747
import comfy.ldm.omnigen.omnigen2
4848
import comfy.ldm.qwen_image.model
@@ -1334,7 +1334,7 @@ def extra_conds(self, **kwargs):
13341334

13351335
class ChromaRadiance(Chroma):
13361336
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
1337-
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model_dct.ChromaRadiance)
1337+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
13381338

13391339
class ACEStep(BaseModel):
13401340
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

comfy/model_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
166166
dit_config["guidance_embed"] = len(guidance_keys) > 0
167167
return dit_config
168168

169-
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
169+
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
170170
dit_config = {}
171171
dit_config["image_model"] = "flux"
172172
dit_config["in_channels"] = 16
@@ -196,7 +196,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
196196
dit_config["out_dim"] = 3072
197197
dit_config["hidden_dim"] = 5120
198198
dit_config["n_layers"] = 5
199-
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance
199+
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
200200
dit_config["image_model"] = "chroma_radiance"
201201
dit_config["in_channels"] = 3
202202
dit_config["out_channels"] = 3

comfy_extras/nodes_chroma_radiance.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,26 @@ def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput
2929

3030

3131
class ChromaRadianceStubVAE:
32-
@classmethod
33-
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
34-
device = comfy.model_management.intermediate_device()
35-
if pixels.ndim == 3:
36-
pixels = pixels.unsqueeze(0)
37-
elif pixels.ndim != 4:
38-
raise ValueError("Unexpected input image shape")
32+
@staticmethod
33+
def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor:
3934
dims = pixels.shape[1:-1]
4035
for d in range(len(dims)):
4136
d_adj = (dims[d] // 16) * 16
4237
if d_adj == d:
4338
continue
4439
d_offset = (dims[d] % 16) // 2
4540
pixels = pixels.narrow(d + 1, d_offset, d_adj)
41+
return pixels
42+
43+
@classmethod
44+
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
45+
device = comfy.model_management.intermediate_device()
46+
if pixels.ndim == 3:
47+
pixels = pixels.unsqueeze(0)
48+
elif pixels.ndim != 4:
49+
raise ValueError("Unexpected input image shape")
50+
# Ensure the image has spatial dimensions that are multiples of 16.
51+
pixels = cls.vae_encode_crop_pixels(pixels)
4652
h, w, c = pixels.shape[1:]
4753
if h < 16 or w < 16:
4854
raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.")
@@ -51,6 +57,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
5157
pixels = pixels.expand(-1, -1, -1, 3)
5258
elif c != 3:
5359
raise ValueError("Unexpected number of channels in input image")
60+
# Rescale to -1..1 and move the channel dimension to position 1.
5461
latent = pixels.to(device=device, dtype=torch.float32, copy=True)
5562
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
5663
latent -= 0.5
@@ -60,6 +67,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
6067
@classmethod
6168
def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
6269
device = comfy.model_management.intermediate_device()
70+
# Rescale to 0..1 and move the channel dimension to the end.
6371
img = samples.to(device=device, dtype=torch.float32, copy=True)
6472
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
6573
img += 1.0
@@ -71,6 +79,7 @@ def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
7179

7280
@classmethod
7381
def spacial_compression_decode(cls) -> int:
82+
# This just exists so the tiled VAE nodes don't crash.
7483
return 1
7584

7685
spacial_compression_encode = spacial_compression_decode
@@ -115,7 +124,7 @@ def define_schema(cls) -> io.Schema:
115124
return io.Schema(
116125
node_id="ChromaRadianceStubVAE",
117126
category="vae/chroma_radiance",
118-
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
127+
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
119128
outputs=[io.Vae.Output()],
120129
)
121130

@@ -129,37 +138,39 @@ def define_schema(cls) -> io.Schema:
129138
return io.Schema(
130139
node_id="ChromaRadianceOptions",
131140
category="model_patches/chroma_radiance",
132-
description="Allows setting some advanced options for the Chroma Radiance model.",
141+
description="Allows setting advanced options for the Chroma Radiance model.",
133142
inputs=[
134143
io.Model.Input(id="model"),
135144
io.Boolean.Input(
136145
id="preserve_wrapper",
137146
default=True,
138-
tooltip="When enabled preserves an existing model wrapper if it exists. Generally should be left enabled.",
147+
tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
139148
),
140149
io.Float.Input(
141150
id="start_sigma",
142151
default=1.0,
143152
min=0.0,
144153
max=1.0,
154+
tooltip="First sigma that these options will be in effect.",
145155
),
146156
io.Float.Input(
147157
id="end_sigma",
148158
default=0.0,
149159
min=0.0,
150160
max=1.0,
161+
tooltip="Last sigma that these options will be in effect.",
151162
),
152163
io.Int.Input(
153164
id="nerf_tile_size",
154165
default=-1,
155166
min=-1,
156-
tooltip="Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM).",
167+
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
157168
),
158169
io.Combo.Input(
159170
id="nerf_embedder_dtype",
160171
default="default",
161172
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
162-
tooltip="Allows overriding the dtype the NeRF embedder uses.",
173+
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
163174
),
164175
],
165176
outputs=[io.Model.Output()],

0 commit comments

Comments
 (0)