diff --git a/maestro/trainer/models/florence_2/checkpoints.py b/maestro/trainer/models/florence_2/checkpoints.py index 0c93ba1c..39e19d46 100644 --- a/maestro/trainer/models/florence_2/checkpoints.py +++ b/maestro/trainer/models/florence_2/checkpoints.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Optional -import torch from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoProcessor @@ -23,7 +22,7 @@ class OptimizationStrategy(Enum): def load_model( model_id_or_path: str = DEFAULT_FLORENCE2_MODEL_ID, revision: str = DEFAULT_FLORENCE2_MODEL_REVISION, - device: str | torch.device = "auto", + device_map: Optional[str] = "auto", optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE, cache_dir: Optional[str] = None, ) -> tuple[AutoProcessor, AutoModelForCausalLM]: @@ -32,7 +31,10 @@ def load_model( Args: model_id_or_path (str): The identifier or path of the Florence 2 model to load. revision (str): The specific model revision to use. - device (torch.device): The device to load the model onto. + device_map (Optional[Union[str, dict]]): Device map for the model: + -"auto": Places model on single available device (default) + - String like "cpu", "cuda:0", or "mps" for a specific device + - Note: Florence-2 doesn't support dict mapping optimization_strategy (OptimizationStrategy): The optimization strategy to apply to the model. cache_dir (Optional[str]): Directory to cache the downloaded model files. @@ -43,7 +45,8 @@ def load_model( Raises: ValueError: If the model or processor cannot be loaded. """ - device = parse_device_spec(device) + + device = parse_device_spec(device_map) processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision) if optimization_strategy == OptimizationStrategy.LORA: diff --git a/maestro/trainer/models/paligemma_2/checkpoints.py b/maestro/trainer/models/paligemma_2/checkpoints.py index e7582c2d..a29a64da 100644 --- a/maestro/trainer/models/paligemma_2/checkpoints.py +++ b/maestro/trainer/models/paligemma_2/checkpoints.py @@ -1,13 +1,11 @@ import os from enum import Enum -from typing import Optional +from typing import Optional, Union import torch from peft import LoraConfig, get_peft_model from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor -from maestro.trainer.common.utils.device import parse_device_spec - DEFAULT_PALIGEMMA2_MODEL_ID = "google/paligemma2-3b-pt-224" DEFAULT_PALIGEMMA2_MODEL_REVISION = "refs/heads/main" @@ -24,7 +22,7 @@ class OptimizationStrategy(Enum): def load_model( model_id_or_path: str = DEFAULT_PALIGEMMA2_MODEL_ID, revision: str = DEFAULT_PALIGEMMA2_MODEL_REVISION, - device: str | torch.device = "auto", + device_map: Optional[Union[str, dict]] = None, optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE, cache_dir: Optional[str] = None, ) -> tuple[PaliGemmaProcessor, PaliGemmaForConditionalGeneration]: @@ -33,7 +31,10 @@ def load_model( Args: model_id_or_path (str): The identifier or path of the model to load. revision (str): The specific model revision to use. - device (torch.device): The device to load the model onto. + device_map (Optional[Union[str, dict]]): Device map for the model: + - None: Uses "auto" for automatic distribution across available devices (default) + - String like "cpu", "cuda:0", or "mps" for a specific device + - Dict for custom module-to-device mapping (e.g., {"": "cuda:0"}) optimization_strategy (OptimizationStrategy): The optimization strategy to apply to the model. cache_dir (Optional[str]): Directory to cache the downloaded model files. @@ -44,7 +45,6 @@ def load_model( Raises: ValueError: If the model or processor cannot be loaded. """ - device = parse_device_spec(device) processor = PaliGemmaProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision) if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}: @@ -66,7 +66,7 @@ def load_model( pretrained_model_name_or_path=model_id_or_path, revision=revision, trust_remote_code=True, - device_map="auto", + device_map=device_map if device_map else "auto", quantization_config=bnb_config, torch_dtype=torch.bfloat16, cache_dir=cache_dir, @@ -78,9 +78,9 @@ def load_model( pretrained_model_name_or_path=model_id_or_path, revision=revision, trust_remote_code=True, - device_map="auto", + device_map=device_map if device_map else "auto", cache_dir=cache_dir, - ).to(device) + ) if optimization_strategy == OptimizationStrategy.FREEZE: for param in model.vision_tower.parameters(): diff --git a/maestro/trainer/models/qwen_2_5_vl/checkpoints.py b/maestro/trainer/models/qwen_2_5_vl/checkpoints.py index 359d2a25..d40818bf 100644 --- a/maestro/trainer/models/qwen_2_5_vl/checkpoints.py +++ b/maestro/trainer/models/qwen_2_5_vl/checkpoints.py @@ -1,13 +1,11 @@ import os from enum import Enum -from typing import Optional +from typing import Optional, Union import torch from peft import LoraConfig, get_peft_model from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor -from maestro.trainer.common.utils.device import parse_device_spec - DEFAULT_QWEN2_5_VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" DEFAULT_QWEN2_5_VL_MODEL_REVISION = "refs/heads/main" @@ -23,7 +21,7 @@ class OptimizationStrategy(Enum): def load_model( model_id_or_path: str = DEFAULT_QWEN2_5_VL_MODEL_ID, revision: str = DEFAULT_QWEN2_5_VL_MODEL_REVISION, - device: str | torch.device = "auto", + device_map: Optional[Union[str, dict]] = None, optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE, cache_dir: Optional[str] = None, min_pixels: int = 256 * 28 * 28, @@ -35,7 +33,10 @@ def load_model( Args: model_id_or_path (str): The model name or path. revision (str): The model revision to load. - device (str | torch.device): The device to load the model onto. + device_map (Optional[Union[str, dict]]): Device map for the model: + - None: Uses "auto" for automatic distribution across available devices (default) + - String like "cpu", "cuda:0", or "mps" for a specific device + - Dict for custom module-to-device mapping (e.g., {"": "cuda:0"}) optimization_strategy (OptimizationStrategy): LORA, QLORA, or NONE. cache_dir (Optional[str]): Directory to cache downloaded model files. min_pixels (int): Minimum number of pixels allowed in the resized image. @@ -45,7 +46,6 @@ def load_model( (Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration): A tuple containing the loaded processor and model. """ - device = parse_device_spec(device) processor = Qwen2_5_VLProcessor.from_pretrained( model_id_or_path, revision=revision, @@ -82,7 +82,7 @@ def load_model( model_id_or_path, revision=revision, trust_remote_code=True, - device_map="auto", + device_map=device_map if device_map else "auto", quantization_config=bnb_config, torch_dtype=torch.bfloat16, cache_dir=cache_dir, @@ -94,11 +94,10 @@ def load_model( model_id_or_path, revision=revision, trust_remote_code=True, - device_map="auto", + device_map=device_map if device_map else "auto", torch_dtype=torch.bfloat16, cache_dir=cache_dir, ) - model.to(device) return processor, model