|
18 | 18 | import grpc
|
19 | 19 |
|
20 | 20 | from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
21 |
| - EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline |
| 21 | + EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline |
22 | 22 | from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
|
23 | 23 | from diffusers.pipelines.stable_diffusion import safety_checker
|
24 | 24 | from diffusers.utils import load_image, export_to_video
|
@@ -334,6 +334,32 @@ def LoadModel(self, request, context):
|
334 | 334 | torch_dtype=torch.bfloat16)
|
335 | 335 | self.pipe.vae.to(torch.bfloat16)
|
336 | 336 | self.pipe.text_encoder.to(torch.bfloat16)
|
| 337 | + elif request.PipelineType == "WanPipeline": |
| 338 | + # WAN2.2 pipeline requires special VAE handling |
| 339 | + vae = AutoencoderKLWan.from_pretrained( |
| 340 | + request.Model, |
| 341 | + subfolder="vae", |
| 342 | + torch_dtype=torch.float32 |
| 343 | + ) |
| 344 | + self.pipe = WanPipeline.from_pretrained( |
| 345 | + request.Model, |
| 346 | + vae=vae, |
| 347 | + torch_dtype=torchType |
| 348 | + ) |
| 349 | + self.txt2vid = True # WAN2.2 is a text-to-video pipeline |
| 350 | + elif request.PipelineType == "WanImageToVideoPipeline": |
| 351 | + # WAN2.2 image-to-video pipeline |
| 352 | + vae = AutoencoderKLWan.from_pretrained( |
| 353 | + request.Model, |
| 354 | + subfolder="vae", |
| 355 | + torch_dtype=torch.float32 |
| 356 | + ) |
| 357 | + self.pipe = WanImageToVideoPipeline.from_pretrained( |
| 358 | + request.Model, |
| 359 | + vae=vae, |
| 360 | + torch_dtype=torchType |
| 361 | + ) |
| 362 | + self.img2vid = True # WAN2.2 image-to-video pipeline |
337 | 363 |
|
338 | 364 | if CLIPSKIP and request.CLIPSkip != 0:
|
339 | 365 | self.clip_skip = request.CLIPSkip
|
@@ -575,6 +601,96 @@ def GenerateImage(self, request, context):
|
575 | 601 |
|
576 | 602 | return backend_pb2.Result(message="Media generated", success=True)
|
577 | 603 |
|
| 604 | + def GenerateVideo(self, request, context): |
| 605 | + try: |
| 606 | + prompt = request.prompt |
| 607 | + if not prompt: |
| 608 | + return backend_pb2.Result(success=False, message="No prompt provided for video generation") |
| 609 | + |
| 610 | + # Set default values from request or use defaults |
| 611 | + num_frames = request.num_frames if request.num_frames > 0 else 81 |
| 612 | + fps = request.fps if request.fps > 0 else 16 |
| 613 | + cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 |
| 614 | + num_inference_steps = request.step if request.step > 0 else 40 |
| 615 | + |
| 616 | + # Prepare generation parameters |
| 617 | + kwargs = { |
| 618 | + "prompt": prompt, |
| 619 | + "negative_prompt": request.negative_prompt if request.negative_prompt else "", |
| 620 | + "height": request.height if request.height > 0 else 720, |
| 621 | + "width": request.width if request.width > 0 else 1280, |
| 622 | + "num_frames": num_frames, |
| 623 | + "guidance_scale": cfg_scale, |
| 624 | + "num_inference_steps": num_inference_steps, |
| 625 | + } |
| 626 | + |
| 627 | + # Add custom options from self.options (including guidance_scale_2 if specified) |
| 628 | + kwargs.update(self.options) |
| 629 | + |
| 630 | + # Set seed if provided |
| 631 | + if request.seed > 0: |
| 632 | + kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed) |
| 633 | + |
| 634 | + # Handle start and end images for video generation |
| 635 | + if request.start_image: |
| 636 | + kwargs["start_image"] = load_image(request.start_image) |
| 637 | + if request.end_image: |
| 638 | + kwargs["end_image"] = load_image(request.end_image) |
| 639 | + |
| 640 | + print(f"Generating video with {kwargs=}", file=sys.stderr) |
| 641 | + |
| 642 | + # Generate video frames based on pipeline type |
| 643 | + if self.PipelineType == "WanPipeline": |
| 644 | + # WAN2.2 text-to-video generation |
| 645 | + output = self.pipe(**kwargs) |
| 646 | + frames = output.frames[0] # WAN2.2 returns frames in this format |
| 647 | + elif self.PipelineType == "WanImageToVideoPipeline": |
| 648 | + # WAN2.2 image-to-video generation |
| 649 | + if request.start_image: |
| 650 | + # Load and resize the input image according to WAN2.2 requirements |
| 651 | + image = load_image(request.start_image) |
| 652 | + # Use request dimensions or defaults, but respect WAN2.2 constraints |
| 653 | + request_height = request.height if request.height > 0 else 480 |
| 654 | + request_width = request.width if request.width > 0 else 832 |
| 655 | + max_area = request_height * request_width |
| 656 | + aspect_ratio = image.height / image.width |
| 657 | + mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] |
| 658 | + height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value |
| 659 | + width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value |
| 660 | + image = image.resize((width, height)) |
| 661 | + kwargs["image"] = image |
| 662 | + kwargs["height"] = height |
| 663 | + kwargs["width"] = width |
| 664 | + |
| 665 | + output = self.pipe(**kwargs) |
| 666 | + frames = output.frames[0] |
| 667 | + elif self.img2vid: |
| 668 | + # Generic image-to-video generation |
| 669 | + if request.start_image: |
| 670 | + image = load_image(request.start_image) |
| 671 | + image = image.resize((request.width if request.width > 0 else 1024, |
| 672 | + request.height if request.height > 0 else 576)) |
| 673 | + kwargs["image"] = image |
| 674 | + |
| 675 | + output = self.pipe(**kwargs) |
| 676 | + frames = output.frames[0] |
| 677 | + elif self.txt2vid: |
| 678 | + # Generic text-to-video generation |
| 679 | + output = self.pipe(**kwargs) |
| 680 | + frames = output.frames[0] |
| 681 | + else: |
| 682 | + return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") |
| 683 | + |
| 684 | + # Export video |
| 685 | + export_to_video(frames, request.dst, fps=fps) |
| 686 | + |
| 687 | + return backend_pb2.Result(message="Video generated successfully", success=True) |
| 688 | + |
| 689 | + except Exception as err: |
| 690 | + print(f"Error generating video: {err}", file=sys.stderr) |
| 691 | + traceback.print_exc() |
| 692 | + return backend_pb2.Result(success=False, message=f"Error generating video: {err}") |
| 693 | + |
578 | 694 |
|
579 | 695 | def serve(address):
|
580 | 696 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
0 commit comments