Skip to content

Commit dcb8834

Browse files
authored
convert Cosmos nodes to V3 schema (#9721)
1 parent f9d2e4b commit dcb8834

File tree

1 file changed

+72
-57
lines changed

1 file changed

+72
-57
lines changed

comfy_extras/nodes_cosmos.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
1+
from typing_extensions import override
12
import nodes
23
import torch
34
import comfy.model_management
45
import comfy.utils
56
import comfy.latent_formats
67

8+
from comfy_api.latest import ComfyExtension, io
79

8-
class EmptyCosmosLatentVideo:
9-
@classmethod
10-
def INPUT_TYPES(s):
11-
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
12-
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
13-
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
14-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
15-
RETURN_TYPES = ("LATENT",)
16-
FUNCTION = "generate"
1710

18-
CATEGORY = "latent/video"
11+
class EmptyCosmosLatentVideo(io.ComfyNode):
12+
@classmethod
13+
def define_schema(cls) -> io.Schema:
14+
return io.Schema(
15+
node_id="EmptyCosmosLatentVideo",
16+
category="latent/video",
17+
inputs=[
18+
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
19+
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
20+
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
21+
io.Int.Input("batch_size", default=1, min=1, max=4096),
22+
],
23+
outputs=[io.Latent.Output()],
24+
)
1925

20-
def generate(self, width, height, length, batch_size=1):
26+
@classmethod
27+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
2128
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
22-
return ({"samples": latent}, )
29+
return io.NodeOutput({"samples": latent})
2330

2431

2532
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
@@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0):
3340
return latent_temp[:, :, :latent_len]
3441

3542

36-
class CosmosImageToVideoLatent:
43+
class CosmosImageToVideoLatent(io.ComfyNode):
3744
@classmethod
38-
def INPUT_TYPES(s):
39-
return {"required": {"vae": ("VAE", ),
40-
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
41-
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
42-
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
43-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
44-
},
45-
"optional": {"start_image": ("IMAGE", ),
46-
"end_image": ("IMAGE", ),
47-
}}
48-
45+
def define_schema(cls) -> io.Schema:
46+
return io.Schema(
47+
node_id="CosmosImageToVideoLatent",
48+
category="conditioning/inpaint",
49+
inputs=[
50+
io.Vae.Input("vae"),
51+
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
52+
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
53+
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
54+
io.Int.Input("batch_size", default=1, min=1, max=4096),
55+
io.Image.Input("start_image", optional=True),
56+
io.Image.Input("end_image", optional=True),
57+
],
58+
outputs=[io.Latent.Output()],
59+
)
4960

50-
RETURN_TYPES = ("LATENT",)
51-
FUNCTION = "encode"
52-
53-
CATEGORY = "conditioning/inpaint"
54-
55-
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
61+
@classmethod
62+
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
5663
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
5764
if start_image is None and end_image is None:
5865
out_latent = {}
5966
out_latent["samples"] = latent
60-
return (out_latent,)
67+
return io.NodeOutput(out_latent)
6168

6269
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
6370

@@ -74,33 +81,33 @@ def encode(self, vae, width, height, length, batch_size, start_image=None, end_i
7481
out_latent = {}
7582
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
7683
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
77-
return (out_latent,)
84+
return io.NodeOutput(out_latent)
7885

79-
class CosmosPredict2ImageToVideoLatent:
86+
class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
8087
@classmethod
81-
def INPUT_TYPES(s):
82-
return {"required": {"vae": ("VAE", ),
83-
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
84-
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
85-
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
86-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
87-
},
88-
"optional": {"start_image": ("IMAGE", ),
89-
"end_image": ("IMAGE", ),
90-
}}
88+
def define_schema(cls) -> io.Schema:
89+
return io.Schema(
90+
node_id="CosmosPredict2ImageToVideoLatent",
91+
category="conditioning/inpaint",
92+
inputs=[
93+
io.Vae.Input("vae"),
94+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
95+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
96+
io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
97+
io.Int.Input("batch_size", default=1, min=1, max=4096),
98+
io.Image.Input("start_image", optional=True),
99+
io.Image.Input("end_image", optional=True),
100+
],
101+
outputs=[io.Latent.Output()],
102+
)
91103

92-
93-
RETURN_TYPES = ("LATENT",)
94-
FUNCTION = "encode"
95-
96-
CATEGORY = "conditioning/inpaint"
97-
98-
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
104+
@classmethod
105+
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
99106
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
100107
if start_image is None and end_image is None:
101108
out_latent = {}
102109
out_latent["samples"] = latent
103-
return (out_latent,)
110+
return io.NodeOutput(out_latent)
104111

105112
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
106113

@@ -119,10 +126,18 @@ def encode(self, vae, width, height, length, batch_size, start_image=None, end_i
119126
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
120127
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
121128
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
122-
return (out_latent,)
129+
return io.NodeOutput(out_latent)
130+
131+
132+
class CosmosExtension(ComfyExtension):
133+
@override
134+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
135+
return [
136+
EmptyCosmosLatentVideo,
137+
CosmosImageToVideoLatent,
138+
CosmosPredict2ImageToVideoLatent,
139+
]
140+
123141

124-
NODE_CLASS_MAPPINGS = {
125-
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
126-
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
127-
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
128-
}
142+
async def comfy_entrypoint() -> CosmosExtension:
143+
return CosmosExtension()

0 commit comments

Comments
 (0)