Skip to content

Commit 265be98

Browse files
committed
Remove Radiance dynamic nerf_embedder dtype override feature
1 parent d8af0a3 commit 265be98

File tree

3 files changed

+9
-29
lines changed

3 files changed

+9
-29
lines changed

comfy/ldm/chroma_radiance/layers.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,13 @@ def __init__(
3636
The total number of positional features will be max_freqs^2.
3737
"""
3838
super().__init__()
39-
self.dtype= dtype
39+
self.dtype = dtype
4040
self.max_freqs = max_freqs
4141
self.hidden_size_input = hidden_size_input
4242

4343
# A linear layer to project the concatenated input features and
4444
# positional encodings to the final output dimension.
45-
self.embedder = nn.Sequential(
46-
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
47-
)
45+
self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
4846

4947
@lru_cache(maxsize=4)
5048
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
@@ -101,7 +99,7 @@ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -
10199

102100
return dct
103101

104-
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
102+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
105103
"""
106104
Forward pass for the embedder.
107105
@@ -117,16 +115,11 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
117115
# Infer the patch side length from the number of pixels (P^2).
118116
patch_size = int(P2 ** 0.5)
119117

120-
# Possibly run the operation with a different dtype.
121118
input_dtype = inputs.dtype
122-
if embedder_dtype != input_dtype or self.dtype != input_dtype:
123-
embedder = self.embedder.to(dtype=embedder_dtype)
124-
inputs = inputs.to(dtype=embedder_dtype)
125-
else:
126-
embedder = self.embedder
119+
inputs = inputs.to(dtype=self.dtype)
127120

128121
# Fetch the pre-computed or cached positional embeddings.
129-
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
122+
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
130123

131124
# Repeat the positional embeddings for each item in the batch.
132125
dct = dct.repeat(B, 1, 1)
@@ -136,10 +129,7 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
136129
inputs = torch.cat((inputs, dct), dim=-1)
137130

138131
# Project the combined tensor to the target hidden size.
139-
inputs = embedder(inputs)
140-
141-
# No-op if already the same dtype.
142-
return inputs.to(dtype=input_dtype)
132+
return self.embedder(inputs).to(dtype=input_dtype)
143133

144134

145135
class NerfGLUBlock(nn.Module):

comfy/ldm/chroma_radiance/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
120120
in_channels=params.in_channels,
121121
hidden_size_input=params.nerf_hidden_size,
122122
max_freqs=params.nerf_max_freqs,
123-
dtype=dtype,
123+
dtype=params.nerf_embedder_dtype or dtype,
124124
device=device,
125125
operations=operations,
126126
)
@@ -199,7 +199,7 @@ def forward_nerf(
199199
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
200200

201201
# Get DCT-encoded pixel embeddings [pixel-dct]
202-
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
202+
img_dct = self.nerf_image_embedder(nerf_pixels)
203203

204204
# Pass through the dynamic MLP blocks (the NeRF)
205205
for block in self.nerf_blocks:
@@ -235,7 +235,6 @@ def forward_tiled_nerf(
235235
"""
236236
tile_size = params.nerf_tile_size
237237
output_tiles = []
238-
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
239238
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
240239
for i in range(0, num_patches, tile_size):
241240
end = min(i + tile_size, num_patches)
@@ -254,7 +253,7 @@ def forward_tiled_nerf(
254253
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
255254

256255
# get DCT-encoded pixel embeddings [pixel-dct]
257-
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
256+
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
258257

259258
# pass through the dynamic MLP blocks (the NeRF)
260259
for block in self.nerf_blocks:

comfy_extras/nodes_chroma_radiance.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,6 @@ def define_schema(cls) -> io.Schema:
166166
min=-1,
167167
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
168168
),
169-
io.Combo.Input(
170-
id="nerf_embedder_dtype",
171-
default="default",
172-
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
173-
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
174-
),
175169
],
176170
outputs=[io.Model.Output()],
177171
)
@@ -185,13 +179,10 @@ def execute(
185179
start_sigma: float,
186180
end_sigma: float,
187181
nerf_tile_size: int,
188-
nerf_embedder_dtype: str,
189182
) -> io.NodeOutput:
190183
radiance_options = {}
191184
if nerf_tile_size >= 0:
192185
radiance_options["nerf_tile_size"] = nerf_tile_size
193-
if nerf_embedder_dtype != "default":
194-
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
195186

196187
if not radiance_options:
197188
return io.NodeOutput(model)

0 commit comments

Comments
 (0)