@@ -29,20 +29,26 @@ def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput
29
29
30
30
31
31
class ChromaRadianceStubVAE :
32
- @classmethod
33
- def encode (cls , pixels : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
34
- device = comfy .model_management .intermediate_device ()
35
- if pixels .ndim == 3 :
36
- pixels = pixels .unsqueeze (0 )
37
- elif pixels .ndim != 4 :
38
- raise ValueError ("Unexpected input image shape" )
32
+ @staticmethod
33
+ def vae_encode_crop_pixels (pixels : torch .Tensor ) -> torch .Tensor :
39
34
dims = pixels .shape [1 :- 1 ]
40
35
for d in range (len (dims )):
41
36
d_adj = (dims [d ] // 16 ) * 16
42
37
if d_adj == d :
43
38
continue
44
39
d_offset = (dims [d ] % 16 ) // 2
45
40
pixels = pixels .narrow (d + 1 , d_offset , d_adj )
41
+ return pixels
42
+
43
+ @classmethod
44
+ def encode (cls , pixels : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
45
+ device = comfy .model_management .intermediate_device ()
46
+ if pixels .ndim == 3 :
47
+ pixels = pixels .unsqueeze (0 )
48
+ elif pixels .ndim != 4 :
49
+ raise ValueError ("Unexpected input image shape" )
50
+ # Ensure the image has spatial dimensions that are multiples of 16.
51
+ pixels = cls .vae_encode_crop_pixels (pixels )
46
52
h , w , c = pixels .shape [1 :]
47
53
if h < 16 or w < 16 :
48
54
raise ValueError ("Chroma Radiance image inputs must have height/width of at least 16 pixels." )
@@ -51,6 +57,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
51
57
pixels = pixels .expand (- 1 , - 1 , - 1 , 3 )
52
58
elif c != 3 :
53
59
raise ValueError ("Unexpected number of channels in input image" )
60
+ # Rescale to -1..1 and move the channel dimension to position 1.
54
61
latent = pixels .to (device = device , dtype = torch .float32 , copy = True )
55
62
latent = latent .clamp_ (0 , 1 ).movedim (- 1 , 1 ).contiguous ()
56
63
latent -= 0.5
@@ -60,6 +67,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
60
67
@classmethod
61
68
def decode (cls , samples : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
62
69
device = comfy .model_management .intermediate_device ()
70
+ # Rescale to 0..1 and move the channel dimension to the end.
63
71
img = samples .to (device = device , dtype = torch .float32 , copy = True )
64
72
img = img .clamp_ (- 1 , 1 ).movedim (1 , - 1 ).contiguous ()
65
73
img += 1.0
@@ -71,6 +79,7 @@ def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
71
79
72
80
@classmethod
73
81
def spacial_compression_decode (cls ) -> int :
82
+ # This just exists so the tiled VAE nodes don't crash.
74
83
return 1
75
84
76
85
spacial_compression_encode = spacial_compression_decode
@@ -115,7 +124,7 @@ def define_schema(cls) -> io.Schema:
115
124
return io .Schema (
116
125
node_id = "ChromaRadianceStubVAE" ,
117
126
category = "vae/chroma_radiance" ,
118
- description = "For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary." ,
127
+ description = "For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary." ,
119
128
outputs = [io .Vae .Output ()],
120
129
)
121
130
@@ -129,37 +138,39 @@ def define_schema(cls) -> io.Schema:
129
138
return io .Schema (
130
139
node_id = "ChromaRadianceOptions" ,
131
140
category = "model_patches/chroma_radiance" ,
132
- description = "Allows setting some advanced options for the Chroma Radiance model." ,
141
+ description = "Allows setting advanced options for the Chroma Radiance model." ,
133
142
inputs = [
134
143
io .Model .Input (id = "model" ),
135
144
io .Boolean .Input (
136
145
id = "preserve_wrapper" ,
137
146
default = True ,
138
- tooltip = "When enabled preserves an existing model wrapper if it exists. Generally should be left enabled." ,
147
+ tooltip = "When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled." ,
139
148
),
140
149
io .Float .Input (
141
150
id = "start_sigma" ,
142
151
default = 1.0 ,
143
152
min = 0.0 ,
144
153
max = 1.0 ,
154
+ tooltip = "First sigma that these options will be in effect." ,
145
155
),
146
156
io .Float .Input (
147
157
id = "end_sigma" ,
148
158
default = 0.0 ,
149
159
min = 0.0 ,
150
160
max = 1.0 ,
161
+ tooltip = "Last sigma that these options will be in effect." ,
151
162
),
152
163
io .Int .Input (
153
164
id = "nerf_tile_size" ,
154
165
default = - 1 ,
155
166
min = - 1 ,
156
- tooltip = "Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM)." ,
167
+ tooltip = "Allows overriding the default NeRF tile size. -1 means use the default (32) . 0 means use non-tiling mode (may require a lot of VRAM)." ,
157
168
),
158
169
io .Combo .Input (
159
170
id = "nerf_embedder_dtype" ,
160
171
default = "default" ,
161
172
options = ["default" , "model_dtype" , "float32" , "float64" , "float16" , "bfloat16" ],
162
- tooltip = "Allows overriding the dtype the NeRF embedder uses." ,
173
+ tooltip = "Allows overriding the dtype the NeRF embedder uses. The default is float32. " ,
163
174
),
164
175
],
165
176
outputs = [io .Model .Output ()],
0 commit comments