Skip to content

Commit 2f85ef5

Browse files
committed
feat(diffusers): add support for wan2.2
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 83b8549 commit 2f85ef5

File tree

5 files changed

+161
-19
lines changed

5 files changed

+161
-19
lines changed

backend/backend.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,17 @@ message GenerateImageRequest {
312312

313313
message GenerateVideoRequest {
314314
string prompt = 1;
315-
string start_image = 2; // Path or base64 encoded image for the start frame
316-
string end_image = 3; // Path or base64 encoded image for the end frame
317-
int32 width = 4;
318-
int32 height = 5;
319-
int32 num_frames = 6; // Number of frames to generate
320-
int32 fps = 7; // Frames per second
321-
int32 seed = 8;
322-
float cfg_scale = 9; // Classifier-free guidance scale
323-
string dst = 10; // Output path for the generated video
315+
string negative_prompt = 2; // Negative prompt for video generation
316+
string start_image = 3; // Path or base64 encoded image for the start frame
317+
string end_image = 4; // Path or base64 encoded image for the end frame
318+
int32 width = 5;
319+
int32 height = 6;
320+
int32 num_frames = 7; // Number of frames to generate
321+
int32 fps = 8; // Frames per second
322+
int32 seed = 9;
323+
float cfg_scale = 10; // Classifier-free guidance scale
324+
int32 step = 11; // Number of inference steps
325+
string dst = 12; // Output path for the generated video
324326
}
325327

326328
message TTSRequest {

backend/python/diffusers/backend.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import grpc
1919

2020
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
21-
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
21+
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
2222
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
2323
from diffusers.pipelines.stable_diffusion import safety_checker
2424
from diffusers.utils import load_image, export_to_video
@@ -334,6 +334,32 @@ def LoadModel(self, request, context):
334334
torch_dtype=torch.bfloat16)
335335
self.pipe.vae.to(torch.bfloat16)
336336
self.pipe.text_encoder.to(torch.bfloat16)
337+
elif request.PipelineType == "WanPipeline":
338+
# WAN2.2 pipeline requires special VAE handling
339+
vae = AutoencoderKLWan.from_pretrained(
340+
request.Model,
341+
subfolder="vae",
342+
torch_dtype=torch.float32
343+
)
344+
self.pipe = WanPipeline.from_pretrained(
345+
request.Model,
346+
vae=vae,
347+
torch_dtype=torchType
348+
)
349+
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
350+
elif request.PipelineType == "WanImageToVideoPipeline":
351+
# WAN2.2 image-to-video pipeline
352+
vae = AutoencoderKLWan.from_pretrained(
353+
request.Model,
354+
subfolder="vae",
355+
torch_dtype=torch.float32
356+
)
357+
self.pipe = WanImageToVideoPipeline.from_pretrained(
358+
request.Model,
359+
vae=vae,
360+
torch_dtype=torchType
361+
)
362+
self.img2vid = True # WAN2.2 image-to-video pipeline
337363

338364
if CLIPSKIP and request.CLIPSkip != 0:
339365
self.clip_skip = request.CLIPSkip
@@ -575,6 +601,96 @@ def GenerateImage(self, request, context):
575601

576602
return backend_pb2.Result(message="Media generated", success=True)
577603

604+
def GenerateVideo(self, request, context):
605+
try:
606+
prompt = request.prompt
607+
if not prompt:
608+
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
609+
610+
# Set default values from request or use defaults
611+
num_frames = request.num_frames if request.num_frames > 0 else 81
612+
fps = request.fps if request.fps > 0 else 16
613+
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
614+
num_inference_steps = request.step if request.step > 0 else 40
615+
616+
# Prepare generation parameters
617+
kwargs = {
618+
"prompt": prompt,
619+
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
620+
"height": request.height if request.height > 0 else 720,
621+
"width": request.width if request.width > 0 else 1280,
622+
"num_frames": num_frames,
623+
"guidance_scale": cfg_scale,
624+
"num_inference_steps": num_inference_steps,
625+
}
626+
627+
# Add custom options from self.options (including guidance_scale_2 if specified)
628+
kwargs.update(self.options)
629+
630+
# Set seed if provided
631+
if request.seed > 0:
632+
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
633+
634+
# Handle start and end images for video generation
635+
if request.start_image:
636+
kwargs["start_image"] = load_image(request.start_image)
637+
if request.end_image:
638+
kwargs["end_image"] = load_image(request.end_image)
639+
640+
print(f"Generating video with {kwargs=}", file=sys.stderr)
641+
642+
# Generate video frames based on pipeline type
643+
if self.PipelineType == "WanPipeline":
644+
# WAN2.2 text-to-video generation
645+
output = self.pipe(**kwargs)
646+
frames = output.frames[0] # WAN2.2 returns frames in this format
647+
elif self.PipelineType == "WanImageToVideoPipeline":
648+
# WAN2.2 image-to-video generation
649+
if request.start_image:
650+
# Load and resize the input image according to WAN2.2 requirements
651+
image = load_image(request.start_image)
652+
# Use request dimensions or defaults, but respect WAN2.2 constraints
653+
request_height = request.height if request.height > 0 else 480
654+
request_width = request.width if request.width > 0 else 832
655+
max_area = request_height * request_width
656+
aspect_ratio = image.height / image.width
657+
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
658+
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
659+
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
660+
image = image.resize((width, height))
661+
kwargs["image"] = image
662+
kwargs["height"] = height
663+
kwargs["width"] = width
664+
665+
output = self.pipe(**kwargs)
666+
frames = output.frames[0]
667+
elif self.img2vid:
668+
# Generic image-to-video generation
669+
if request.start_image:
670+
image = load_image(request.start_image)
671+
image = image.resize((request.width if request.width > 0 else 1024,
672+
request.height if request.height > 0 else 576))
673+
kwargs["image"] = image
674+
675+
output = self.pipe(**kwargs)
676+
frames = output.frames[0]
677+
elif self.txt2vid:
678+
# Generic text-to-video generation
679+
output = self.pipe(**kwargs)
680+
frames = output.frames[0]
681+
else:
682+
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
683+
684+
# Export video
685+
export_to_video(frames, request.dst, fps=fps)
686+
687+
return backend_pb2.Result(message="Video generated successfully", success=True)
688+
689+
except Exception as err:
690+
print(f"Error generating video: {err}", file=sys.stderr)
691+
traceback.print_exc()
692+
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
693+
578694

579695
def serve(address):
580696
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),

core/backend/video.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
model "github.com/mudler/LocalAI/pkg/model"
88
)
99

10-
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
10+
func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
1111

1212
opts := ModelOptions(modelConfig, appConfig)
1313
inferenceModel, err := loader.Load(
@@ -22,12 +22,18 @@ func VideoGeneration(height, width int32, prompt, startImage, endImage, dst stri
2222
_, err := inferenceModel.GenerateVideo(
2323
appConfig.Context,
2424
&proto.GenerateVideoRequest{
25-
Height: height,
26-
Width: width,
27-
Prompt: prompt,
28-
StartImage: startImage,
29-
EndImage: endImage,
30-
Dst: dst,
25+
Height: height,
26+
Width: width,
27+
Prompt: prompt,
28+
NegativePrompt: negativePrompt,
29+
StartImage: startImage,
30+
EndImage: endImage,
31+
NumFrames: numFrames,
32+
Fps: fps,
33+
Seed: seed,
34+
CfgScale: cfgScale,
35+
Step: step,
36+
Dst: dst,
3137
})
3238
return err
3339
}

core/http/endpoints/localai/video.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func downloadFile(url string) (string, error) {
6161
*/
6262
// VideoEndpoint
6363
// @Summary Creates a video given a prompt.
64-
// @Param request body schema.OpenAIRequest true "query params"
64+
// @Param request body schema.VideoRequest true "query params"
6565
// @Success 200 {object} schema.OpenAIResponse "Response"
6666
// @Router /video [post]
6767
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
@@ -166,7 +166,23 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
166166

167167
baseURL := c.BaseURL()
168168

169-
fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig)
169+
fn, err := backend.VideoGeneration(
170+
height,
171+
width,
172+
input.Prompt,
173+
input.NegativePrompt,
174+
src,
175+
input.EndImage,
176+
output,
177+
input.NumFrames,
178+
input.FPS,
179+
input.Seed,
180+
input.CFGScale,
181+
input.Step,
182+
ml,
183+
*config,
184+
appConfig,
185+
)
170186
if err != nil {
171187
return err
172188
}

core/schema/localai.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type GalleryResponse struct {
2828
type VideoRequest struct {
2929
BasicModelRequest
3030
Prompt string `json:"prompt" yaml:"prompt"`
31+
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
3132
StartImage string `json:"start_image" yaml:"start_image"`
3233
EndImage string `json:"end_image" yaml:"end_image"`
3334
Width int32 `json:"width" yaml:"width"`
@@ -36,6 +37,7 @@ type VideoRequest struct {
3637
FPS int32 `json:"fps" yaml:"fps"`
3738
Seed int32 `json:"seed" yaml:"seed"`
3839
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"`
40+
Step int32 `json:"step" yaml:"step"`
3941
ResponseFormat string `json:"response_format" yaml:"response_format"`
4042
}
4143

0 commit comments

Comments
 (0)