Skip to content

Commit 74d6e5d

Browse files
committed
perf: use new fairchem batch inference API to avoid going through ASE calculator interface
see https://fair-chem.github.io/core/common_tasks/batch_inference.html
1 parent 21ecd1c commit 74d6e5d

File tree

3 files changed

+49
-124
lines changed

3 files changed

+49
-124
lines changed

examples/scripts/1_Introduction/1.3_fairchem.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
model_name=MODEL_NAME,
2929
task_name="omat", # Open Materials task for crystalline systems
3030
cpu=False,
31-
seed=0,
3231
)
3332
atoms_list = [si_dc, si_dc]
3433
state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)

tests/models/test_fairchem.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import pytest
22
import torch
33

4-
from tests.models.conftest import (
5-
consistency_test_simstate_fixtures,
6-
make_model_calculator_consistency_test,
7-
make_validate_model_outputs_test,
8-
)
4+
from tests.models.conftest import make_validate_model_outputs_test
95

106

117
try:
12-
from fairchem.core.calculate.ase_calculator import FAIRChemCalculator
138
from huggingface_hub.utils._auth import get_token
149

1510
from torch_sim.models.fairchem import FairChemModel
@@ -22,56 +17,17 @@
2217
def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel:
2318
"""Use the UMA model which is available in fairchem-core-2.2.0+."""
2419
cpu = device.type == "cpu"
25-
return FairChemModel(
26-
model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0
27-
)
20+
return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
2821

2922

3023
@pytest.fixture
3124
def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel:
3225
"""Use the UMA model for non-PBC systems."""
3326
cpu = device.type == "cpu"
34-
return FairChemModel(
35-
model=None, model_name="uma-s-1", task_name="omat", cpu=cpu, seed=0
36-
)
27+
return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
3728

3829

39-
@pytest.fixture
40-
def fairchem_calculator() -> FAIRChemCalculator:
41-
"""FAIRChemCalculator using the UMA model."""
42-
return FAIRChemCalculator.from_model_checkpoint(
43-
name_or_path="uma-s-1",
44-
task_name="omat",
45-
device="cpu",
46-
seed=0,
47-
)
48-
49-
50-
test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test(
51-
test_name="fairchem_uma",
52-
model_fixture_name="eqv2_uma_model_pbc",
53-
calculator_fixture_name="fairchem_calculator",
54-
sim_state_names=consistency_test_simstate_fixtures[:-1],
55-
energy_rtol=5e-4, # NOTE: UMA model tolerances
56-
energy_atol=5e-4,
57-
force_rtol=5e-4,
58-
force_atol=5e-4,
59-
stress_rtol=5e-4,
60-
stress_atol=5e-4,
61-
)
62-
63-
test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test(
64-
test_name="fairchem_non_pbc_benzene",
65-
model_fixture_name="eqv2_uma_model_non_pbc",
66-
calculator_fixture_name="fairchem_calculator",
67-
sim_state_names=["benzene_sim_state"],
68-
energy_rtol=5e-4, # NOTE: UMA model tolerances
69-
energy_atol=5e-4,
70-
force_rtol=5e-4,
71-
force_atol=5e-4,
72-
stress_rtol=5e-4,
73-
stress_atol=5e-4,
74-
)
30+
# Removed calculator consistency tests since we're using predictor interface only
7531

7632

7733
test_fairchem_uma_model_outputs = pytest.mark.skipif(

torch_sim/models/fairchem.py

Lines changed: 45 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,10 @@
3030

3131

3232
try:
33-
from fairchem.core.calculate.ase_calculator import (
34-
FAIRChemCalculator,
35-
InferenceSettings,
36-
UMATask,
37-
)
33+
from fairchem.core import pretrained_mlip
34+
from fairchem.core.calculate.ase_calculator import UMATask
3835
from fairchem.core.common.utils import setup_imports, setup_logging
36+
from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch
3937

4038
except ImportError as exc:
4139
warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2)
@@ -71,10 +69,11 @@ class FairChemModel(torch.nn.Module, ModelInterface):
7169
checkpoint. It supports various model architectures and configurations supported by
7270
FairChem.
7371
74-
This version uses the modern fairchem-core-2.2.0+ API with FAIRChemCalculator.
72+
This version uses the efficient fairchem-core-2.2.0+ predictor API.
7573
7674
Attributes:
77-
calculator (FAIRChemCalculator): The underlying FairChem calculator
75+
predictor: The FairChem predictor for batch inference
76+
task_name (UMATask): Task type for the model
7877
_device (torch.device): Device where computation is performed
7978
_dtype (torch.dtype): Data type used for computation
8079
_compute_stress (bool): Whether to compute stress tensor
@@ -92,39 +91,32 @@ def __init__(
9291
*, # force remaining arguments to be keyword-only
9392
model_name: str | None = None,
9493
cpu: bool = False,
95-
seed: int = 41,
9694
dtype: torch.dtype | None = None,
9795
compute_stress: bool = False,
9896
task_name: UMATask | str | None = None,
99-
inference_settings: InferenceSettings | str = "default",
100-
overrides: dict | None = None,
10197
) -> None:
10298
"""Initialize the FairChemModel with specified configuration.
10399
104-
Uses the modern FAIRChemCalculator.from_model_checkpoint API for simplified
105-
model loading and configuration.
100+
Uses the efficient FairChem predictor interface for optimal performance.
106101
107102
Args:
108103
model (str | Path | None): Path to model checkpoint file
109104
neighbor_list_fn (Callable | None): Function to compute neighbor lists
110105
(not currently supported)
111106
model_name (str | None): Name of pretrained model to load
112107
cpu (bool): Whether to use CPU instead of GPU for computation
113-
seed (int): Random seed for reproducibility
114108
dtype (torch.dtype | None): Data type to use for computation
115109
compute_stress (bool): Whether to compute stress tensor
116110
task_name (UMATask | str | None): Task type for the model
117-
inference_settings (InferenceSettings | str): Inference configuration
118-
overrides (dict | None): Configuration overrides
119111
120112
Raises:
121113
RuntimeError: If both model_name and model are specified
122114
NotImplementedError: If custom neighbor list function is provided
123115
ValueError: If neither model nor model_name is provided
124116
125117
Notes:
126-
This uses the new fairchem-core-2.2.0+ API which is much simpler than
127-
the previous versions.
118+
This uses the efficient fairchem-core-2.2.0+ predictor API for
119+
optimal batch inference performance.
128120
"""
129121
setup_imports()
130122
setup_logging()
@@ -146,8 +138,6 @@ def __init__(
146138
"model_name and checkpoint_path were both specified, "
147139
"please use only one at a time"
148140
)
149-
# For fairchem-core-2.2.0+, model_name can be used directly
150-
# as it supports pretrained model names from available_models
151141
model = model_name
152142

153143
if model is None:
@@ -157,21 +147,15 @@ def __init__(
157147
if isinstance(task_name, str):
158148
task_name = UMATask(task_name)
159149

160-
# Use the new simplified API
150+
# Use the efficient predictor API for optimal performance
161151
device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu"
162-
163-
self.calculator = FAIRChemCalculator.from_model_checkpoint(
164-
name_or_path=str(model),
165-
task_name=task_name,
166-
inference_settings=inference_settings,
167-
overrides=overrides,
168-
device=device_str,
169-
seed=seed,
170-
)
171-
172152
self._device = torch.device(device_str)
153+
self.task_name = task_name
154+
155+
# Create efficient batch predictor for fast inference
156+
self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str)
173157

174-
# Determine implemented properties from the calculator
158+
# Determine implemented properties
175159
# This is a simplified approach - in practice you might want to
176160
# inspect the model configuration more carefully
177161
self.implemented_properties = ["energy", "forces"]
@@ -191,8 +175,8 @@ def device(self) -> torch.device:
191175
def forward(self, state: ts.SimState | StateDict) -> dict:
192176
"""Perform forward pass to compute energies, forces, and other properties.
193177
194-
Takes a simulation state and computes the properties implemented by the model,
195-
such as energy, forces, and stresses.
178+
Uses efficient batch inference with FairChem's native tensor interface for
179+
optimal performance on both single systems and large batches.
196180
197181
Args:
198182
state (SimState | StateDict): State object containing positions, cells,
@@ -206,27 +190,28 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
206190
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
207191
208192
Notes:
209-
This implementation uses the FAIRChemCalculator which expects ASE Atoms
210-
objects. The conversion is handled internally.
193+
This implementation uses FairChem's efficient batch predictor interface
194+
for optimal performance on both single systems and large batches.
211195
"""
212196
if isinstance(state, dict):
213197
state = ts.SimState(**state, masses=torch.ones_like(state["positions"]))
214198

215199
if state.device != self._device:
216200
state = state.to(self._device)
217201

218-
# Convert torch_sim SimState to ASE Atoms objects for FAIRChemCalculator
219-
from ase import Atoms
220-
221202
if state.batch is None:
222203
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)
223204

205+
# Convert SimState to AtomicData objects for efficient batch processing
206+
from ase import Atoms
207+
224208
natoms = torch.bincount(state.batch)
225-
atoms_list = []
209+
atomic_data_list = []
226210

227211
for i, (n, c) in enumerate(
228212
zip(natoms, torch.cumsum(natoms, dim=0), strict=False)
229213
):
214+
# Extract system data
230215
positions = state.positions[c - n : c].cpu().numpy()
231216
atomic_numbers = state.atomic_numbers[c - n : c].cpu().numpy()
232217
cell = (
@@ -235,51 +220,36 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
235220
else None
236221
)
237222

223+
# Create ASE Atoms object first
238224
atoms = Atoms(
239225
numbers=atomic_numbers,
240226
positions=positions,
241227
cell=cell,
242228
pbc=state.pbc if cell is not None else False,
243229
)
244-
atoms_list.append(atoms)
245230

246-
# Use FAIRChemCalculator to compute properties
247-
results = {}
248-
energies = []
249-
forces_list = []
250-
stress_list = []
231+
# Convert ASE Atoms to AtomicData with task_name
232+
atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name)
233+
atomic_data_list.append(atomic_data)
251234

252-
for atoms in atoms_list:
253-
atoms.calc = self.calculator
235+
# Create batch for efficient inference
236+
batch = atomicdata_list_to_batch(atomic_data_list)
237+
batch = batch.to(self._device)
254238

255-
# Get energy
256-
energy = atoms.get_potential_energy()
257-
energies.append(energy)
239+
# Run efficient batch prediction
240+
predictions = self.predictor.predict(batch)
258241

259-
# Get forces
260-
forces = atoms.get_forces()
261-
forces_list.append(
262-
torch.from_numpy(forces).to(self._device, dtype=self._dtype)
263-
)
264-
265-
# Get stress if requested
266-
if self._compute_stress:
267-
try:
268-
stress = atoms.get_stress(voigt=False) # 3x3 tensor
269-
stress_list.append(
270-
torch.from_numpy(stress).to(self._device, dtype=self._dtype)
271-
)
272-
except (RuntimeError, AttributeError, NotImplementedError):
273-
# If stress computation fails, fill with zeros
274-
stress_list.append(
275-
torch.zeros(3, 3, device=self._device, dtype=self._dtype)
276-
)
277-
278-
# Combine results
279-
results["energy"] = torch.tensor(energies, device=self._device, dtype=self._dtype)
280-
results["forces"] = torch.cat(forces_list, dim=0)
281-
282-
if self._compute_stress and stress_list:
283-
results["stress"] = torch.stack(stress_list, dim=0)
242+
# Convert predictions to torch_sim format
243+
results = {}
244+
results["energy"] = predictions["energy"].to(dtype=self._dtype)
245+
results["forces"] = predictions["forces"].to(dtype=self._dtype)
246+
247+
# Handle stress if requested and available
248+
if self._compute_stress and "stress" in predictions:
249+
stress = predictions["stress"].to(dtype=self._dtype)
250+
# Ensure stress has correct shape [batch_size, 3, 3]
251+
if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list):
252+
stress = stress.view(-1, 3, 3)
253+
results["stress"] = stress
284254

285255
return results

0 commit comments

Comments
 (0)