|
5 | 5 |
|
6 | 6 |
|
7 | 7 | try:
|
| 8 | + from collections.abc import Callable |
| 9 | + |
| 10 | + from ase.build import bulk, fcc100, molecule |
8 | 11 | from huggingface_hub.utils._auth import get_token
|
9 | 12 |
|
| 13 | + import torch_sim as ts |
10 | 14 | from torch_sim.models.fairchem import FairChemModel
|
11 | 15 |
|
12 | 16 | except ImportError:
|
|
15 | 19 |
|
16 | 20 | @pytest.fixture
|
17 | 21 | def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel:
|
18 |
| - """Use the UMA model which is available in fairchem-core-2.2.0+.""" |
| 22 | + """UMA model for periodic boundary condition systems.""" |
19 | 23 | cpu = device.type == "cpu"
|
20 | 24 | return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
|
21 | 25 |
|
22 | 26 |
|
23 |
| -@pytest.fixture |
24 |
| -def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel: |
25 |
| - """Use the UMA model for non-PBC systems.""" |
| 27 | +# Removed calculator consistency tests since we're using predictor interface only |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.skipif( |
| 31 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 32 | +) |
| 33 | +@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) |
| 34 | +def test_task_initialization(task_name: str) -> None: |
| 35 | + """Test that different UMA task names work correctly.""" |
| 36 | + model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True) |
| 37 | + assert model.task_name.value == task_name |
| 38 | + assert hasattr(model, "predictor") |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.skipif( |
| 42 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 43 | +) |
| 44 | +@pytest.mark.parametrize( |
| 45 | + ("task_name", "systems_func"), |
| 46 | + [ |
| 47 | + ( |
| 48 | + "omat", |
| 49 | + lambda: [ |
| 50 | + bulk("Si", "diamond", a=5.43), |
| 51 | + bulk("Al", "fcc", a=4.05), |
| 52 | + bulk("Fe", "bcc", a=2.87), |
| 53 | + bulk("Cu", "fcc", a=3.61), |
| 54 | + ], |
| 55 | + ), |
| 56 | + ( |
| 57 | + "omol", |
| 58 | + lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], |
| 59 | + ), |
| 60 | + ], |
| 61 | +) |
| 62 | +def test_homogeneous_batching( |
| 63 | + task_name: str, systems_func: Callable, device: torch.device, dtype: torch.dtype |
| 64 | +) -> None: |
| 65 | + """Test batching multiple systems with the same task.""" |
| 66 | + systems = systems_func() |
| 67 | + |
| 68 | + # Add molecular properties for molecules |
| 69 | + if task_name == "omol": |
| 70 | + for mol in systems: |
| 71 | + mol.info.update({"charge": 0, "spin": 1}) |
| 72 | + |
| 73 | + model = FairChemModel( |
| 74 | + model=None, model_name="uma-s-1", task_name=task_name, cpu=device.type == "cpu" |
| 75 | + ) |
| 76 | + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) |
| 77 | + results = model(state) |
| 78 | + |
| 79 | + # Check batch dimensions |
| 80 | + assert results["energy"].shape == (4,) |
| 81 | + assert results["forces"].shape[0] == sum(len(s) for s in systems) |
| 82 | + assert results["forces"].shape[1] == 3 |
| 83 | + |
| 84 | + # Check that different systems have different energies |
| 85 | + energies = results["energy"] |
| 86 | + unique_energies = torch.unique(energies, dim=0) |
| 87 | + assert len(unique_energies) > 1, "Different systems should have different energies" |
| 88 | + |
| 89 | + |
| 90 | +@pytest.mark.skipif( |
| 91 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 92 | +) |
| 93 | +def test_heterogeneous_tasks(device: torch.device, dtype: torch.dtype) -> None: |
| 94 | + """Test different task types work with appropriate systems.""" |
| 95 | + # Test molecule, material, and catalysis systems separately |
| 96 | + test_cases = [ |
| 97 | + ("omol", [molecule("H2O")]), |
| 98 | + ("omat", [bulk("Pt", cubic=True)]), |
| 99 | + ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), |
| 100 | + ] |
| 101 | + |
| 102 | + for task_name, systems in test_cases: |
| 103 | + if task_name == "omol": |
| 104 | + systems[0].info.update({"charge": 0, "spin": 1}) |
| 105 | + |
| 106 | + model = FairChemModel( |
| 107 | + model=None, |
| 108 | + model_name="uma-s-1", |
| 109 | + task_name=task_name, |
| 110 | + cpu=device.type == "cpu", |
| 111 | + ) |
| 112 | + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) |
| 113 | + results = model(state) |
| 114 | + |
| 115 | + assert "energy" in results |
| 116 | + assert "forces" in results |
| 117 | + assert results["energy"].shape[0] == 1 |
| 118 | + assert results["forces"].dim() == 2 |
| 119 | + assert results["forces"].shape[1] == 3 |
| 120 | + |
| 121 | + |
| 122 | +@pytest.mark.skipif( |
| 123 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 124 | +) |
| 125 | +@pytest.mark.parametrize( |
| 126 | + ("systems_func", "expected_count"), |
| 127 | + [ |
| 128 | + (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system |
| 129 | + ( |
| 130 | + lambda: [ |
| 131 | + bulk("H", "bcc", a=2.0), |
| 132 | + bulk("Li", "bcc", a=3.0), |
| 133 | + bulk("Si", "diamond", a=5.43), |
| 134 | + bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), |
| 135 | + ], |
| 136 | + 4, |
| 137 | + ), # Mixed sizes |
| 138 | + ( |
| 139 | + lambda: [ |
| 140 | + bulk(element, "fcc", a=4.0) |
| 141 | + for element in ["Al", "Cu", "Ni", "Pd", "Pt"] * 3 |
| 142 | + ], |
| 143 | + 15, |
| 144 | + ), # Large batch |
| 145 | + ], |
| 146 | +) |
| 147 | +def test_batch_size_variations( |
| 148 | + systems_func: Callable, expected_count: int, device: torch.device, dtype: torch.dtype |
| 149 | +) -> None: |
| 150 | + """Test batching with different numbers and sizes of systems.""" |
| 151 | + systems = systems_func() |
| 152 | + |
| 153 | + model = FairChemModel( |
| 154 | + model=None, model_name="uma-s-1", task_name="omat", cpu=device.type == "cpu" |
| 155 | + ) |
| 156 | + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) |
| 157 | + results = model(state) |
| 158 | + |
| 159 | + assert results["energy"].shape == (expected_count,) |
| 160 | + assert results["forces"].shape[0] == sum(len(s) for s in systems) |
| 161 | + assert results["forces"].shape[1] == 3 |
| 162 | + assert torch.isfinite(results["energy"]).all() |
| 163 | + assert torch.isfinite(results["forces"]).all() |
| 164 | + |
| 165 | + |
| 166 | +@pytest.mark.skipif( |
| 167 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 168 | +) |
| 169 | +@pytest.mark.parametrize("compute_stress", [True, False]) |
| 170 | +def test_stress_computation( |
| 171 | + *, compute_stress: bool, device: torch.device, dtype: torch.dtype |
| 172 | +) -> None: |
| 173 | + """Test stress tensor computation.""" |
| 174 | + systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] |
| 175 | + |
| 176 | + model = FairChemModel( |
| 177 | + model=None, |
| 178 | + model_name="uma-s-1", |
| 179 | + task_name="omat", |
| 180 | + cpu=device.type == "cpu", |
| 181 | + compute_stress=compute_stress, |
| 182 | + ) |
| 183 | + state = ts.io.atoms_to_state(systems, device=device, dtype=dtype) |
| 184 | + results = model(state) |
| 185 | + |
| 186 | + if compute_stress: |
| 187 | + assert "stress" in results |
| 188 | + assert results["stress"].shape == (2, 3, 3) |
| 189 | + assert torch.isfinite(results["stress"]).all() |
| 190 | + else: |
| 191 | + assert "stress" not in results |
| 192 | + |
| 193 | + |
| 194 | +@pytest.mark.skipif( |
| 195 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 196 | +) |
| 197 | +def test_device_consistency(dtype: torch.dtype) -> None: |
| 198 | + """Test device consistency between model and data.""" |
| 199 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
26 | 200 | cpu = device.type == "cpu"
|
27 |
| - return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) |
28 | 201 |
|
| 202 | + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) |
| 203 | + system = bulk("Si", "diamond", a=5.43) |
| 204 | + state = ts.io.atoms_to_state([system], device=device, dtype=dtype) |
29 | 205 |
|
30 |
| -# Removed calculator consistency tests since we're using predictor interface only |
| 206 | + results = model(state) |
| 207 | + assert results["energy"].device == device |
| 208 | + assert results["forces"].device == device |
| 209 | + |
| 210 | + |
| 211 | +@pytest.mark.skipif( |
| 212 | + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" |
| 213 | +) |
| 214 | +def test_empty_batch_error() -> None: |
| 215 | + """Test that empty batches raise appropriate errors.""" |
| 216 | + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True) |
| 217 | + with pytest.raises((ValueError, RuntimeError, IndexError)): |
| 218 | + model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32)) |
31 | 219 |
|
32 | 220 |
|
33 | 221 | test_fairchem_uma_model_outputs = pytest.mark.skipif(
|
|
0 commit comments