30
30
31
31
32
32
try :
33
- from fairchem .core .calculate .ase_calculator import (
34
- FAIRChemCalculator ,
35
- InferenceSettings ,
36
- UMATask ,
37
- )
33
+ from fairchem .core import pretrained_mlip
34
+ from fairchem .core .calculate .ase_calculator import UMATask
38
35
from fairchem .core .common .utils import setup_imports , setup_logging
36
+ from fairchem .core .datasets .atomic_data import AtomicData , atomicdata_list_to_batch
39
37
40
38
except ImportError as exc :
41
39
warnings .warn (f"FairChem import failed: { traceback .format_exc ()} " , stacklevel = 2 )
@@ -71,10 +69,11 @@ class FairChemModel(torch.nn.Module, ModelInterface):
71
69
checkpoint. It supports various model architectures and configurations supported by
72
70
FairChem.
73
71
74
- This version uses the modern fairchem-core-2.2.0+ API with FAIRChemCalculator .
72
+ This version uses the efficient fairchem-core-2.2.0+ predictor API .
75
73
76
74
Attributes:
77
- calculator (FAIRChemCalculator): The underlying FairChem calculator
75
+ predictor: The FairChem predictor for batch inference
76
+ task_name (UMATask): Task type for the model
78
77
_device (torch.device): Device where computation is performed
79
78
_dtype (torch.dtype): Data type used for computation
80
79
_compute_stress (bool): Whether to compute stress tensor
@@ -92,39 +91,32 @@ def __init__(
92
91
* , # force remaining arguments to be keyword-only
93
92
model_name : str | None = None ,
94
93
cpu : bool = False ,
95
- seed : int = 41 ,
96
94
dtype : torch .dtype | None = None ,
97
95
compute_stress : bool = False ,
98
96
task_name : UMATask | str | None = None ,
99
- inference_settings : InferenceSettings | str = "default" ,
100
- overrides : dict | None = None ,
101
97
) -> None :
102
98
"""Initialize the FairChemModel with specified configuration.
103
99
104
- Uses the modern FAIRChemCalculator.from_model_checkpoint API for simplified
105
- model loading and configuration.
100
+ Uses the efficient FairChem predictor interface for optimal performance.
106
101
107
102
Args:
108
103
model (str | Path | None): Path to model checkpoint file
109
104
neighbor_list_fn (Callable | None): Function to compute neighbor lists
110
105
(not currently supported)
111
106
model_name (str | None): Name of pretrained model to load
112
107
cpu (bool): Whether to use CPU instead of GPU for computation
113
- seed (int): Random seed for reproducibility
114
108
dtype (torch.dtype | None): Data type to use for computation
115
109
compute_stress (bool): Whether to compute stress tensor
116
110
task_name (UMATask | str | None): Task type for the model
117
- inference_settings (InferenceSettings | str): Inference configuration
118
- overrides (dict | None): Configuration overrides
119
111
120
112
Raises:
121
113
RuntimeError: If both model_name and model are specified
122
114
NotImplementedError: If custom neighbor list function is provided
123
115
ValueError: If neither model nor model_name is provided
124
116
125
117
Notes:
126
- This uses the new fairchem-core-2.2.0+ API which is much simpler than
127
- the previous versions .
118
+ This uses the efficient fairchem-core-2.2.0+ predictor API for
119
+ optimal batch inference performance .
128
120
"""
129
121
setup_imports ()
130
122
setup_logging ()
@@ -146,8 +138,6 @@ def __init__(
146
138
"model_name and checkpoint_path were both specified, "
147
139
"please use only one at a time"
148
140
)
149
- # For fairchem-core-2.2.0+, model_name can be used directly
150
- # as it supports pretrained model names from available_models
151
141
model = model_name
152
142
153
143
if model is None :
@@ -157,21 +147,15 @@ def __init__(
157
147
if isinstance (task_name , str ):
158
148
task_name = UMATask (task_name )
159
149
160
- # Use the new simplified API
150
+ # Use the efficient predictor API for optimal performance
161
151
device_str = "cpu" if cpu else "cuda" if torch .cuda .is_available () else "cpu"
162
-
163
- self .calculator = FAIRChemCalculator .from_model_checkpoint (
164
- name_or_path = str (model ),
165
- task_name = task_name ,
166
- inference_settings = inference_settings ,
167
- overrides = overrides ,
168
- device = device_str ,
169
- seed = seed ,
170
- )
171
-
172
152
self ._device = torch .device (device_str )
153
+ self .task_name = task_name
154
+
155
+ # Create efficient batch predictor for fast inference
156
+ self .predictor = pretrained_mlip .get_predict_unit (str (model ), device = device_str )
173
157
174
- # Determine implemented properties from the calculator
158
+ # Determine implemented properties
175
159
# This is a simplified approach - in practice you might want to
176
160
# inspect the model configuration more carefully
177
161
self .implemented_properties = ["energy" , "forces" ]
@@ -191,8 +175,8 @@ def device(self) -> torch.device:
191
175
def forward (self , state : ts .SimState | StateDict ) -> dict :
192
176
"""Perform forward pass to compute energies, forces, and other properties.
193
177
194
- Takes a simulation state and computes the properties implemented by the model,
195
- such as energy, forces, and stresses .
178
+ Uses efficient batch inference with FairChem's native tensor interface for
179
+ optimal performance on both single systems and large batches .
196
180
197
181
Args:
198
182
state (SimState | StateDict): State object containing positions, cells,
@@ -206,27 +190,28 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
206
190
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
207
191
208
192
Notes:
209
- This implementation uses the FAIRChemCalculator which expects ASE Atoms
210
- objects. The conversion is handled internally .
193
+ This implementation uses FairChem's efficient batch predictor interface
194
+ for optimal performance on both single systems and large batches .
211
195
"""
212
196
if isinstance (state , dict ):
213
197
state = ts .SimState (** state , masses = torch .ones_like (state ["positions" ]))
214
198
215
199
if state .device != self ._device :
216
200
state = state .to (self ._device )
217
201
218
- # Convert torch_sim SimState to ASE Atoms objects for FAIRChemCalculator
219
- from ase import Atoms
220
-
221
202
if state .batch is None :
222
203
state .batch = torch .zeros (state .positions .shape [0 ], dtype = torch .int )
223
204
205
+ # Convert SimState to AtomicData objects for efficient batch processing
206
+ from ase import Atoms
207
+
224
208
natoms = torch .bincount (state .batch )
225
- atoms_list = []
209
+ atomic_data_list = []
226
210
227
211
for i , (n , c ) in enumerate (
228
212
zip (natoms , torch .cumsum (natoms , dim = 0 ), strict = False )
229
213
):
214
+ # Extract system data
230
215
positions = state .positions [c - n : c ].cpu ().numpy ()
231
216
atomic_numbers = state .atomic_numbers [c - n : c ].cpu ().numpy ()
232
217
cell = (
@@ -235,51 +220,36 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
235
220
else None
236
221
)
237
222
223
+ # Create ASE Atoms object first
238
224
atoms = Atoms (
239
225
numbers = atomic_numbers ,
240
226
positions = positions ,
241
227
cell = cell ,
242
228
pbc = state .pbc if cell is not None else False ,
243
229
)
244
- atoms_list .append (atoms )
245
230
246
- # Use FAIRChemCalculator to compute properties
247
- results = {}
248
- energies = []
249
- forces_list = []
250
- stress_list = []
231
+ # Convert ASE Atoms to AtomicData with task_name
232
+ atomic_data = AtomicData .from_ase (atoms , task_name = self .task_name )
233
+ atomic_data_list .append (atomic_data )
251
234
252
- for atoms in atoms_list :
253
- atoms .calc = self .calculator
235
+ # Create batch for efficient inference
236
+ batch = atomicdata_list_to_batch (atomic_data_list )
237
+ batch = batch .to (self ._device )
254
238
255
- # Get energy
256
- energy = atoms .get_potential_energy ()
257
- energies .append (energy )
239
+ # Run efficient batch prediction
240
+ predictions = self .predictor .predict (batch )
258
241
259
- # Get forces
260
- forces = atoms .get_forces ()
261
- forces_list .append (
262
- torch .from_numpy (forces ).to (self ._device , dtype = self ._dtype )
263
- )
264
-
265
- # Get stress if requested
266
- if self ._compute_stress :
267
- try :
268
- stress = atoms .get_stress (voigt = False ) # 3x3 tensor
269
- stress_list .append (
270
- torch .from_numpy (stress ).to (self ._device , dtype = self ._dtype )
271
- )
272
- except (RuntimeError , AttributeError , NotImplementedError ):
273
- # If stress computation fails, fill with zeros
274
- stress_list .append (
275
- torch .zeros (3 , 3 , device = self ._device , dtype = self ._dtype )
276
- )
277
-
278
- # Combine results
279
- results ["energy" ] = torch .tensor (energies , device = self ._device , dtype = self ._dtype )
280
- results ["forces" ] = torch .cat (forces_list , dim = 0 )
281
-
282
- if self ._compute_stress and stress_list :
283
- results ["stress" ] = torch .stack (stress_list , dim = 0 )
242
+ # Convert predictions to torch_sim format
243
+ results = {}
244
+ results ["energy" ] = predictions ["energy" ].to (dtype = self ._dtype )
245
+ results ["forces" ] = predictions ["forces" ].to (dtype = self ._dtype )
246
+
247
+ # Handle stress if requested and available
248
+ if self ._compute_stress and "stress" in predictions :
249
+ stress = predictions ["stress" ].to (dtype = self ._dtype )
250
+ # Ensure stress has correct shape [batch_size, 3, 3]
251
+ if stress .dim () == 2 and stress .shape [0 ] == len (atomic_data_list ):
252
+ stress = stress .view (- 1 , 3 , 3 )
253
+ results ["stress" ] = stress
284
254
285
255
return results
0 commit comments