Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 112 additions & 139 deletions comfy_extras/nodes_post_processing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing_extensions import override
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -7,46 +8,41 @@
import comfy.utils
import comfy.model_management
import node_helpers
from comfy_api.latest import ComfyExtension, io

class Blend:
def __init__(self):
pass
class Blend(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageBlend",
category="image/postprocessing",
inputs=[
io.Image.Input("image1"),
io.Image.Input("image2"),
io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01),
io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]),
],
outputs=[
io.Image.Output(),
],
)

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"

CATEGORY = "image/postprocessing"

def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput:
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)

blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = cls.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
return io.NodeOutput(blended_image)

def blend_mode(self, img1, img2, mode):
@classmethod
def blend_mode(cls, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
Expand All @@ -56,13 +52,13 @@ def blend_mode(self, img1, img2, mode):
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
raise ValueError(f"Unsupported blend mode: {mode}")

def g(self, x):
@classmethod
def g(cls, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))

def gaussian_kernel(kernel_size: int, sigma: float, device=None):
Expand All @@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()

class Blur:
def __init__(self):
pass
class Blur(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageBlur",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
],
outputs=[
io.Image.Output(),
],
)

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"

CATEGORY = "image/postprocessing"

def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput:
if blur_radius == 0:
return (image,)
return io.NodeOutput(image)

image = image.to(comfy.model_management.get_torch_device())
batch_size, height, width, channels = image.shape
Expand All @@ -115,31 +99,24 @@ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1)

return (blurred.to(comfy.model_management.intermediate_device()),)
return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device()))

class Quantize:
def __init__(self):
pass

class Quantize(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"

CATEGORY = "image/postprocessing"
def define_schema(cls):
return io.Schema(
node_id="ImageQuantize",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("colors", default=256, min=1, max=256, step=1),
io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]),
],
outputs=[
io.Image.Output(),
],
)

@staticmethod
def bayer(im, pal_im, order):
Expand Down Expand Up @@ -167,7 +144,8 @@ def normalized_bayer_matrix(n):
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im

def quantize(self, image: torch.Tensor, colors: int, dither: str):
@classmethod
def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput:
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)

Expand All @@ -187,46 +165,29 @@ def quantize(self, image: torch.Tensor, colors: int, dither: str):
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array

return (result,)
return io.NodeOutput(result)

class Sharpen:
def __init__(self):
pass
class Sharpen(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageSharpen",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01),
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.01
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"

CATEGORY = "image/postprocessing"

def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput:
if sharpen_radius == 0:
return (image,)
return io.NodeOutput(image)

batch_size, height, width, channels = image.shape
image = image.to(comfy.model_management.get_torch_device())
Expand All @@ -244,23 +205,29 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:

result = torch.clamp(sharpened, 0, 1)

return (result.to(comfy.model_management.intermediate_device()),)
return io.NodeOutput(result.to(comfy.model_management.intermediate_device()))

class ImageScaleToTotalPixels:
class ImageScaleToTotalPixels(io.ComfyNode):
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]

@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
def define_schema(cls):
return io.Schema(
node_id="ImageScaleToTotalPixels",
category="image/upscaling",
inputs=[
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)

CATEGORY = "image/upscaling"

def upscale(self, image, upscale_method, megapixels):
@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)

Expand All @@ -270,12 +237,18 @@ def upscale(self, image, upscale_method, megapixels):

s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1)
return (s,)

NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
"ImageScaleToTotalPixels": ImageScaleToTotalPixels,
}
return io.NodeOutput(s)

class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Blend,
Blur,
Quantize,
Sharpen,
ImageScaleToTotalPixels,
]

async def comfy_entrypoint() -> PostProcessingExtension:
return PostProcessingExtension()
Loading