From b1dec86e831b6153e061bc5de1837656c39a0ae7 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 17:37:53 -0400 Subject: [PATCH 01/10] fix 1.3_Fairchem.py for fairchem-core==1.10.0 --- examples/scripts/1_Introduction/1.3_Fairchem.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_Fairchem.py index b6f8dd5b..4fb4ac0f 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_Fairchem.py @@ -38,19 +38,22 @@ seed=0, ) atoms_list = [si_dc, si_dc] -state = ts.io.atoms_to_state(atoms_list) +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) results = model(state) print(results["energy"].shape) print(results["forces"].shape) -print(results["stress"].shape) +if stress := results.get("stress"): + print(stress.shape) print(f"Energy: {results['energy']}") print(f"Forces: {results['forces']}") -print(f"Stress: {results['stress']}") +if stress := results.get("stress"): + print(f"{stress=}") # Check if the energy, forces, and stress are the same for the Si system across the batch print(torch.max(torch.abs(results["energy"][0] - results["energy"][1]))) print(torch.max(torch.abs(results["forces"][0] - results["forces"][1]))) -print(torch.max(torch.abs(results["stress"][0] - results["stress"][1]))) +if stress := results.get("stress"): + print(torch.max(torch.abs(stress[0] - stress[1]))) From f32083d81bdef4ed3c654daaa27728fbcbbd29e9 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 17:40:08 -0400 Subject: [PATCH 02/10] bump FairChem to v2.2.0 and resolve breaking changes in torch_sim/models/fairchem.py - Modernize FairChemModel to use new FAIRChemCalculator.from_model_checkpoint() API - Replace deprecated imports (load_config, update_config, model_registry) - Simplify model loading with direct pretrained model name - Add required task_name parameter for model initialization - Remove unused imports and parameters (available_models, local_cache) --- .../scripts/1_Introduction/1.3_Fairchem.py | 19 +- torch_sim/models/fairchem.py | 363 ++++++------------ 2 files changed, 131 insertions(+), 251 deletions(-) diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_Fairchem.py index 4fb4ac0f..798c284a 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_Fairchem.py @@ -3,12 +3,10 @@ # /// script # dependencies = [ -# "fairchem-core==1.10.0", +# "fairchem-core>=2.2.0", # ] # /// -import sys - import torch from ase.build import bulk @@ -19,21 +17,16 @@ device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 -try: - from fairchem.core.models.model_registry import model_name_to_local_file -except ImportError: - print("Skipping example due to missing fairchem dependency") - sys.exit(0) - -MODEL_PATH = model_name_to_local_file( - "EquiformerV2-31M-S2EF-OC20-All+MD", local_cache="." -) +# UMA = Unified Machine Learning for Atomistic simulations +MODEL_NAME = "uma-s-1" # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2)) atomic_numbers = si_dc.get_atomic_numbers() model = FairChemModel( - model=MODEL_PATH, + model=None, + model_name=MODEL_NAME, + task_name="omat", # Open Materials task for crystalline systems cpu=False, seed=0, ) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 77b1b0ba..7702e4ba 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -18,11 +18,9 @@ from __future__ import annotations -import copy import traceback import typing import warnings -from types import MappingProxyType from typing import Any import torch @@ -32,15 +30,12 @@ try: - from fairchem.core.common.registry import registry - from fairchem.core.common.utils import ( - load_config, - setup_imports, - setup_logging, - update_config, + from fairchem.core.calculate.ase_calculator import ( + FAIRChemCalculator, + InferenceSettings, + UMATask, ) - from fairchem.core.models.model_registry import model_name_to_local_file - from torch_geometric.data import Batch, Data + from fairchem.core.common.utils import setup_imports, setup_logging except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) @@ -63,12 +58,6 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -_DTYPE_DICT = { - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", -} - class FairChemModel(ModelInterface): """Computes atomistic energies, forces and stresses using a FairChem model. @@ -82,73 +71,60 @@ class FairChemModel(ModelInterface): checkpoint. It supports various model architectures and configurations supported by FairChem. + This version uses the modern fairchem-core-2.2.0+ API with FAIRChemCalculator. + Attributes: - neighbor_list_fn (Callable | None): Function to compute neighbor lists - config (dict): Complete model configuration dictionary - trainer: FairChem trainer object that contains the model - data_object (Batch): Data object containing system information - implemented_properties (list): Model outputs the model can compute - pbc (bool): Whether periodic boundary conditions are used + calculator (FAIRChemCalculator): The underlying FairChem calculator + _device (torch.device): Device where computation is performed _dtype (torch.dtype): Data type used for computation _compute_stress (bool): Whether to compute stress tensor - _compute_forces (bool): Whether to compute forces - _device (torch.device): Device where computation is performed - _reshaped_props (dict): Properties that need reshaping after computation + implemented_properties (list): Model outputs the model can compute Examples: >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state) """ - _reshaped_props = MappingProxyType( - {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} - ) - - def __init__( # noqa: C901, PLR0915 + def __init__( self, model: str | Path | None, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only - config_yml: str | None = None, model_name: str | None = None, - local_cache: str | None = None, - trainer: str | None = None, cpu: bool = False, - seed: int | None = None, + seed: int = 41, dtype: torch.dtype | None = None, compute_stress: bool = False, - pbc: bool = True, - disable_amp: bool = True, + task_name: UMATask | str | None = None, + inference_settings: InferenceSettings | str = "default", + overrides: dict | None = None, ) -> None: """Initialize the FairChemModel with specified configuration. - Loads a FairChem model from either a checkpoint path or a configuration file. - Sets up the model parameters, trainer, and configuration for subsequent use - in energy and force calculations. + Uses the modern FAIRChemCalculator.from_model_checkpoint API for simplified + model loading and configuration. Args: model (str | Path | None): Path to model checkpoint file neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) - config_yml (str | None): Path to configuration YAML file model_name (str | None): Name of pretrained model to load - local_cache (str | None): Path to local model cache directory - trainer (str | None): Name of trainer class to use cpu (bool): Whether to use CPU instead of GPU for computation - seed (int | None): Random seed for reproducibility + seed (int): Random seed for reproducibility dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor - pbc (bool): Whether to use periodic boundary conditions - disable_amp (bool): Whether to disable AMP + task_name (UMATask | str | None): Task type for the model + inference_settings (InferenceSettings | str): Inference configuration + overrides (dict | None): Configuration overrides + Raises: RuntimeError: If both model_name and model are specified - NotImplementedError: If local_cache is not set when model_name is used NotImplementedError: If custom neighbor list function is provided - ValueError: If stress computation is requested but not supported by model + ValueError: If neither model nor model_name is provided Notes: - Either config_yml or model must be provided. The model loads configuration - from the checkpoint if config_yml is not specified. + This uses the new fairchem-core-2.2.0+ API which is much simpler than + the previous versions. """ setup_imports() setup_logging() @@ -158,7 +134,11 @@ def __init__( # noqa: C901, PLR0915 self._compute_stress = compute_stress self._compute_forces = True self._memory_scales_with = "n_atoms" - self.pbc = pbc + + if neighbor_list_fn is not None: + raise NotImplementedError( + "Custom neighbor list is not supported for FairChemModel." + ) if model_name is not None: if model is not None: @@ -166,160 +146,47 @@ def __init__( # noqa: C901, PLR0915 "model_name and checkpoint_path were both specified, " "please use only one at a time" ) - if local_cache is None: - raise NotImplementedError( - "Local cache must be set when specifying a model name" - ) - model = model_name_to_local_file( - model_name=model_name, local_cache=local_cache - ) - - # Either the config path or the checkpoint path needs to be provided - if not config_yml and model is None: - raise ValueError("Either config_yml or model must be provided") - - checkpoint = None - if config_yml is not None: - if isinstance(config_yml, str): - config, duplicates_warning, duplicates_error = load_config(config_yml) - if len(duplicates_warning) > 0: - print( - "Overwritten config parameters from included configs " - f"(non-included parameters take precedence): {duplicates_warning}" - ) - if len(duplicates_error) > 0: - raise ValueError( - "Conflicting (duplicate) parameters in simultaneously " - f"included configs: {duplicates_error}" - ) - else: - config = config_yml - - # Only keeps the train data that might have normalizer values - if isinstance(config["dataset"], list): - config["dataset"] = config["dataset"][0] - elif isinstance(config["dataset"], dict): - config["dataset"] = config["dataset"].get("train", None) - else: - # Loads the config from the checkpoint directly (always on CPU). - checkpoint = torch.load(model, map_location=torch.device("cpu")) - config = checkpoint["config"] - - if trainer is not None: - config["trainer"] = trainer - else: - config["trainer"] = config.get("trainer", "ocp") - - if "model_attributes" in config: - config["model_attributes"]["name"] = config.pop("model") - config["model"] = config["model_attributes"] - - self.neighbor_list_fn = neighbor_list_fn - - if neighbor_list_fn is None: - # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - else: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemModel." - ) - - if "backbone" in config["model"]: - config["model"]["backbone"]["use_pbc"] = pbc - config["model"]["backbone"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"]["backbone"].update({"dtype": _DTYPE_DICT[dtype]}) - for key in config["model"]["heads"]: - config["model"]["heads"][key].update( - {"dtype": _DTYPE_DICT[dtype]} - ) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - else: - config["model"]["use_pbc"] = pbc - config["model"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"].update({"dtype": _DTYPE_DICT[dtype]}) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - - ### backwards compatibility with OCP v<2.0 - config = update_config(config) - - self.config = copy.deepcopy(config) - self.config["checkpoint"] = str(model) - del config["dataset"]["src"] - - self.trainer = registry.get_trainer_class(config["trainer"])( - task=config.get("task", {}), - model=config["model"], - dataset=[config["dataset"]], - outputs=config["outputs"], - loss_functions=config["loss_functions"], - evaluation_metrics=config["evaluation_metrics"], - optimizer=config["optim"], - identifier="", - slurm=config.get("slurm", {}), - local_rank=config.get("local_rank", 0), - is_debug=config.get("is_debug", True), - cpu=cpu, - amp=False if dtype is not None else config.get("amp", False), - inference_only=True, + # For fairchem-core-2.2.0+, model_name can be used directly + # as it supports pretrained model names from available_models + model = model_name + + if model is None: + raise ValueError("Either model or model_name must be provided") + + # Convert task_name to UMATask if it's a string + if isinstance(task_name, str): + task_name = UMATask(task_name) + + # Use the new simplified API + device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu" + + self.calculator = FAIRChemCalculator.from_model_checkpoint( + name_or_path=str(model), + task_name=task_name, + inference_settings=inference_settings, + overrides=overrides, + device=device_str, + seed=seed, ) - if dtype is not None: - # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self.dtype) + self._device = torch.device(device_str) - if model is not None: - self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) + # Determine implemented properties from the calculator + # This is a simplified approach - in practice you might want to + # inspect the model configuration more carefully + self.implemented_properties = ["energy", "forces"] + if compute_stress: + self.implemented_properties.append("stress") - seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] - if seed is None: - print( - "No seed has been set in model checkpoint or OCPCalculator! Results may " - "not be reproducible on re-run" - ) - else: - self.trainer.set_seed(seed) - - if disable_amp: - self.trainer.scaler = None - - self.implemented_properties = list(self.config["outputs"]) - - self._device = self.trainer.device - - stress_output = "stress" in self.implemented_properties - if not stress_output and compute_stress: - raise NotImplementedError("Stress output not implemented for this model") - - def load_checkpoint( - self, checkpoint_path: str, checkpoint: dict | None = None - ) -> None: - """Load an existing trained model checkpoint. + @property + def dtype(self) -> torch.dtype: + """Return the data type used by the model.""" + return self._dtype - Loads model parameters from a checkpoint file or dictionary, - setting the model to inference mode. - - Args: - checkpoint_path (str): Path to the trained model checkpoint file - checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, - this dictionary is used instead of loading from checkpoint_path. - - Notes: - If loading fails, a message is printed but no exception is raised. - """ - try: - self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) - except NotImplementedError: - print("Unable to load checkpoint!") + @property + def device(self) -> torch.device: + """Return the device where the model is located.""" + return self._device def forward(self, state: ts.SimState | StateDict) -> dict: """Perform forward pass to compute energies, forces, and other properties. @@ -336,12 +203,11 @@ def forward(self, state: ts.SimState | StateDict) -> dict: dict: Dictionary of model predictions, which may include: - energy (torch.Tensor): Energy with shape [batch_size] - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], - if compute_stress is True + - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] Notes: - The state is automatically transferred to the model's device if needed. - All output tensors are detached from the computation graph. + This implementation uses the FAIRChemCalculator which expects ASE Atoms + objects. The conversion is handled internally. """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) @@ -349,50 +215,71 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.device != self._device: state = state.to(self._device) - if state.system_idx is None: - state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) + # Convert torch_sim SimState to ASE Atoms objects for FAIRChemCalculator + from ase import Atoms - if self.pbc != state.pbc: - raise ValueError( - "PBC mismatch between model and state. " - "For FairChemModel PBC needs to be defined in the model class." - ) + if state.batch is None: + state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + + natoms = torch.bincount(state.batch) + atoms_list = [] - natoms = torch.bincount(state.system_idx) - fixed = torch.zeros((state.system_idx.size(0), natoms.sum()), dtype=torch.int) - data_list = [] for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) ): - data_list.append( - Data( - pos=state.positions[c - n : c].clone(), - cell=state.row_vector_cell[i, None].clone(), - atomic_numbers=state.atomic_numbers[c - n : c].clone(), - fixed=fixed[c - n : c].clone(), - natoms=n, - pbc=torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool), - ) + positions = state.positions[c - n : c].cpu().numpy() + atomic_numbers = state.atomic_numbers[c - n : c].cpu().numpy() + cell = ( + state.row_vector_cell[i].cpu().numpy() + if state.row_vector_cell is not None + else None ) - self.data_object = Batch.from_data_list(data_list) - - if self.dtype is not None: - self.data_object.pos = self.data_object.pos.to(self.dtype) - self.data_object.cell = self.data_object.cell.to(self.dtype) - predictions = self.trainer.predict( - self.data_object, per_image=False, disable_tqdm=True - ) + atoms = Atoms( + numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=state.pbc if cell is not None else False, + ) + atoms_list.append(atoms) + # Use FAIRChemCalculator to compute properties results = {} + energies = [] + forces_list = [] + stress_list = [] + + for atoms in atoms_list: + atoms.calc = self.calculator + + # Get energy + energy = atoms.get_potential_energy() + energies.append(energy) + + # Get forces + forces = atoms.get_forces() + forces_list.append( + torch.from_numpy(forces).to(self._device, dtype=self._dtype) + ) + + # Get stress if requested + if self._compute_stress: + try: + stress = atoms.get_stress(voigt=False) # 3x3 tensor + stress_list.append( + torch.from_numpy(stress).to(self._device, dtype=self._dtype) + ) + except (RuntimeError, AttributeError, NotImplementedError): + # If stress computation fails, fill with zeros + stress_list.append( + torch.zeros(3, 3, device=self._device, dtype=self._dtype) + ) + + # Combine results + results["energy"] = torch.tensor(energies, device=self._device, dtype=self._dtype) + results["forces"] = torch.cat(forces_list, dim=0) - for key in predictions: - _pred = predictions[key] - if key in self._reshaped_props: - _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred.detach() + if self._compute_stress and stress_list: + results["stress"] = torch.stack(stress_list, dim=0) - results["energy"] = results["energy"].squeeze(dim=1) - if results.get("stress") is not None and len(results["stress"].shape) == 2: - results["stress"] = results["stress"].unsqueeze(dim=0) return results From b48acf6a9ff7f80af2f41167291e8d9b1929b5a7 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 17:46:39 -0400 Subject: [PATCH 03/10] refactor FairChem tests to use UMA model - Replace OCPCalculator with FAIRChemCalculator in test fixtures - Remove unused model path fixtures and simplify model initialization - update test parameters for UMA model tolerances --- tests/models/test_fairchem.py | 77 ++++++++++++++--------------------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b04ad9c8..ac0af7df 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,5 +1,3 @@ -import os - import pytest import torch @@ -11,8 +9,7 @@ try: - from fairchem.core import OCPCalculator - from fairchem.core.models.model_registry import model_name_to_local_file + from fairchem.core.calculate.ase_calculator import FAIRChemCalculator from huggingface_hub.utils._auth import get_token from torch_sim.models.fairchem import FairChemModel @@ -21,52 +18,41 @@ pytest.skip("FairChem not installed", allow_module_level=True) -@pytest.fixture(scope="session") -def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - @pytest.fixture -def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: +def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel: + """Use the UMA model which is available in fairchem-core-2.2.0+.""" cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=True) + return FairChemModel( + model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0 + ) @pytest.fixture -def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: +def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel: + """Use the UMA model for non-PBC systems.""" cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=False) - - -if get_token(): - - @pytest.fixture(scope="session") - def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - @pytest.fixture - def eqv2_omat24_model_pbc( - model_path_omat24: str, device: torch.device - ) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_omat24, cpu=cpu, seed=0, pbc=True) + return FairChemModel( + model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0 + ) @pytest.fixture -def ocp_calculator(model_path_oc20: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) +def fairchem_calculator() -> FAIRChemCalculator: + """FAIRChemCalculator using the UMA model.""" + return FAIRChemCalculator.from_model_checkpoint( + name_or_path="uma-s-1", + task_name="omat", + device="cpu", + seed=0, + ) test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( - test_name="fairchem_ocp", - model_fixture_name="eqv2_oc20_model_pbc", - calculator_fixture_name="ocp_calculator", + test_name="fairchem_uma", + model_fixture_name="eqv2_uma_model_pbc", + calculator_fixture_name="fairchem_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + energy_rtol=5e-4, # NOTE: UMA model tolerances energy_atol=5e-4, force_rtol=5e-4, force_atol=5e-4, @@ -76,10 +62,10 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( test_name="fairchem_non_pbc_benzene", - model_fixture_name="eqv2_oc20_model_non_pbc", - calculator_fixture_name="ocp_calculator", + model_fixture_name="eqv2_uma_model_non_pbc", + calculator_fixture_name="fairchem_calculator", sim_state_names=["benzene_sim_state"], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + energy_rtol=5e-4, # NOTE: UMA model tolerances energy_atol=5e-4, force_rtol=5e-4, force_atol=5e-4, @@ -88,10 +74,7 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator: ) -# Skip this test due to issues with how the older models -# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428) - -test_fairchem_ocp_model_outputs = pytest.mark.skipif( - os.environ.get("HF_TOKEN") is None, - reason="Issues in graph construction of older models", -)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) +test_fairchem_uma_model_outputs = pytest.mark.skipif( + get_token() is None, + reason="Requires HuggingFace authentication for UMA model access", +)(make_validate_model_outputs_test(model_fixture_name="eqv2_uma_model_pbc")) From ce01e5ee7269c7033484026851f958dd5c04d7f1 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 17:53:38 -0400 Subject: [PATCH 04/10] try loading UMA model in CI (requires huggingface auth) using my personal HF token for now --- .github/workflows/test.yml | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f08a83c1..9f440ef5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,14 +82,6 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - name: Check out fairchem repository - if: ${{ matrix.model.name == 'fairchem' }} - uses: actions/checkout@v4 - with: - repository: FAIR-Chem/fairchem - path: fairchem-repo - ref: fairchem_core-1.10.0 - - name: Set up Python uses: actions/setup-python@v5 with: @@ -98,24 +90,13 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 - - name: Install fairchem repository and dependencies + - name: Install fairchem and dependencies if: ${{ matrix.model.name == 'fairchem' }} - env: - HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - uv pip install huggingface_hub --system - if [ -n "$HF_TOKEN" ]; then - huggingface-cli login --token "$HF_TOKEN" - else - echo "HF_TOKEN is not set. Skipping login." - fi - if [ -f fairchem-repo/packages/requirements.txt ]; then - uv pip install -r fairchem-repo/packages/requirements.txt --system - fi - if [ -f fairchem-repo/packages/requirements-optional.txt ]; then - uv pip install -r fairchem-repo/packages/requirements-optional.txt --system - fi - uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system + uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system + uv pip install "torch-scatter" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system + uv pip install "torch-sparse" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system + uv pip install "fairchem-core>=2.2.0" --system uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system - name: Install torch_sim with model dependencies @@ -124,6 +105,8 @@ jobs: uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system - name: Run ${{ matrix.model.test_path }} tests + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} From b9e0a6cb0f7d881f74aef67dccf235eaf90f076a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 17:59:01 -0400 Subject: [PATCH 05/10] include HF login in fairchem CI workflow --- .github/workflows/test.yml | 15 ++++++++++++++- .../{1.3_Fairchem.py => 1.3_fairchem.py} | 0 2 files changed, 14 insertions(+), 1 deletion(-) rename examples/scripts/1_Introduction/{1.3_Fairchem.py => 1.3_fairchem.py} (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f440ef5..fe0b2d08 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -92,11 +92,14 @@ jobs: - name: Install fairchem and dependencies if: ${{ matrix.model.name == 'fairchem' }} + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system uv pip install "torch-scatter" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system uv pip install "torch-sparse" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system uv pip install "fairchem-core>=2.2.0" --system + uv pip install "huggingface_hub[cli]" --system uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system - name: Install torch_sim with model dependencies @@ -108,6 +111,9 @@ jobs: env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | + if [ "${{ matrix.model.name }}" == "fairchem" ]; then + huggingface-cli login --token "$HF_TOKEN" + fi pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} - name: Upload coverage to Codecov @@ -151,4 +157,11 @@ jobs: uses: astral-sh/setup-uv@v6 - name: Run example - run: uv run --with . ${{ matrix.example }} + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + run: | + if [[ "${{ matrix.example }}" == *"fairchem"* ]]; then + uv pip install "huggingface_hub[cli]" --system + huggingface-cli login --token "$HF_TOKEN" + fi + uv run --with . ${{ matrix.example }} diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_fairchem.py similarity index 100% rename from examples/scripts/1_Introduction/1.3_Fairchem.py rename to examples/scripts/1_Introduction/1.3_fairchem.py From 2e74ef2256809a0f1ed98027394f298b793ba1e4 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 18:28:33 -0400 Subject: [PATCH 06/10] https://github.com/Radical-AI/torch-sim/pull/211#discussion_r2141207587 --- .github/workflows/test.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fe0b2d08..f47f6238 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -95,9 +95,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system - uv pip install "torch-scatter" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system - uv pip install "torch-sparse" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system + uv pip install "torch>=2.6" --index-url https://download.pytorch.org/whl/cpu --system uv pip install "fairchem-core>=2.2.0" --system uv pip install "huggingface_hub[cli]" --system uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system From 419d18e1f64056701ec402fb9380d39f7d0b96db Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 18:41:37 -0400 Subject: [PATCH 07/10] perf: use new fairchem batch inference API to avoid going through ASE calculator interface see https://fair-chem.github.io/core/common_tasks/batch_inference.html --- .../scripts/1_Introduction/1.3_fairchem.py | 1 - tests/models/test_fairchem.py | 52 +------- torch_sim/models/fairchem.py | 120 +++++++----------- 3 files changed, 49 insertions(+), 124 deletions(-) diff --git a/examples/scripts/1_Introduction/1.3_fairchem.py b/examples/scripts/1_Introduction/1.3_fairchem.py index 798c284a..6fc9a42d 100644 --- a/examples/scripts/1_Introduction/1.3_fairchem.py +++ b/examples/scripts/1_Introduction/1.3_fairchem.py @@ -28,7 +28,6 @@ model_name=MODEL_NAME, task_name="omat", # Open Materials task for crystalline systems cpu=False, - seed=0, ) atoms_list = [si_dc, si_dc] state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index ac0af7df..d0b4bb14 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,15 +1,10 @@ import pytest import torch -from tests.models.conftest import ( - consistency_test_simstate_fixtures, - make_model_calculator_consistency_test, - make_validate_model_outputs_test, -) +from tests.models.conftest import make_validate_model_outputs_test try: - from fairchem.core.calculate.ase_calculator import FAIRChemCalculator from huggingface_hub.utils._auth import get_token from torch_sim.models.fairchem import FairChemModel @@ -22,56 +17,17 @@ def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel: """Use the UMA model which is available in fairchem-core-2.2.0+.""" cpu = device.type == "cpu" - return FairChemModel( - model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0 - ) + return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) @pytest.fixture def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel: """Use the UMA model for non-PBC systems.""" cpu = device.type == "cpu" - return FairChemModel( - model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0 - ) + return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) -@pytest.fixture -def fairchem_calculator() -> FAIRChemCalculator: - """FAIRChemCalculator using the UMA model.""" - return FAIRChemCalculator.from_model_checkpoint( - name_or_path="uma-s-1", - task_name="omat", - device="cpu", - seed=0, - ) - - -test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( - test_name="fairchem_uma", - model_fixture_name="eqv2_uma_model_pbc", - calculator_fixture_name="fairchem_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], - energy_rtol=5e-4, # NOTE: UMA model tolerances - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, -) - -test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( - test_name="fairchem_non_pbc_benzene", - model_fixture_name="eqv2_uma_model_non_pbc", - calculator_fixture_name="fairchem_calculator", - sim_state_names=["benzene_sim_state"], - energy_rtol=5e-4, # NOTE: UMA model tolerances - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, -) +# Removed calculator consistency tests since we're using predictor interface only test_fairchem_uma_model_outputs = pytest.mark.skipif( diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 7702e4ba..869931a3 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -30,12 +30,10 @@ try: - from fairchem.core.calculate.ase_calculator import ( - FAIRChemCalculator, - InferenceSettings, - UMATask, - ) + from fairchem.core import pretrained_mlip + from fairchem.core.calculate.ase_calculator import UMATask from fairchem.core.common.utils import setup_imports, setup_logging + from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) @@ -71,10 +69,11 @@ class FairChemModel(ModelInterface): checkpoint. It supports various model architectures and configurations supported by FairChem. - This version uses the modern fairchem-core-2.2.0+ API with FAIRChemCalculator. + This version uses the efficient fairchem-core-2.2.0+ predictor API. Attributes: - calculator (FAIRChemCalculator): The underlying FairChem calculator + predictor: The FairChem predictor for batch inference + task_name (UMATask): Task type for the model _device (torch.device): Device where computation is performed _dtype (torch.dtype): Data type used for computation _compute_stress (bool): Whether to compute stress tensor @@ -92,17 +91,13 @@ def __init__( *, # force remaining arguments to be keyword-only model_name: str | None = None, cpu: bool = False, - seed: int = 41, dtype: torch.dtype | None = None, compute_stress: bool = False, task_name: UMATask | str | None = None, - inference_settings: InferenceSettings | str = "default", - overrides: dict | None = None, ) -> None: """Initialize the FairChemModel with specified configuration. - Uses the modern FAIRChemCalculator.from_model_checkpoint API for simplified - model loading and configuration. + Uses the efficient FairChem predictor interface for optimal performance. Args: model (str | Path | None): Path to model checkpoint file @@ -110,12 +105,9 @@ def __init__( (not currently supported) model_name (str | None): Name of pretrained model to load cpu (bool): Whether to use CPU instead of GPU for computation - seed (int): Random seed for reproducibility dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor task_name (UMATask | str | None): Task type for the model - inference_settings (InferenceSettings | str): Inference configuration - overrides (dict | None): Configuration overrides Raises: RuntimeError: If both model_name and model are specified @@ -123,8 +115,8 @@ def __init__( ValueError: If neither model nor model_name is provided Notes: - This uses the new fairchem-core-2.2.0+ API which is much simpler than - the previous versions. + This uses the efficient fairchem-core-2.2.0+ predictor API for + optimal batch inference performance. """ setup_imports() setup_logging() @@ -146,8 +138,6 @@ def __init__( "model_name and checkpoint_path were both specified, " "please use only one at a time" ) - # For fairchem-core-2.2.0+, model_name can be used directly - # as it supports pretrained model names from available_models model = model_name if model is None: @@ -157,21 +147,15 @@ def __init__( if isinstance(task_name, str): task_name = UMATask(task_name) - # Use the new simplified API + # Use the efficient predictor API for optimal performance device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu" - - self.calculator = FAIRChemCalculator.from_model_checkpoint( - name_or_path=str(model), - task_name=task_name, - inference_settings=inference_settings, - overrides=overrides, - device=device_str, - seed=seed, - ) - self._device = torch.device(device_str) + self.task_name = task_name + + # Create efficient batch predictor for fast inference + self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str) - # Determine implemented properties from the calculator + # Determine implemented properties # This is a simplified approach - in practice you might want to # inspect the model configuration more carefully self.implemented_properties = ["energy", "forces"] @@ -191,8 +175,8 @@ def device(self) -> torch.device: def forward(self, state: ts.SimState | StateDict) -> dict: """Perform forward pass to compute energies, forces, and other properties. - Takes a simulation state and computes the properties implemented by the model, - such as energy, forces, and stresses. + Uses efficient batch inference with FairChem's native tensor interface for + optimal performance on both single systems and large batches. Args: state (SimState | StateDict): State object containing positions, cells, @@ -206,8 +190,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict: - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] Notes: - This implementation uses the FAIRChemCalculator which expects ASE Atoms - objects. The conversion is handled internally. + This implementation uses FairChem's efficient batch predictor interface + for optimal performance on both single systems and large batches. """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) @@ -215,18 +199,19 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.device != self._device: state = state.to(self._device) - # Convert torch_sim SimState to ASE Atoms objects for FAIRChemCalculator - from ase import Atoms - if state.batch is None: state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + # Convert SimState to AtomicData objects for efficient batch processing + from ase import Atoms + natoms = torch.bincount(state.batch) - atoms_list = [] + atomic_data_list = [] for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) ): + # Extract system data positions = state.positions[c - n : c].cpu().numpy() atomic_numbers = state.atomic_numbers[c - n : c].cpu().numpy() cell = ( @@ -235,51 +220,36 @@ def forward(self, state: ts.SimState | StateDict) -> dict: else None ) + # Create ASE Atoms object first atoms = Atoms( numbers=atomic_numbers, positions=positions, cell=cell, pbc=state.pbc if cell is not None else False, ) - atoms_list.append(atoms) - # Use FAIRChemCalculator to compute properties - results = {} - energies = [] - forces_list = [] - stress_list = [] + # Convert ASE Atoms to AtomicData with task_name + atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + atomic_data_list.append(atomic_data) - for atoms in atoms_list: - atoms.calc = self.calculator + # Create batch for efficient inference + batch = atomicdata_list_to_batch(atomic_data_list) + batch = batch.to(self._device) - # Get energy - energy = atoms.get_potential_energy() - energies.append(energy) + # Run efficient batch prediction + predictions = self.predictor.predict(batch) - # Get forces - forces = atoms.get_forces() - forces_list.append( - torch.from_numpy(forces).to(self._device, dtype=self._dtype) - ) - - # Get stress if requested - if self._compute_stress: - try: - stress = atoms.get_stress(voigt=False) # 3x3 tensor - stress_list.append( - torch.from_numpy(stress).to(self._device, dtype=self._dtype) - ) - except (RuntimeError, AttributeError, NotImplementedError): - # If stress computation fails, fill with zeros - stress_list.append( - torch.zeros(3, 3, device=self._device, dtype=self._dtype) - ) - - # Combine results - results["energy"] = torch.tensor(energies, device=self._device, dtype=self._dtype) - results["forces"] = torch.cat(forces_list, dim=0) - - if self._compute_stress and stress_list: - results["stress"] = torch.stack(stress_list, dim=0) + # Convert predictions to torch_sim format + results = {} + results["energy"] = predictions["energy"].to(dtype=self._dtype) + results["forces"] = predictions["forces"].to(dtype=self._dtype) + + # Handle stress if requested and available + if self._compute_stress and "stress" in predictions: + stress = predictions["stress"].to(dtype=self._dtype) + # Ensure stress has correct shape [batch_size, 3, 3] + if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): + stress = stress.view(-1, 3, 3) + results["stress"] = stress return results From ed735ecf9fa9e6d51a209bbe005a3187f346e631 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 19:14:02 -0400 Subject: [PATCH 08/10] address https://github.com/Radical-AI/torch-sim/pull/211#discussion_r2141240214 --- torch_sim/models/fairchem.py | 63 +++++++++++------------------------- 1 file changed, 19 insertions(+), 44 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 869931a3..989b433b 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -1,21 +1,11 @@ -"""Wrapper for FairChem ecosystem models in TorchSim. +"""FairChem model wrapper for torch_sim. -This module provides a TorchSim wrapper of the FairChem models for computing -energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the FairChem library, integrating it with the torch_sim framework to enable seamless -simulation of atomistic systems with machine learning potentials. +Provides a TorchSim-compatible interface to FairChem models for computing +energies, forces, and stresses of atomistic systems. -The FairChemModel class adapts FairChem models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. - -Notes: - This implementation requires FairChem to be installed and accessible. - It supports various model configurations through configuration files or - pretrained model checkpoints. +Requires fairchem-core to be installed. """ -# ruff: noqa: T201 - from __future__ import annotations import traceback @@ -57,19 +47,13 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -class FairChemModel(ModelInterface): - """Computes atomistic energies, forces and stresses using a FairChem model. +class FairChemModel(torch.nn.Module, ModelInterface): + """FairChem model wrapper for computing atomistic properties. - This class wraps a FairChem model to compute energies, forces, and stresses for - atomistic systems. It handles model initialization, checkpoint loading, and - provides a forward pass that accepts a SimState object and returns model - predictions. + Wraps FairChem models to compute energies, forces, and stresses. Can be + initialized with a model checkpoint path or pretrained model name. - The model can be initialized either with a configuration file or a pretrained - checkpoint. It supports various model architectures and configurations supported by - FairChem. - - This version uses the efficient fairchem-core-2.2.0+ predictor API. + Uses the fairchem-core-2.2.0+ predictor API for batch inference. Attributes: predictor: The FairChem predictor for batch inference @@ -95,9 +79,7 @@ def __init__( compute_stress: bool = False, task_name: UMATask | str | None = None, ) -> None: - """Initialize the FairChemModel with specified configuration. - - Uses the efficient FairChem predictor interface for optimal performance. + """Initialize the FairChem model. Args: model (str | Path | None): Path to model checkpoint file @@ -107,16 +89,13 @@ def __init__( cpu (bool): Whether to use CPU instead of GPU for computation dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor - task_name (UMATask | str | None): Task type for the model + task_name (UMATask | str | None): Task type for UMA models (optional, + only needed for UMA models) Raises: RuntimeError: If both model_name and model are specified NotImplementedError: If custom neighbor list function is provided ValueError: If neither model nor model_name is provided - - Notes: - This uses the efficient fairchem-core-2.2.0+ predictor API for - optimal batch inference performance. """ setup_imports() setup_logging() @@ -143,7 +122,7 @@ def __init__( if model is None: raise ValueError("Either model or model_name must be provided") - # Convert task_name to UMATask if it's a string + # Convert task_name to UMATask if it's a string (only for UMA models) if isinstance(task_name, str): task_name = UMATask(task_name) @@ -173,10 +152,7 @@ def device(self) -> torch.device: return self._device def forward(self, state: ts.SimState | StateDict) -> dict: - """Perform forward pass to compute energies, forces, and other properties. - - Uses efficient batch inference with FairChem's native tensor interface for - optimal performance on both single systems and large batches. + """Compute energies, forces, and other properties. Args: state (SimState | StateDict): State object containing positions, cells, @@ -188,10 +164,6 @@ def forward(self, state: ts.SimState | StateDict) -> dict: - energy (torch.Tensor): Energy with shape [batch_size] - forces (torch.Tensor): Forces with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] - - Notes: - This implementation uses FairChem's efficient batch predictor interface - for optimal performance on both single systems and large batches. """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) @@ -228,8 +200,11 @@ def forward(self, state: ts.SimState | StateDict) -> dict: pbc=state.pbc if cell is not None else False, ) - # Convert ASE Atoms to AtomicData with task_name - atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + if self.task_name is None: + atomic_data = AtomicData.from_ase(atoms) + else: + atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) atomic_data_list.append(atomic_data) # Create batch for efficient inference From b04ca55d6470bde852eb9c2f34c41e657665c812 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 11 Jun 2025 19:14:36 -0400 Subject: [PATCH 09/10] more FairChem tests with UMA model and batch processing - parameterized tests for different UMA task names and system configs - tests for homogeneous and heterogeneous batching, ensuring correct energy and force outputs - stress tensor computation tests with conditional checks - test error handling for empty batches --- tests/models/test_fairchem.py | 200 +++++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 6 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index d0b4bb14..f73b6a97 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -5,8 +5,12 @@ try: + from collections.abc import Callable + + from ase.build import bulk, fcc100, molecule from huggingface_hub.utils._auth import get_token + import torch_sim as ts from torch_sim.models.fairchem import FairChemModel except ImportError: @@ -15,19 +19,203 @@ @pytest.fixture def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel: - """Use the UMA model which is available in fairchem-core-2.2.0+.""" + """UMA model for periodic boundary condition systems.""" cpu = device.type == "cpu" return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) -@pytest.fixture -def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel: - """Use the UMA model for non-PBC systems.""" +# Removed calculator consistency tests since we're using predictor interface only + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) +def test_task_initialization(task_name: str) -> None: + """Test that different UMA task names work correctly.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True) + assert model.task_name.value == task_name + assert hasattr(model, "predictor") + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("task_name", "systems_func"), + [ + ( + "omat", + lambda: [ + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05), + bulk("Fe", "bcc", a=2.87), + bulk("Cu", "fcc", a=3.61), + ], + ), + ( + "omol", + lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], + ), + ], +) +def test_homogeneous_batching( + task_name: str, systems_func: Callable, device: torch.device, dtype: torch.dtype +) -> None: + """Test batching multiple systems with the same task.""" + systems = systems_func() + + # Add molecular properties for molecules + if task_name == "omol": + for mol in systems: + mol.info.update({"charge": 0, "spin": 1}) + + model = FairChemModel( + model=None, model_name="uma-s-1", task_name=task_name, cpu=device.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + # Check batch dimensions + assert results["energy"].shape == (4,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + + # Check that different systems have different energies + energies = results["energy"] + unique_energies = torch.unique(energies, dim=0) + assert len(unique_energies) > 1, "Different systems should have different energies" + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_heterogeneous_tasks(device: torch.device, dtype: torch.dtype) -> None: + """Test different task types work with appropriate systems.""" + # Test molecule, material, and catalysis systems separately + test_cases = [ + ("omol", [molecule("H2O")]), + ("omat", [bulk("Pt", cubic=True)]), + ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), + ] + + for task_name, systems in test_cases: + if task_name == "omol": + systems[0].info.update({"charge": 0, "spin": 1}) + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name=task_name, + cpu=device.type == "cpu", + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape[0] == 1 + assert results["forces"].dim() == 2 + assert results["forces"].shape[1] == 3 + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("systems_func", "expected_count"), + [ + (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system + ( + lambda: [ + bulk("H", "bcc", a=2.0), + bulk("Li", "bcc", a=3.0), + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), + ], + 4, + ), # Mixed sizes + ( + lambda: [ + bulk(element, "fcc", a=4.0) + for element in ["Al", "Cu", "Ni", "Pd", "Pt"] * 3 + ], + 15, + ), # Large batch + ], +) +def test_batch_size_variations( + systems_func: Callable, expected_count: int, device: torch.device, dtype: torch.dtype +) -> None: + """Test batching with different numbers and sizes of systems.""" + systems = systems_func() + + model = FairChemModel( + model=None, model_name="uma-s-1", task_name="omat", cpu=device.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + assert results["energy"].shape == (expected_count,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize("compute_stress", [True, False]) +def test_stress_computation( + *, compute_stress: bool, device: torch.device, dtype: torch.dtype +) -> None: + """Test stress tensor computation.""" + systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name="omat", + cpu=device.type == "cpu", + compute_stress=compute_stress, + ) + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) + results = model(state) + + if compute_stress: + assert "stress" in results + assert results["stress"].shape == (2, 3, 3) + assert torch.isfinite(results["stress"]).all() + else: + assert "stress" not in results + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_device_consistency(dtype: torch.dtype) -> None: + """Test device consistency between model and data.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cpu = device.type == "cpu" - return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=device, dtype=dtype) -# Removed calculator consistency tests since we're using predictor interface only + results = model(state) + assert results["energy"].device == device + assert results["forces"].device == device + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_empty_batch_error() -> None: + """Test that empty batches raise appropriate errors.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True) + with pytest.raises((ValueError, RuntimeError, IndexError)): + model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32)) test_fairchem_uma_model_outputs = pytest.mark.skipif( From d0c91cdbb18ccc37c2b677ab6a36ea5dde8ea176 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 13 Aug 2025 21:51:28 -0400 Subject: [PATCH 10/10] Fairchem v2 patch (#238) --- torch_sim/models/fairchem.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 989b433b..3aae3f3a 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -47,7 +47,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -class FairChemModel(torch.nn.Module, ModelInterface): +class FairChemModel(ModelInterface): """FairChem model wrapper for computing atomistic properties. Wraps FairChem models to compute energies, forces, and stresses. Can be @@ -171,13 +171,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.device != self._device: state = state.to(self._device) - if state.batch is None: - state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + if state.system_idx is None: + state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) # Convert SimState to AtomicData objects for efficient batch processing from ase import Atoms - natoms = torch.bincount(state.batch) + natoms = torch.bincount(state.system_idx) atomic_data_list = [] for i, (n, c) in enumerate(