1
+ from typing_extensions import override
2
+
1
3
import torch
2
4
5
+ from comfy_api .latest import ComfyExtension , io
6
+
7
+
3
8
# https://github.com/WeichenFan/CFG-Zero-star
4
9
def optimized_scale (positive , negative ):
5
10
positive_flat = positive .reshape (positive .shape [0 ], - 1 )
@@ -16,17 +21,20 @@ def optimized_scale(positive, negative):
16
21
17
22
return st_star .reshape ([positive .shape [0 ]] + [1 ] * (positive .ndim - 1 ))
18
23
19
- class CFGZeroStar :
24
+ class CFGZeroStar (io .ComfyNode ):
25
+ @classmethod
26
+ def define_schema (cls ) -> io .Schema :
27
+ return io .Schema (
28
+ node_id = "CFGZeroStar" ,
29
+ category = "advanced/guidance" ,
30
+ inputs = [
31
+ io .Model .Input ("model" ),
32
+ ],
33
+ outputs = [io .Model .Output (display_name = "patched_model" )],
34
+ )
35
+
20
36
@classmethod
21
- def INPUT_TYPES (s ):
22
- return {"required" : {"model" : ("MODEL" ,),
23
- }}
24
- RETURN_TYPES = ("MODEL" ,)
25
- RETURN_NAMES = ("patched_model" ,)
26
- FUNCTION = "patch"
27
- CATEGORY = "advanced/guidance"
28
-
29
- def patch (self , model ):
37
+ def execute (cls , model ) -> io .NodeOutput :
30
38
m = model .clone ()
31
39
def cfg_zero_star (args ):
32
40
guidance_scale = args ['cond_scale' ]
@@ -38,21 +46,24 @@ def cfg_zero_star(args):
38
46
39
47
return out + uncond_p * (alpha - 1.0 ) + guidance_scale * uncond_p * (1.0 - alpha )
40
48
m .set_model_sampler_post_cfg_function (cfg_zero_star )
41
- return ( m , )
49
+ return io . NodeOutput ( m )
42
50
43
- class CFGNorm :
51
+ class CFGNorm ( io . ComfyNode ) :
44
52
@classmethod
45
- def INPUT_TYPES (s ):
46
- return {"required" : {"model" : ("MODEL" ,),
47
- "strength" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 100.0 , "step" : 0.01 }),
48
- }}
49
- RETURN_TYPES = ("MODEL" ,)
50
- RETURN_NAMES = ("patched_model" ,)
51
- FUNCTION = "patch"
52
- CATEGORY = "advanced/guidance"
53
- EXPERIMENTAL = True
54
-
55
- def patch (self , model , strength ):
53
+ def define_schema (cls ) -> io .Schema :
54
+ return io .Schema (
55
+ node_id = "CFGNorm" ,
56
+ category = "advanced/guidance" ,
57
+ inputs = [
58
+ io .Model .Input ("model" ),
59
+ io .Float .Input ("strength" , default = 1.0 , min = 0.0 , max = 100.0 , step = 0.01 ),
60
+ ],
61
+ outputs = [io .Model .Output (display_name = "patched_model" )],
62
+ is_experimental = True ,
63
+ )
64
+
65
+ @classmethod
66
+ def execute (cls , model , strength ) -> io .NodeOutput :
56
67
m = model .clone ()
57
68
def cfg_norm (args ):
58
69
cond_p = args ['cond_denoised' ]
@@ -64,9 +75,17 @@ def cfg_norm(args):
64
75
return pred_text_ * scale * strength
65
76
66
77
m .set_model_sampler_post_cfg_function (cfg_norm )
67
- return (m , )
78
+ return io .NodeOutput (m )
79
+
80
+
81
+ class CfgExtension (ComfyExtension ):
82
+ @override
83
+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
84
+ return [
85
+ CFGZeroStar ,
86
+ CFGNorm ,
87
+ ]
88
+
68
89
69
- NODE_CLASS_MAPPINGS = {
70
- "CFGZeroStar" : CFGZeroStar ,
71
- "CFGNorm" : CFGNorm ,
72
- }
90
+ async def comfy_entrypoint () -> CfgExtension :
91
+ return CfgExtension ()
0 commit comments