diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f08a83c1..f47f6238 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,14 +82,6 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - name: Check out fairchem repository - if: ${{ matrix.model.name == 'fairchem' }} - uses: actions/checkout@v4 - with: - repository: FAIR-Chem/fairchem - path: fairchem-repo - ref: fairchem_core-1.10.0 - - name: Set up Python uses: actions/setup-python@v5 with: @@ -98,24 +90,14 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 - - name: Install fairchem repository and dependencies + - name: Install fairchem and dependencies if: ${{ matrix.model.name == 'fairchem' }} env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - uv pip install huggingface_hub --system - if [ -n "$HF_TOKEN" ]; then - huggingface-cli login --token "$HF_TOKEN" - else - echo "HF_TOKEN is not set. Skipping login." - fi - if [ -f fairchem-repo/packages/requirements.txt ]; then - uv pip install -r fairchem-repo/packages/requirements.txt --system - fi - if [ -f fairchem-repo/packages/requirements-optional.txt ]; then - uv pip install -r fairchem-repo/packages/requirements-optional.txt --system - fi - uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system + uv pip install "torch>=2.6" --index-url https://download.pytorch.org/whl/cpu --system + uv pip install "fairchem-core>=2.2.0" --system + uv pip install "huggingface_hub[cli]" --system uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system - name: Install torch_sim with model dependencies @@ -124,7 +106,12 @@ jobs: uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system - name: Run ${{ matrix.model.test_path }} tests + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | + if [ "${{ matrix.model.name }}" == "fairchem" ]; then + huggingface-cli login --token "$HF_TOKEN" + fi pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} - name: Upload coverage to Codecov @@ -168,4 +155,11 @@ jobs: uses: astral-sh/setup-uv@v6 - name: Run example - run: uv run --with . ${{ matrix.example }} + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + run: | + if [[ "${{ matrix.example }}" == *"fairchem"* ]]; then + uv pip install "huggingface_hub[cli]" --system + huggingface-cli login --token "$HF_TOKEN" + fi + uv run --with . ${{ matrix.example }} diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_fairchem.py similarity index 62% rename from examples/scripts/1_Introduction/1.3_Fairchem.py rename to examples/scripts/1_Introduction/1.3_fairchem.py index b6f8dd5b..6fc9a42d 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_fairchem.py @@ -3,12 +3,10 @@ # /// script # dependencies = [ -# "fairchem-core==1.10.0", +# "fairchem-core>=2.2.0", # ] # /// -import sys - import torch from ase.build import bulk @@ -19,38 +17,35 @@ device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 -try: - from fairchem.core.models.model_registry import model_name_to_local_file -except ImportError: - print("Skipping example due to missing fairchem dependency") - sys.exit(0) - -MODEL_PATH = model_name_to_local_file( - "EquiformerV2-31M-S2EF-OC20-All+MD", local_cache="." -) +# UMA = Unified Machine Learning for Atomistic simulations +MODEL_NAME = "uma-s-1" # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2)) atomic_numbers = si_dc.get_atomic_numbers() model = FairChemModel( - model=MODEL_PATH, + model=None, + model_name=MODEL_NAME, + task_name="omat", # Open Materials task for crystalline systems cpu=False, - seed=0, ) atoms_list = [si_dc, si_dc] -state = ts.io.atoms_to_state(atoms_list) +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) results = model(state) print(results["energy"].shape) print(results["forces"].shape) -print(results["stress"].shape) +if stress := results.get("stress"): + print(stress.shape) print(f"Energy: {results['energy']}") print(f"Forces: {results['forces']}") -print(f"Stress: {results['stress']}") +if stress := results.get("stress"): + print(f"{stress=}") # Check if the energy, forces, and stress are the same for the Si system across the batch print(torch.max(torch.abs(results["energy"][0] - results["energy"][1]))) print(torch.max(torch.abs(results["forces"][0] - results["forces"][1]))) -print(torch.max(torch.abs(results["stress"][0] - results["stress"][1]))) +if stress := results.get("stress"): + print(torch.max(torch.abs(stress[0] - stress[1]))) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b04ad9c8..f73b6a97 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,97 +1,224 @@ -import os - import pytest import torch -from tests.models.conftest import ( - consistency_test_simstate_fixtures, - make_model_calculator_consistency_test, - make_validate_model_outputs_test, -) +from tests.models.conftest import make_validate_model_outputs_test try: - from fairchem.core import OCPCalculator - from fairchem.core.models.model_registry import model_name_to_local_file + from collections.abc import Callable + + from ase.build import bulk, fcc100, molecule from huggingface_hub.utils._auth import get_token + import torch_sim as ts from torch_sim.models.fairchem import FairChemModel except ImportError: pytest.skip("FairChem not installed", allow_module_level=True) -@pytest.fixture(scope="session") -def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - @pytest.fixture -def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: +def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel: + """UMA model for periodic boundary condition systems.""" cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=True) + return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) -@pytest.fixture -def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=False) +# Removed calculator consistency tests since we're using predictor interface only -if get_token(): +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) +def test_task_initialization(task_name: str) -> None: + """Test that different UMA task names work correctly.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True) + assert model.task_name.value == task_name + assert hasattr(model, "predictor") - @pytest.fixture(scope="session") - def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - @pytest.fixture - def eqv2_omat24_model_pbc( - model_path_omat24: str, device: torch.device - ) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_omat24, cpu=cpu, seed=0, pbc=True) +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("task_name", "systems_func"), + [ + ( + "omat", + lambda: [ + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05), + bulk("Fe", "bcc", a=2.87), + bulk("Cu", "fcc", a=3.61), + ], + ), + ( + "omol", + lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], + ), + ], +) +def test_homogeneous_batching( + task_name: str, systems_func: Callable, device: torch.device, dtype: torch.dtype +) -> None: + """Test batching multiple systems with the same task.""" + systems = systems_func() + + # Add molecular properties for molecules + if task_name == "omol": + for mol in systems: + mol.info.update({"charge": 0, "spin": 1}) + + model = FairChemModel( + model=None, model_name="uma-s-1", task_name=task_name, cpu=device.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + # Check batch dimensions + assert results["energy"].shape == (4,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + + # Check that different systems have different energies + energies = results["energy"] + unique_energies = torch.unique(energies, dim=0) + assert len(unique_energies) > 1, "Different systems should have different energies" + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_heterogeneous_tasks(device: torch.device, dtype: torch.dtype) -> None: + """Test different task types work with appropriate systems.""" + # Test molecule, material, and catalysis systems separately + test_cases = [ + ("omol", [molecule("H2O")]), + ("omat", [bulk("Pt", cubic=True)]), + ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), + ] + + for task_name, systems in test_cases: + if task_name == "omol": + systems[0].info.update({"charge": 0, "spin": 1}) + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name=task_name, + cpu=device.type == "cpu", + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape[0] == 1 + assert results["forces"].dim() == 2 + assert results["forces"].shape[1] == 3 + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("systems_func", "expected_count"), + [ + (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system + ( + lambda: [ + bulk("H", "bcc", a=2.0), + bulk("Li", "bcc", a=3.0), + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), + ], + 4, + ), # Mixed sizes + ( + lambda: [ + bulk(element, "fcc", a=4.0) + for element in ["Al", "Cu", "Ni", "Pd", "Pt"] * 3 + ], + 15, + ), # Large batch + ], +) +def test_batch_size_variations( + systems_func: Callable, expected_count: int, device: torch.device, dtype: torch.dtype +) -> None: + """Test batching with different numbers and sizes of systems.""" + systems = systems_func() + + model = FairChemModel( + model=None, model_name="uma-s-1", task_name="omat", cpu=device.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + assert results["energy"].shape == (expected_count,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize("compute_stress", [True, False]) +def test_stress_computation( + *, compute_stress: bool, device: torch.device, dtype: torch.dtype +) -> None: + """Test stress tensor computation.""" + systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name="omat", + cpu=device.type == "cpu", + compute_stress=compute_stress, + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + if compute_stress: + assert "stress" in results + assert results["stress"].shape == (2, 3, 3) + assert torch.isfinite(results["stress"]).all() + else: + assert "stress" not in results + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_device_consistency(dtype: torch.dtype) -> None: + """Test device consistency between model and data.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cpu = device.type == "cpu" + + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=device, dtype=dtype) + results = model(state) + assert results["energy"].device == device + assert results["forces"].device == device -@pytest.fixture -def ocp_calculator(model_path_oc20: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) - - -test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( - test_name="fairchem_ocp", - model_fixture_name="eqv2_oc20_model_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, -) -test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( - test_name="fairchem_non_pbc_benzene", - model_fixture_name="eqv2_oc20_model_non_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=["benzene_sim_state"], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" ) +def test_empty_batch_error() -> None: + """Test that empty batches raise appropriate errors.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True) + with pytest.raises((ValueError, RuntimeError, IndexError)): + model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32)) -# Skip this test due to issues with how the older models -# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428) - -test_fairchem_ocp_model_outputs = pytest.mark.skipif( - os.environ.get("HF_TOKEN") is None, - reason="Issues in graph construction of older models", -)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) +test_fairchem_uma_model_outputs = pytest.mark.skipif( + get_token() is None, + reason="Requires HuggingFace authentication for UMA model access", +)(make_validate_model_outputs_test(model_fixture_name="eqv2_uma_model_pbc")) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 77b1b0ba..3aae3f3a 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -1,28 +1,16 @@ -"""Wrapper for FairChem ecosystem models in TorchSim. +"""FairChem model wrapper for torch_sim. -This module provides a TorchSim wrapper of the FairChem models for computing -energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the FairChem library, integrating it with the torch_sim framework to enable seamless -simulation of atomistic systems with machine learning potentials. +Provides a TorchSim-compatible interface to FairChem models for computing +energies, forces, and stresses of atomistic systems. -The FairChemModel class adapts FairChem models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. - -Notes: - This implementation requires FairChem to be installed and accessible. - It supports various model configurations through configuration files or - pretrained model checkpoints. +Requires fairchem-core to be installed. """ -# ruff: noqa: T201 - from __future__ import annotations -import copy import traceback import typing import warnings -from types import MappingProxyType from typing import Any import torch @@ -32,15 +20,10 @@ try: - from fairchem.core.common.registry import registry - from fairchem.core.common.utils import ( - load_config, - setup_imports, - setup_logging, - update_config, - ) - from fairchem.core.models.model_registry import model_name_to_local_file - from torch_geometric.data import Batch, Data + from fairchem.core import pretrained_mlip + from fairchem.core.calculate.ase_calculator import UMATask + from fairchem.core.common.utils import setup_imports, setup_logging + from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) @@ -63,92 +46,56 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -_DTYPE_DICT = { - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", -} - class FairChemModel(ModelInterface): - """Computes atomistic energies, forces and stresses using a FairChem model. + """FairChem model wrapper for computing atomistic properties. - This class wraps a FairChem model to compute energies, forces, and stresses for - atomistic systems. It handles model initialization, checkpoint loading, and - provides a forward pass that accepts a SimState object and returns model - predictions. + Wraps FairChem models to compute energies, forces, and stresses. Can be + initialized with a model checkpoint path or pretrained model name. - The model can be initialized either with a configuration file or a pretrained - checkpoint. It supports various model architectures and configurations supported by - FairChem. + Uses the fairchem-core-2.2.0+ predictor API for batch inference. Attributes: - neighbor_list_fn (Callable | None): Function to compute neighbor lists - config (dict): Complete model configuration dictionary - trainer: FairChem trainer object that contains the model - data_object (Batch): Data object containing system information - implemented_properties (list): Model outputs the model can compute - pbc (bool): Whether periodic boundary conditions are used + predictor: The FairChem predictor for batch inference + task_name (UMATask): Task type for the model + _device (torch.device): Device where computation is performed _dtype (torch.dtype): Data type used for computation _compute_stress (bool): Whether to compute stress tensor - _compute_forces (bool): Whether to compute forces - _device (torch.device): Device where computation is performed - _reshaped_props (dict): Properties that need reshaping after computation + implemented_properties (list): Model outputs the model can compute Examples: >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state) """ - _reshaped_props = MappingProxyType( - {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} - ) - - def __init__( # noqa: C901, PLR0915 + def __init__( self, model: str | Path | None, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only - config_yml: str | None = None, model_name: str | None = None, - local_cache: str | None = None, - trainer: str | None = None, cpu: bool = False, - seed: int | None = None, dtype: torch.dtype | None = None, compute_stress: bool = False, - pbc: bool = True, - disable_amp: bool = True, + task_name: UMATask | str | None = None, ) -> None: - """Initialize the FairChemModel with specified configuration. - - Loads a FairChem model from either a checkpoint path or a configuration file. - Sets up the model parameters, trainer, and configuration for subsequent use - in energy and force calculations. + """Initialize the FairChem model. Args: model (str | Path | None): Path to model checkpoint file neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) - config_yml (str | None): Path to configuration YAML file model_name (str | None): Name of pretrained model to load - local_cache (str | None): Path to local model cache directory - trainer (str | None): Name of trainer class to use cpu (bool): Whether to use CPU instead of GPU for computation - seed (int | None): Random seed for reproducibility dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor - pbc (bool): Whether to use periodic boundary conditions - disable_amp (bool): Whether to disable AMP + task_name (UMATask | str | None): Task type for UMA models (optional, + only needed for UMA models) + Raises: RuntimeError: If both model_name and model are specified - NotImplementedError: If local_cache is not set when model_name is used NotImplementedError: If custom neighbor list function is provided - ValueError: If stress computation is requested but not supported by model - - Notes: - Either config_yml or model must be provided. The model loads configuration - from the checkpoint if config_yml is not specified. + ValueError: If neither model nor model_name is provided """ setup_imports() setup_logging() @@ -158,7 +105,11 @@ def __init__( # noqa: C901, PLR0915 self._compute_stress = compute_stress self._compute_forces = True self._memory_scales_with = "n_atoms" - self.pbc = pbc + + if neighbor_list_fn is not None: + raise NotImplementedError( + "Custom neighbor list is not supported for FairChemModel." + ) if model_name is not None: if model is not None: @@ -166,166 +117,42 @@ def __init__( # noqa: C901, PLR0915 "model_name and checkpoint_path were both specified, " "please use only one at a time" ) - if local_cache is None: - raise NotImplementedError( - "Local cache must be set when specifying a model name" - ) - model = model_name_to_local_file( - model_name=model_name, local_cache=local_cache - ) + model = model_name - # Either the config path or the checkpoint path needs to be provided - if not config_yml and model is None: - raise ValueError("Either config_yml or model must be provided") - - checkpoint = None - if config_yml is not None: - if isinstance(config_yml, str): - config, duplicates_warning, duplicates_error = load_config(config_yml) - if len(duplicates_warning) > 0: - print( - "Overwritten config parameters from included configs " - f"(non-included parameters take precedence): {duplicates_warning}" - ) - if len(duplicates_error) > 0: - raise ValueError( - "Conflicting (duplicate) parameters in simultaneously " - f"included configs: {duplicates_error}" - ) - else: - config = config_yml - - # Only keeps the train data that might have normalizer values - if isinstance(config["dataset"], list): - config["dataset"] = config["dataset"][0] - elif isinstance(config["dataset"], dict): - config["dataset"] = config["dataset"].get("train", None) - else: - # Loads the config from the checkpoint directly (always on CPU). - checkpoint = torch.load(model, map_location=torch.device("cpu")) - config = checkpoint["config"] - - if trainer is not None: - config["trainer"] = trainer - else: - config["trainer"] = config.get("trainer", "ocp") - - if "model_attributes" in config: - config["model_attributes"]["name"] = config.pop("model") - config["model"] = config["model_attributes"] - - self.neighbor_list_fn = neighbor_list_fn - - if neighbor_list_fn is None: - # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - else: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemModel." - ) + if model is None: + raise ValueError("Either model or model_name must be provided") - if "backbone" in config["model"]: - config["model"]["backbone"]["use_pbc"] = pbc - config["model"]["backbone"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"]["backbone"].update({"dtype": _DTYPE_DICT[dtype]}) - for key in config["model"]["heads"]: - config["model"]["heads"][key].update( - {"dtype": _DTYPE_DICT[dtype]} - ) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - else: - config["model"]["use_pbc"] = pbc - config["model"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"].update({"dtype": _DTYPE_DICT[dtype]}) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - - ### backwards compatibility with OCP v<2.0 - config = update_config(config) - - self.config = copy.deepcopy(config) - self.config["checkpoint"] = str(model) - del config["dataset"]["src"] - - self.trainer = registry.get_trainer_class(config["trainer"])( - task=config.get("task", {}), - model=config["model"], - dataset=[config["dataset"]], - outputs=config["outputs"], - loss_functions=config["loss_functions"], - evaluation_metrics=config["evaluation_metrics"], - optimizer=config["optim"], - identifier="", - slurm=config.get("slurm", {}), - local_rank=config.get("local_rank", 0), - is_debug=config.get("is_debug", True), - cpu=cpu, - amp=False if dtype is not None else config.get("amp", False), - inference_only=True, - ) - - if dtype is not None: - # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self.dtype) - - if model is not None: - self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) - - seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] - if seed is None: - print( - "No seed has been set in model checkpoint or OCPCalculator! Results may " - "not be reproducible on re-run" - ) - else: - self.trainer.set_seed(seed) + # Convert task_name to UMATask if it's a string (only for UMA models) + if isinstance(task_name, str): + task_name = UMATask(task_name) - if disable_amp: - self.trainer.scaler = None + # Use the efficient predictor API for optimal performance + device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu" + self._device = torch.device(device_str) + self.task_name = task_name - self.implemented_properties = list(self.config["outputs"]) + # Create efficient batch predictor for fast inference + self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str) - self._device = self.trainer.device + # Determine implemented properties + # This is a simplified approach - in practice you might want to + # inspect the model configuration more carefully + self.implemented_properties = ["energy", "forces"] + if compute_stress: + self.implemented_properties.append("stress") - stress_output = "stress" in self.implemented_properties - if not stress_output and compute_stress: - raise NotImplementedError("Stress output not implemented for this model") - - def load_checkpoint( - self, checkpoint_path: str, checkpoint: dict | None = None - ) -> None: - """Load an existing trained model checkpoint. + @property + def dtype(self) -> torch.dtype: + """Return the data type used by the model.""" + return self._dtype - Loads model parameters from a checkpoint file or dictionary, - setting the model to inference mode. - - Args: - checkpoint_path (str): Path to the trained model checkpoint file - checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, - this dictionary is used instead of loading from checkpoint_path. - - Notes: - If loading fails, a message is printed but no exception is raised. - """ - try: - self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) - except NotImplementedError: - print("Unable to load checkpoint!") + @property + def device(self) -> torch.device: + """Return the device where the model is located.""" + return self._device def forward(self, state: ts.SimState | StateDict) -> dict: - """Perform forward pass to compute energies, forces, and other properties. - - Takes a simulation state and computes the properties implemented by the model, - such as energy, forces, and stresses. + """Compute energies, forces, and other properties. Args: state (SimState | StateDict): State object containing positions, cells, @@ -336,12 +163,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict: dict: Dictionary of model predictions, which may include: - energy (torch.Tensor): Energy with shape [batch_size] - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], - if compute_stress is True - - Notes: - The state is automatically transferred to the model's device if needed. - All output tensors are detached from the computation graph. + - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) @@ -352,47 +174,57 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.system_idx is None: state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) - if self.pbc != state.pbc: - raise ValueError( - "PBC mismatch between model and state. " - "For FairChemModel PBC needs to be defined in the model class." - ) + # Convert SimState to AtomicData objects for efficient batch processing + from ase import Atoms natoms = torch.bincount(state.system_idx) - fixed = torch.zeros((state.system_idx.size(0), natoms.sum()), dtype=torch.int) - data_list = [] + atomic_data_list = [] + for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) ): - data_list.append( - Data( - pos=state.positions[c - n : c].clone(), - cell=state.row_vector_cell[i, None].clone(), - atomic_numbers=state.atomic_numbers[c - n : c].clone(), - fixed=fixed[c - n : c].clone(), - natoms=n, - pbc=torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool), - ) + # Extract system data + positions = state.positions[c - n : c].cpu().numpy() + atomic_numbers = state.atomic_numbers[c - n : c].cpu().numpy() + cell = ( + state.row_vector_cell[i].cpu().numpy() + if state.row_vector_cell is not None + else None ) - self.data_object = Batch.from_data_list(data_list) - if self.dtype is not None: - self.data_object.pos = self.data_object.pos.to(self.dtype) - self.data_object.cell = self.data_object.cell.to(self.dtype) + # Create ASE Atoms object first + atoms = Atoms( + numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=state.pbc if cell is not None else False, + ) + + # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + if self.task_name is None: + atomic_data = AtomicData.from_ase(atoms) + else: + atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + atomic_data_list.append(atomic_data) + + # Create batch for efficient inference + batch = atomicdata_list_to_batch(atomic_data_list) + batch = batch.to(self._device) - predictions = self.trainer.predict( - self.data_object, per_image=False, disable_tqdm=True - ) + # Run efficient batch prediction + predictions = self.predictor.predict(batch) + # Convert predictions to torch_sim format results = {} + results["energy"] = predictions["energy"].to(dtype=self._dtype) + results["forces"] = predictions["forces"].to(dtype=self._dtype) - for key in predictions: - _pred = predictions[key] - if key in self._reshaped_props: - _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred.detach() + # Handle stress if requested and available + if self._compute_stress and "stress" in predictions: + stress = predictions["stress"].to(dtype=self._dtype) + # Ensure stress has correct shape [batch_size, 3, 3] + if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): + stress = stress.view(-1, 3, 3) + results["stress"] = stress - results["energy"] = results["energy"].squeeze(dim=1) - if results.get("stress") is not None and len(results["stress"].shape) == 2: - results["stress"] = results["stress"].unsqueeze(dim=0) return results