1
- """Wrapper for metatensor -based models in TorchSim.
1
+ """Wrapper for metatomic -based models in TorchSim.
2
2
3
- This module provides a TorchSim wrapper of metatensor models for computing
3
+ This module provides a TorchSim wrapper of metatomic models for computing
4
4
energies, forces, and stresses for atomistic systems, including batched computations
5
5
for multiple systems simultaneously.
6
6
7
- The MetatensorModel class adapts metatensor models to the ModelInterface protocol,
7
+ The MetatomicModel class adapts metatomic models to the ModelInterface protocol,
8
8
allowing them to be used within the broader torch_sim simulation framework.
9
9
10
10
Notes:
11
- This module depends on the metatensor -torch package.
11
+ This module depends on the metatomic -torch package.
12
12
"""
13
13
14
14
import traceback
17
17
from typing import Any
18
18
19
19
import torch
20
- import vesin .torch . metatensor
20
+ import vesin .metatomic
21
21
22
22
import torch_sim as ts
23
23
from torch_sim .models .interface import ModelInterface
24
24
from torch_sim .typing import StateDict
25
25
26
26
27
27
try :
28
- from metatensor .torch . atomistic import (
28
+ from metatomic .torch import (
29
29
ModelEvaluationOptions ,
30
30
ModelOutput ,
31
31
System ,
34
34
from metatrain .utils .io import load_model
35
35
36
36
except ImportError as exc :
37
- warnings .warn (f"Metatensor import failed: { traceback .format_exc ()} " , stacklevel = 2 )
37
+ warnings .warn (f"Metatomic import failed: { traceback .format_exc ()} " , stacklevel = 2 )
38
38
39
- class MetatensorModel (torch .nn .Module , ModelInterface ):
40
- """Metatensor model wrapper for torch_sim.
39
+ class MetatomicModel (torch .nn .Module , ModelInterface ):
40
+ """Metatomic model wrapper for torch_sim.
41
41
42
- This class is a placeholder for the MetatensorModel class.
43
- It raises an ImportError if metatensor is not installed.
42
+ This class is a placeholder for the MetatomicModel class.
43
+ It raises an ImportError if metatomic is not installed.
44
44
"""
45
45
46
46
def __init__ (self , err : ImportError = exc , * _args : Any , ** _kwargs : Any ) -> None :
47
47
"""Dummy init for type checking."""
48
48
raise err
49
49
50
50
51
- class MetatensorModel (torch .nn .Module , ModelInterface ):
52
- """Computes energies for a list of systems using a metatensor model.
51
+ class MetatomicModel (torch .nn .Module , ModelInterface ):
52
+ """Computes energies for a list of systems using a metatomic model.
53
53
54
- This class wraps a metatensor model to compute energies, forces, and stresses for
54
+ This class wraps a metatomic model to compute energies, forces, and stresses for
55
55
atomic systems within the TorchSim framework. It supports batched calculations
56
56
for multiple systems and handles the necessary transformations between
57
- TorchSim's data structures and metatensor 's expected inputs.
57
+ TorchSim's data structures and metatomic 's expected inputs.
58
58
59
59
Attributes:
60
60
...
@@ -70,14 +70,14 @@ def __init__(
70
70
compute_forces : bool = True ,
71
71
compute_stress : bool = True ,
72
72
) -> None :
73
- """Initialize the metatensor model for energy, force and stress calculations.
73
+ """Initialize the metatomic model for energy, force and stress calculations.
74
74
75
- Sets up a metatensor model for energy, force, and stress calculations within
75
+ Sets up a metatomic model for energy, force, and stress calculations within
76
76
the TorchSim framework. The model can be initialized with atomic numbers
77
77
and batch indices, or these can be provided during the forward pass.
78
78
79
79
Args:
80
- model (str | Path | None): Path to the metatensor model file or a
80
+ model (str | Path | None): Path to the metatomic model file or a
81
81
pre-defined model name. Currently only "pet-mad"
82
82
(https://arxiv.org/abs/2503.14118) is supported as a pre-defined model.
83
83
If None, defaults to "pet-mad".
@@ -155,7 +155,7 @@ def forward( # noqa: C901, PLR0915
155
155
"""Compute energies, forces, and stresses for the given atomic systems.
156
156
157
157
Processes the provided state information and computes energies, forces, and
158
- stresses using the underlying metatensor model. Handles batched calculations for
158
+ stresses using the underlying metatomic model. Handles batched calculations for
159
159
multiple systems as well as constructing the necessary neighbor lists.
160
160
161
161
Args:
@@ -175,21 +175,21 @@ def forward( # noqa: C901, PLR0915
175
175
state = ts .SimState (** state , masses = torch .ones_like (state ["positions" ]))
176
176
177
177
# Input validation is already done inside the forward method of the
178
- # MetatensorAtomisticModel class, so we don't need to do it again here.
178
+ # AtomisticModel class, so we don't need to do it again here.
179
179
180
180
atomic_numbers = state .atomic_numbers
181
181
cell = state .row_vector_cell
182
182
positions = state .positions
183
183
pbc = state .pbc
184
184
185
- # Check dtype (metatensor models require a specific input dtype)
185
+ # Check dtype (metatomic models require a specific input dtype)
186
186
if positions .dtype != self ._dtype :
187
187
raise TypeError (
188
188
f"Positions dtype { positions .dtype } does not match model dtype "
189
189
f"{ self ._dtype } "
190
190
)
191
191
192
- # Compared to other models, metatensor models have two peculiarities:
192
+ # Compared to other models, metatomic models have two peculiarities:
193
193
# - different structures are fed to the models separately as a list of System
194
194
# objects, and not as a single graph-like batch
195
195
# - the model does not compute forces and stresses itself, but rather the
@@ -232,7 +232,7 @@ def forward( # noqa: C901, PLR0915
232
232
233
233
# move data to CPU because vesin only supports CPU for now
234
234
systems = [system .to (device = "cpu" ) for system in systems ]
235
- vesin .torch . metatensor .compute_requested_neighbors (
235
+ vesin .metatomic .compute_requested_neighbors (
236
236
systems , system_length_unit = "Angstrom" , model = self ._model
237
237
)
238
238
# move back to the proper device
0 commit comments