Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class SamplingInput:
scheduler: str
cfg_scale: float
total_steps: int
cache_threshold: float = 0
start_step: int = 0
seed: int = 0

Expand Down
4 changes: 2 additions & 2 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,13 +1075,13 @@ def estimate_pose(self, image: Output, resolution: int):
mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx"
return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls)

def apply_first_block_cache(self, model: Output, arch: Arch):
def apply_first_block_cache(self, model: Output, arch: Arch, threshold: float):
return self.add(
"ApplyFBCacheOnModel",
1,
model=model,
object_to_patch="diffusion_model",
residual_diff_threshold=0.2 if arch.is_sdxl_like else 0.12,
residual_diff_threshold=threshold or 0.2 if arch.is_sdxl_like else 0.12,
start=0.0,
end=1.0,
max_consecutive_cache_hits=-1,
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class SamplerPreset(NamedTuple):
lora: str | None = None
minimum_steps: int = 4
hidden: bool = False
cache_threshold: float = 0


class SamplerPresets:
Expand Down
19 changes: 11 additions & 8 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _sampling_from_style(style: Style, strength: float, is_live: bool):
scheduler=preset.scheduler,
cfg_scale=cfg or preset.cfg,
total_steps=max_steps,
cache_threshold=preset.cache_threshold or None,
)
if strength < 1.0:
result.total_steps, result.start_step = apply_strength(strength, max_steps, min_steps)
Expand Down Expand Up @@ -87,7 +88,9 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None):
return params


def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels):
def load_checkpoint_with_lora(
w: ComfyWorkflow, checkpoint: CheckpointInput, sampling: SamplingInput, models: ClientModels
):
arch = checkpoint.version
model_info = models.checkpoints.get(checkpoint.checkpoint)
if model_info is None:
Expand Down Expand Up @@ -133,7 +136,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
vae = w.load_vae(models.for_arch(arch).vae)

if checkpoint.dynamic_caching and (arch in [Arch.flux, Arch.sd3] or arch.is_sdxl_like):
model = w.apply_first_block_cache(model, arch)
model = w.apply_first_block_cache(model, arch, sampling.cache_threshold)

for lora in checkpoint.loras:
model, clip = w.load_lora(model, clip, lora.name, lora.strength, lora.strength)
Expand Down Expand Up @@ -753,7 +756,7 @@ def generate(
misc: MiscParams,
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all)
model = apply_ip_adapter(w, model, cond.control, models)
model_orig = copy(model)
model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
Expand Down Expand Up @@ -865,7 +868,7 @@ def inpaint(
checkpoint.dynamic_caching = False # doesn't seem to work with Flux fill model
sampling.cfg_scale = 30 # set Flux guidance to 30 (typical values don't work well)

model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all)
model = w.differential_diffusion(model)
model_orig = copy(model)

Expand Down Expand Up @@ -994,7 +997,7 @@ def refine(
misc: MiscParams,
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all)
model = apply_ip_adapter(w, model, cond.control, models)
model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
Expand Down Expand Up @@ -1031,7 +1034,7 @@ def refine_region(
):
extent = ScaledExtent.from_input(images.extent)

model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all)
model = w.differential_diffusion(model)
model = apply_ip_adapter(w, model, cond.control, models)
model_orig = copy(model)
Expand Down Expand Up @@ -1182,7 +1185,7 @@ def upscale_tiled(
extent.initial, extent.desired.width, sampling.denoise_strength
)

model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all)
model = apply_ip_adapter(w, model, cond.control, models)

in_image = w.load_image(image)
Expand Down Expand Up @@ -1301,7 +1304,7 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
is_live = node.input("sampler_preset", "auto") == "live"
checkpoint_input = style.get_models(models.checkpoints.keys())
sampling = _sampling_from_style(style, 1.0, is_live)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, sampling, models)
outputs[node.output(0)] = model
outputs[node.output(1)] = clip.model
outputs[node.output(2)] = vae
Expand Down