Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,17 @@ message GenerateImageRequest {

message GenerateVideoRequest {
string prompt = 1;
string start_image = 2; // Path or base64 encoded image for the start frame
string end_image = 3; // Path or base64 encoded image for the end frame
int32 width = 4;
int32 height = 5;
int32 num_frames = 6; // Number of frames to generate
int32 fps = 7; // Frames per second
int32 seed = 8;
float cfg_scale = 9; // Classifier-free guidance scale
string dst = 10; // Output path for the generated video
string negative_prompt = 2; // Negative prompt for video generation
string start_image = 3; // Path or base64 encoded image for the start frame
string end_image = 4; // Path or base64 encoded image for the end frame
int32 width = 5;
int32 height = 6;
int32 num_frames = 7; // Number of frames to generate
int32 fps = 8; // Frames per second
int32 seed = 9;
float cfg_scale = 10; // Classifier-free guidance scale
int32 step = 11; // Number of inference steps
string dst = 12; // Output path for the generated video
}

message TTSRequest {
Expand Down
132 changes: 121 additions & 11 deletions backend/python/diffusers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import grpc

from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video
Expand Down Expand Up @@ -72,13 +72,6 @@ def is_float(s):
except ValueError:
return False

def is_int(s):
try:
int(s)
return True
except ValueError:
return False

# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
# Credits to https://github.com/neggles
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
Expand Down Expand Up @@ -184,9 +177,10 @@ def LoadModel(self, request, context):
key, value = opt.split(":")
# if value is a number, convert it to the appropriate type
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
if value.is_integer():
value = int(value)
else:
value = float(value)
self.options[key] = value

# From options, extract if present "torch_dtype" and set it to the appropriate type
Expand Down Expand Up @@ -334,6 +328,32 @@ def LoadModel(self, request, context):
torch_dtype=torch.bfloat16)
self.pipe.vae.to(torch.bfloat16)
self.pipe.text_encoder.to(torch.bfloat16)
elif request.PipelineType == "WanPipeline":
# WAN2.2 pipeline requires special VAE handling
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
elif request.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video pipeline
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
)
self.pipe = WanImageToVideoPipeline.from_pretrained(
request.Model,
vae=vae,
torch_dtype=torchType
)
self.img2vid = True # WAN2.2 image-to-video pipeline

if CLIPSKIP and request.CLIPSkip != 0:
self.clip_skip = request.CLIPSkip
Expand Down Expand Up @@ -575,6 +595,96 @@ def GenerateImage(self, request, context):

return backend_pb2.Result(message="Media generated", success=True)

def GenerateVideo(self, request, context):
try:
prompt = request.prompt
if not prompt:
return backend_pb2.Result(success=False, message="No prompt provided for video generation")

# Set default values from request or use defaults
num_frames = request.num_frames if request.num_frames > 0 else 81
fps = request.fps if request.fps > 0 else 16
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
num_inference_steps = request.step if request.step > 0 else 40

# Prepare generation parameters
kwargs = {
"prompt": prompt,
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
"height": request.height if request.height > 0 else 720,
"width": request.width if request.width > 0 else 1280,
"num_frames": num_frames,
"guidance_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
}

# Add custom options from self.options (including guidance_scale_2 if specified)
kwargs.update(self.options)

# Set seed if provided
if request.seed > 0:
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)

# Handle start and end images for video generation
if request.start_image:
kwargs["start_image"] = load_image(request.start_image)
if request.end_image:
kwargs["end_image"] = load_image(request.end_image)

print(f"Generating video with {kwargs=}", file=sys.stderr)

# Generate video frames based on pipeline type
if self.PipelineType == "WanPipeline":
# WAN2.2 text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0] # WAN2.2 returns frames in this format
elif self.PipelineType == "WanImageToVideoPipeline":
# WAN2.2 image-to-video generation
if request.start_image:
# Load and resize the input image according to WAN2.2 requirements
image = load_image(request.start_image)
# Use request dimensions or defaults, but respect WAN2.2 constraints
request_height = request.height if request.height > 0 else 480
request_width = request.width if request.width > 0 else 832
max_area = request_height * request_width
aspect_ratio = image.height / image.width
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
image = image.resize((width, height))
kwargs["image"] = image
kwargs["height"] = height
kwargs["width"] = width

output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.img2vid:
# Generic image-to-video generation
if request.start_image:
image = load_image(request.start_image)
image = image.resize((request.width if request.width > 0 else 1024,
request.height if request.height > 0 else 576))
kwargs["image"] = image

output = self.pipe(**kwargs)
frames = output.frames[0]
elif self.txt2vid:
# Generic text-to-video generation
output = self.pipe(**kwargs)
frames = output.frames[0]
else:
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")

# Export video
export_to_video(frames, request.dst, fps=fps)

return backend_pb2.Result(message="Video generated successfully", success=True)

except Exception as err:
print(f"Error generating video: {err}", file=sys.stderr)
traceback.print_exc()
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")


def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
Expand Down
3 changes: 2 additions & 1 deletion backend/python/diffusers/requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
optimum-quanto
ftfy
7 changes: 4 additions & 3 deletions backend/python/diffusers/requirements-cublas11.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.7.1+cu118
torchvision==0.22.1+cu118
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision==0.22.1
accelerate
compel
peft
sentencepiece
optimum-quanto
torch==2.7.1
optimum-quanto
ftfy
7 changes: 4 additions & 3 deletions backend/python/diffusers/requirements-cublas12.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
torch==2.7.1
torchvision==0.22.1
--extra-index-url https://download.pytorch.org/whl/cu121
git+https://github.com/huggingface/diffusers
opencv-python
transformers
torchvision
accelerate
compel
peft
sentencepiece
optimum-quanto
torch
ftfy
3 changes: 2 additions & 1 deletion backend/python/diffusers/requirements-hipblas.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
optimum-quanto
ftfy
3 changes: 2 additions & 1 deletion backend/python/diffusers/requirements-intel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
optimum-quanto
ftfy
3 changes: 2 additions & 1 deletion backend/python/diffusers/requirements-l4t.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ peft
optimum-quanto
numpy<2
sentencepiece
torchvision
torchvision
ftfy
3 changes: 2 additions & 1 deletion backend/python/diffusers/requirements-mps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ accelerate
compel
peft
sentencepiece
optimum-quanto
optimum-quanto
ftfy
15 changes: 4 additions & 11 deletions backend/python/mlx-audio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ def _is_float(self, s):
except ValueError:
return False

def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False

def Health(self, request, context):
"""
Returns a health check message.
Expand Down Expand Up @@ -89,9 +81,10 @@ async def LoadModel(self, request, context):

# Convert numeric values to appropriate types
if self._is_float(value):
value = float(value)
elif self._is_int(value):
value = int(value)
if float(value).is_integer():
value = int(value)
else:
value = float(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

Expand Down
15 changes: 4 additions & 11 deletions backend/python/mlx-vlm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ def _is_float(self, s):
except ValueError:
return False

def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False

def Health(self, request, context):
"""
Returns a health check message.
Expand Down Expand Up @@ -89,9 +81,10 @@ async def LoadModel(self, request, context):

# Convert numeric values to appropriate types
if self._is_float(value):
value = float(value)
elif self._is_int(value):
value = int(value)
if float(value).is_integer():
value = int(value)
else:
value = float(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

Expand Down
15 changes: 4 additions & 11 deletions backend/python/mlx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ def _is_float(self, s):
except ValueError:
return False

def _is_int(self, s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False

def Health(self, request, context):
"""
Returns a health check message.
Expand Down Expand Up @@ -87,9 +79,10 @@ async def LoadModel(self, request, context):

# Convert numeric values to appropriate types
if self._is_float(value):
value = float(value)
elif self._is_int(value):
value = int(value)
if float(value).is_integer():
value = int(value)
else:
value = float(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"

Expand Down
20 changes: 13 additions & 7 deletions core/backend/video.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)

func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {
func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) {

opts := ModelOptions(modelConfig, appConfig)
inferenceModel, err := loader.Load(
Expand All @@ -22,12 +22,18 @@ func VideoGeneration(height, width int32, prompt, startImage, endImage, dst stri
_, err := inferenceModel.GenerateVideo(
appConfig.Context,
&proto.GenerateVideoRequest{
Height: height,
Width: width,
Prompt: prompt,
StartImage: startImage,
EndImage: endImage,
Dst: dst,
Height: height,
Width: width,
Prompt: prompt,
NegativePrompt: negativePrompt,
StartImage: startImage,
EndImage: endImage,
NumFrames: numFrames,
Fps: fps,
Seed: seed,
CfgScale: cfgScale,
Step: step,
Dst: dst,
})
return err
}
Expand Down
Loading
Loading