Skip to content

Commit 8991c89

Browse files
committed
more FairChem tests with UMA model and batch processing
- parameterized tests for different UMA task names and system configs - tests for homogeneous and heterogeneous batching, ensuring correct energy and force outputs - stress tensor computation tests with conditional checks - test error handling for empty batches
1 parent 0db53df commit 8991c89

File tree

1 file changed

+194
-6
lines changed

1 file changed

+194
-6
lines changed

tests/models/test_fairchem.py

Lines changed: 194 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55

66

77
try:
8+
from collections.abc import Callable
9+
10+
from ase.build import bulk, fcc100, molecule
811
from huggingface_hub.utils._auth import get_token
912

13+
import torch_sim as ts
1014
from torch_sim.models.fairchem import FairChemModel
1115

1216
except ImportError:
@@ -15,19 +19,203 @@
1519

1620
@pytest.fixture
1721
def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel:
18-
"""Use the UMA model which is available in fairchem-core-2.2.0+."""
22+
"""UMA model for periodic boundary condition systems."""
1923
cpu = device.type == "cpu"
2024
return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
2125

2226

23-
@pytest.fixture
24-
def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel:
25-
"""Use the UMA model for non-PBC systems."""
27+
# Removed calculator consistency tests since we're using predictor interface only
28+
29+
30+
@pytest.mark.skipif(
31+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
32+
)
33+
@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"])
34+
def test_task_initialization(task_name: str) -> None:
35+
"""Test that different UMA task names work correctly."""
36+
model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True)
37+
assert model.task_name.value == task_name
38+
assert hasattr(model, "predictor")
39+
40+
41+
@pytest.mark.skipif(
42+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
43+
)
44+
@pytest.mark.parametrize(
45+
("task_name", "systems_func"),
46+
[
47+
(
48+
"omat",
49+
lambda: [
50+
bulk("Si", "diamond", a=5.43),
51+
bulk("Al", "fcc", a=4.05),
52+
bulk("Fe", "bcc", a=2.87),
53+
bulk("Cu", "fcc", a=3.61),
54+
],
55+
),
56+
(
57+
"omol",
58+
lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")],
59+
),
60+
],
61+
)
62+
def test_homogeneous_batching(
63+
task_name: str, systems_func: Callable, device: torch.device, dtype: torch.dtype
64+
) -> None:
65+
"""Test batching multiple systems with the same task."""
66+
systems = systems_func()
67+
68+
# Add molecular properties for molecules
69+
if task_name == "omol":
70+
for mol in systems:
71+
mol.info.update({"charge": 0, "spin": 1})
72+
73+
model = FairChemModel(
74+
model=None, model_name="uma-s-1", task_name=task_name, cpu=device.type == "cpu"
75+
)
76+
state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)
77+
results = model(state)
78+
79+
# Check batch dimensions
80+
assert results["energy"].shape == (4,)
81+
assert results["forces"].shape[0] == sum(len(s) for s in systems)
82+
assert results["forces"].shape[1] == 3
83+
84+
# Check that different systems have different energies
85+
energies = results["energy"]
86+
unique_energies = torch.unique(energies, dim=0)
87+
assert len(unique_energies) > 1, "Different systems should have different energies"
88+
89+
90+
@pytest.mark.skipif(
91+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
92+
)
93+
def test_heterogeneous_tasks(device: torch.device, dtype: torch.dtype) -> None:
94+
"""Test different task types work with appropriate systems."""
95+
# Test molecule, material, and catalysis systems separately
96+
test_cases = [
97+
("omol", [molecule("H2O")]),
98+
("omat", [bulk("Pt", cubic=True)]),
99+
("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]),
100+
]
101+
102+
for task_name, systems in test_cases:
103+
if task_name == "omol":
104+
systems[0].info.update({"charge": 0, "spin": 1})
105+
106+
model = FairChemModel(
107+
model=None,
108+
model_name="uma-s-1",
109+
task_name=task_name,
110+
cpu=device.type == "cpu",
111+
)
112+
state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)
113+
results = model(state)
114+
115+
assert "energy" in results
116+
assert "forces" in results
117+
assert results["energy"].shape[0] == 1
118+
assert results["forces"].dim() == 2
119+
assert results["forces"].shape[1] == 3
120+
121+
122+
@pytest.mark.skipif(
123+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
124+
)
125+
@pytest.mark.parametrize(
126+
("systems_func", "expected_count"),
127+
[
128+
(lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system
129+
(
130+
lambda: [
131+
bulk("H", "bcc", a=2.0),
132+
bulk("Li", "bcc", a=3.0),
133+
bulk("Si", "diamond", a=5.43),
134+
bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)),
135+
],
136+
4,
137+
), # Mixed sizes
138+
(
139+
lambda: [
140+
bulk(element, "fcc", a=4.0)
141+
for element in ["Al", "Cu", "Ni", "Pd", "Pt"] * 3
142+
],
143+
15,
144+
), # Large batch
145+
],
146+
)
147+
def test_batch_size_variations(
148+
systems_func: Callable, expected_count: int, device: torch.device, dtype: torch.dtype
149+
) -> None:
150+
"""Test batching with different numbers and sizes of systems."""
151+
systems = systems_func()
152+
153+
model = FairChemModel(
154+
model=None, model_name="uma-s-1", task_name="omat", cpu=device.type == "cpu"
155+
)
156+
state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)
157+
results = model(state)
158+
159+
assert results["energy"].shape == (expected_count,)
160+
assert results["forces"].shape[0] == sum(len(s) for s in systems)
161+
assert results["forces"].shape[1] == 3
162+
assert torch.isfinite(results["energy"]).all()
163+
assert torch.isfinite(results["forces"]).all()
164+
165+
166+
@pytest.mark.skipif(
167+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
168+
)
169+
@pytest.mark.parametrize("compute_stress", [True, False])
170+
def test_stress_computation(
171+
*, compute_stress: bool, device: torch.device, dtype: torch.dtype
172+
) -> None:
173+
"""Test stress tensor computation."""
174+
systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]
175+
176+
model = FairChemModel(
177+
model=None,
178+
model_name="uma-s-1",
179+
task_name="omat",
180+
cpu=device.type == "cpu",
181+
compute_stress=compute_stress,
182+
)
183+
state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)
184+
results = model(state)
185+
186+
if compute_stress:
187+
assert "stress" in results
188+
assert results["stress"].shape == (2, 3, 3)
189+
assert torch.isfinite(results["stress"]).all()
190+
else:
191+
assert "stress" not in results
192+
193+
194+
@pytest.mark.skipif(
195+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
196+
)
197+
def test_device_consistency(dtype: torch.dtype) -> None:
198+
"""Test device consistency between model and data."""
199+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26200
cpu = device.type == "cpu"
27-
return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
28201

202+
model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
203+
system = bulk("Si", "diamond", a=5.43)
204+
state = ts.io.atoms_to_state([system], device=device, dtype=dtype)
29205

30-
# Removed calculator consistency tests since we're using predictor interface only
206+
results = model(state)
207+
assert results["energy"].device == device
208+
assert results["forces"].device == device
209+
210+
211+
@pytest.mark.skipif(
212+
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
213+
)
214+
def test_empty_batch_error() -> None:
215+
"""Test that empty batches raise appropriate errors."""
216+
model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True)
217+
with pytest.raises((ValueError, RuntimeError, IndexError)):
218+
model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32))
31219

32220

33221
test_fairchem_uma_model_outputs = pytest.mark.skipif(

0 commit comments

Comments
 (0)