Skip to content

Commit d8af0a3

Browse files
committed
Fix Chroma Radiance embedder dtype overriding
1 parent 2aac0a8 commit d8af0a3

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

comfy/ldm/chroma_radiance/layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
The total number of positional features will be max_freqs^2.
3737
"""
3838
super().__init__()
39+
self.dtype= dtype
3940
self.max_freqs = max_freqs
4041
self.hidden_size_input = hidden_size_input
4142

@@ -117,9 +118,10 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
117118
patch_size = int(P2 ** 0.5)
118119

119120
# Possibly run the operation with a different dtype.
120-
orig_dtype = inputs.dtype
121-
if embedder_dtype != orig_dtype:
121+
input_dtype = inputs.dtype
122+
if embedder_dtype != input_dtype or self.dtype != input_dtype:
122123
embedder = self.embedder.to(dtype=embedder_dtype)
124+
inputs = inputs.to(dtype=embedder_dtype)
123125
else:
124126
embedder = self.embedder
125127

@@ -137,7 +139,7 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
137139
inputs = embedder(inputs)
138140

139141
# No-op if already the same dtype.
140-
return inputs.to(dtype=orig_dtype)
142+
return inputs.to(dtype=input_dtype)
141143

142144

143145
class NerfGLUBlock(nn.Module):

0 commit comments

Comments
 (0)