Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions maestro/trainer/models/florence_2/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from enum import Enum
from typing import Optional

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoProcessor

Expand All @@ -23,7 +22,7 @@ class OptimizationStrategy(Enum):
def load_model(
model_id_or_path: str = DEFAULT_FLORENCE2_MODEL_ID,
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION,
device: str | torch.device = "auto",
device_map: Optional[str] = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
) -> tuple[AutoProcessor, AutoModelForCausalLM]:
Expand All @@ -32,7 +31,10 @@ def load_model(
Args:
model_id_or_path (str): The identifier or path of the Florence 2 model to load.
revision (str): The specific model revision to use.
device (torch.device): The device to load the model onto.
device_map (Optional[Union[str, dict]]): Device map for the model:
-"auto": Places model on single available device (default)
- String like "cpu", "cuda:0", or "mps" for a specific device
- Note: Florence-2 doesn't support dict mapping
optimization_strategy (OptimizationStrategy): The optimization strategy to apply to the model.
cache_dir (Optional[str]): Directory to cache the downloaded model files.

Expand All @@ -43,7 +45,8 @@ def load_model(
Raises:
ValueError: If the model or processor cannot be loaded.
"""
device = parse_device_spec(device)

device = parse_device_spec(device_map)
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision)

if optimization_strategy == OptimizationStrategy.LORA:
Expand Down
18 changes: 9 additions & 9 deletions maestro/trainer/models/paligemma_2/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
from enum import Enum
from typing import Optional
from typing import Optional, Union

import torch
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor

from maestro.trainer.common.utils.device import parse_device_spec

DEFAULT_PALIGEMMA2_MODEL_ID = "google/paligemma2-3b-pt-224"
DEFAULT_PALIGEMMA2_MODEL_REVISION = "refs/heads/main"

Expand All @@ -24,7 +22,7 @@ class OptimizationStrategy(Enum):
def load_model(
model_id_or_path: str = DEFAULT_PALIGEMMA2_MODEL_ID,
revision: str = DEFAULT_PALIGEMMA2_MODEL_REVISION,
device: str | torch.device = "auto",
device_map: Optional[Union[str, dict]] = None,
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
) -> tuple[PaliGemmaProcessor, PaliGemmaForConditionalGeneration]:
Expand All @@ -33,7 +31,10 @@ def load_model(
Args:
model_id_or_path (str): The identifier or path of the model to load.
revision (str): The specific model revision to use.
device (torch.device): The device to load the model onto.
device_map (Optional[Union[str, dict]]): Device map for the model:
- None: Uses "auto" for automatic distribution across available devices (default)
- String like "cpu", "cuda:0", or "mps" for a specific device
- Dict for custom module-to-device mapping (e.g., {"": "cuda:0"})
optimization_strategy (OptimizationStrategy): The optimization strategy to apply to the model.
cache_dir (Optional[str]): Directory to cache the downloaded model files.

Expand All @@ -44,7 +45,6 @@ def load_model(
Raises:
ValueError: If the model or processor cannot be loaded.
"""
device = parse_device_spec(device)
processor = PaliGemmaProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision)

if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}:
Expand All @@ -66,7 +66,7 @@ def load_model(
pretrained_model_name_or_path=model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
device_map=device_map if device_map else "auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
Expand All @@ -78,9 +78,9 @@ def load_model(
pretrained_model_name_or_path=model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
device_map=device_map if device_map else "auto",
cache_dir=cache_dir,
).to(device)
)

if optimization_strategy == OptimizationStrategy.FREEZE:
for param in model.vision_tower.parameters():
Expand Down
17 changes: 8 additions & 9 deletions maestro/trainer/models/qwen_2_5_vl/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
from enum import Enum
from typing import Optional
from typing import Optional, Union

import torch
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor

from maestro.trainer.common.utils.device import parse_device_spec

DEFAULT_QWEN2_5_VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
DEFAULT_QWEN2_5_VL_MODEL_REVISION = "refs/heads/main"

Expand All @@ -23,7 +21,7 @@ class OptimizationStrategy(Enum):
def load_model(
model_id_or_path: str = DEFAULT_QWEN2_5_VL_MODEL_ID,
revision: str = DEFAULT_QWEN2_5_VL_MODEL_REVISION,
device: str | torch.device = "auto",
device_map: Optional[Union[str, dict]] = None,
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
min_pixels: int = 256 * 28 * 28,
Expand All @@ -35,7 +33,10 @@ def load_model(
Args:
model_id_or_path (str): The model name or path.
revision (str): The model revision to load.
device (str | torch.device): The device to load the model onto.
device_map (Optional[Union[str, dict]]): Device map for the model:
- None: Uses "auto" for automatic distribution across available devices (default)
- String like "cpu", "cuda:0", or "mps" for a specific device
- Dict for custom module-to-device mapping (e.g., {"": "cuda:0"})
optimization_strategy (OptimizationStrategy): LORA, QLORA, or NONE.
cache_dir (Optional[str]): Directory to cache downloaded model files.
min_pixels (int): Minimum number of pixels allowed in the resized image.
Expand All @@ -45,7 +46,6 @@ def load_model(
(Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration):
A tuple containing the loaded processor and model.
"""
device = parse_device_spec(device)
processor = Qwen2_5_VLProcessor.from_pretrained(
model_id_or_path,
revision=revision,
Expand Down Expand Up @@ -82,7 +82,7 @@ def load_model(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
device_map=device_map if device_map else "auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
Expand All @@ -94,11 +94,10 @@ def load_model(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
device_map=device_map if device_map else "auto",
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
model.to(device)

return processor, model

Expand Down