Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5356d0b
fea: use batched vdot
CompRhys May 22, 2025
fa0830a
clean: remove ai slop
CompRhys May 23, 2025
53f6839
clean: further attempts to clean but still not matching PR
CompRhys May 23, 2025
75ca9a5
fix: dr is vdt rather than fdt
CompRhys May 23, 2025
971a4f5
typing: fix typing issue
CompRhys May 23, 2025
cc94fac
wip: still not sure where the difference is now
CompRhys May 23, 2025
4a494e6
update forces per comment
CompRhys May 27, 2025
d1de14a
Fix ASE pos only implementation
t-reents May 28, 2025
beeb4df
Fix torch-sim ASE-FIRE (Frechet Cell)
t-reents May 28, 2025
0ba531c
linting
CompRhys May 28, 2025
183e19d
test: still differ significantly after step 1 for distorted structures
CompRhys May 28, 2025
0aae837
Merge remote-tracking branch 'origin/main' into fix/ase-torch-sim
CompRhys May 30, 2025
25cf6e2
Fix test comparing ASE and torch-sim optimization
t-reents Jun 2, 2025
c0afdd2
Fix `optimizers` when using `UnitCellFilter`
t-reents Jun 2, 2025
e49c178
fix test_optimize_fire
janosh Jun 3, 2025
66871d5
allow FireState.velocities = None since it's being set to None in mul…
janosh Jun 3, 2025
85380ad
safer `batched_vdot`: check dimensionality of input tensors `y` and `…
janosh Jun 3, 2025
2e54afe
generate_force_convergence_fn raise informative error on needed but m…
janosh Jun 3, 2025
4484465
pascal case VALID_FIRE_CELL_STATES->AnyFireCellState and fix non-f-st…
janosh Jun 3, 2025
ff03101
fix FireState TypeError: non-default argument 'dt' follows default ar…
janosh Jun 3, 2025
3458288
allow None but don't set default for state.velocities
janosh Jun 3, 2025
d441bf3
fix bad merge conflict resolution
janosh Jun 3, 2025
8deb1c8
tweaks
janosh Jun 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/link-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,3 @@ jobs:
with:
# ignore ipynb links since they're generated on the fly
args: --exclude-path dist --exclude '\.ipynb$' --accept 100..=103,200..=299,403,429,500 -- ./**/*.{md,py,yml,json}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
6 changes: 4 additions & 2 deletions examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def run_optimization_ts( # noqa: PLR0915
convergence_steps = torch.full(
(total_structures,), -1, dtype=torch.long, device=device
)
convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol)
convergence_fn = ts.generate_force_convergence_fn(
force_tol=force_tol, include_cell_forces=ts_use_frechet
)
converged_tensor_global = torch.zeros(
total_structures, dtype=torch.bool, device=device
)
Expand All @@ -194,7 +196,7 @@ def run_optimization_ts( # noqa: PLR0915
current_indices_list, dtype=torch.long, device=device
)

steps_this_round = 10
steps_this_round = 1
for _ in range(steps_this_round):
opt_state = update_fn_opt(opt_state)
global_step += steps_this_round
Expand Down
223 changes: 142 additions & 81 deletions tests/test_optimizers_vs_ase.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
from typing import TYPE_CHECKING, Any

import pytest
import torch
from ase.filters import FrechetCellFilter, UnitCellFilter
from ase.optimize import FIRE
from pymatgen.analysis.structure_matcher import StructureMatcher

import torch_sim as ts
from torch_sim.io import state_to_atoms
from torch_sim.io import atoms_to_state, state_to_atoms, state_to_structures
from torch_sim.models.mace import MaceModel
from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire

Expand All @@ -16,6 +16,65 @@
from mace.calculators import MACECalculator


def _compare_ase_and_ts_states(
ts_current_system_state: ts.state.SimState,
filtered_ase_atoms_for_run: Any,
tolerances: dict[str, float],
current_test_id: str,
) -> None:
structure_matcher = StructureMatcher(
ltol=tolerances["lattice_tol"],
stol=tolerances["site_tol"],
angle_tol=tolerances["angle_tol"],
scale=False,
)

tensor_kwargs = {
"device": ts_current_system_state.device,
"dtype": ts_current_system_state.dtype,
}

final_custom_energy = ts_current_system_state.energy.item()
final_custom_forces_max = (
torch.norm(ts_current_system_state.forces, dim=-1).max().item()
)

# Convert torch-sim state to pymatgen Structure
ts_structure = state_to_structures(ts_current_system_state)[0]

# Convert ASE atoms to pymatgen Structure
final_ase_atoms = filtered_ase_atoms_for_run.atoms
final_ase_energy = final_ase_atoms.get_potential_energy()
ase_forces_raw = final_ase_atoms.get_forces()
final_ase_forces_max = torch.norm(
torch.tensor(ase_forces_raw, **tensor_kwargs), dim=-1
).max()
ts_state = atoms_to_state(final_ase_atoms, **tensor_kwargs)
ase_structure = state_to_structures(ts_state)[0]

# Compare energies
energy_diff = abs(final_custom_energy - final_ase_energy)
assert energy_diff < tolerances["energy"], (
f"{current_test_id}: Final energies differ significantly: "
f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, "
f"Diff={energy_diff:.2e}"
)

# Compare forces
force_max_diff = abs(final_custom_forces_max - final_ase_forces_max)
assert force_max_diff < tolerances["force_max"], (
f"{current_test_id}: Max forces differ significantly: "
f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, "
f"Diff={force_max_diff:.2e}"
)

# Compare structures using StructureMatcher
assert structure_matcher.fit(ts_structure, ase_structure), (
f"{current_test_id}: Structures do not match according to StructureMatcher\n"
f"{ts_structure=}\n{ase_structure=}"
)


def _run_and_compare_optimizers(
initial_sim_state_fixture: ts.state.SimState,
torchsim_mace_mpa: MaceModel,
Expand All @@ -32,14 +91,7 @@ def _run_and_compare_optimizers(
dtype = torch.float64
device = torchsim_mace_mpa.device

ts_current_system_state = copy.deepcopy(initial_sim_state_fixture).to(
dtype=dtype, device=device
)
ts_current_system_state.positions = (
ts_current_system_state.positions.detach().requires_grad_()
)
ts_current_system_state.cell = ts_current_system_state.cell.detach().requires_grad_()
ts_optimizer_state = None
ts_current_system_state = initial_sim_state_fixture.clone()

optimizer_builders = {
"frechet": frechet_cell_fire,
Expand All @@ -54,89 +106,53 @@ def _run_and_compare_optimizers(
)

ase_atoms_for_run = state_to_atoms(
copy.deepcopy(initial_sim_state_fixture).to(dtype=dtype, device=device)
initial_sim_state_fixture.clone().to(dtype=dtype, device=device)
)[0]
ase_atoms_for_run.calc = ase_mace_mpa
filtered_ase_atoms_for_run = ase_filter_class(ase_atoms_for_run)
ase_optimizer = FIRE(filtered_ase_atoms_for_run, logfile=None)

last_checkpoint_step_count = 0
convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol)
convergence_fn = ts.generate_force_convergence_fn(
force_tol=force_tol, include_cell_forces=True
)

results = torchsim_mace_mpa(ts_current_system_state)
ts_initial_system_state = ts_current_system_state.clone()
ts_initial_system_state.forces = results["forces"]
ts_initial_system_state.energy = results["energy"]
ase_atoms_for_run.calc.calculate(ase_atoms_for_run)

_compare_ase_and_ts_states(
ts_initial_system_state,
filtered_ase_atoms_for_run,
tolerances,
f"{test_id_prefix} (Initial)",
)

for checkpoint_step in checkpoints:
steps_for_current_segment = checkpoint_step - last_checkpoint_step_count

if steps_for_current_segment > 0:
# Ensure requires_grad is set for the input to ts.optimize
# ts.optimize is expected to return a state suitable for further optimization
# if optimizer_state is passed.
ts_current_system_state.positions = (
ts_current_system_state.positions.detach().requires_grad_()
)
ts_current_system_state.cell = (
ts_current_system_state.cell.detach().requires_grad_()
)
new_ts_state_and_optimizer_state = ts.optimize(
updated_ts_state = ts.optimize(
system=ts_current_system_state,
model=torchsim_mace_mpa,
optimizer=optimizer_callable_for_ts_optimize,
max_steps=steps_for_current_segment,
convergence_fn=convergence_fn,
optimizer_state=ts_optimizer_state,
steps_between_swaps=1,
)
ts_current_system_state = new_ts_state_and_optimizer_state
ts_optimizer_state = new_ts_state_and_optimizer_state
ts_current_system_state = updated_ts_state.clone()

ase_optimizer.run(fmax=force_tol, steps=steps_for_current_segment)

current_test_id = f"{test_id_prefix} (Step {checkpoint_step})"

final_custom_energy = ts_current_system_state.energy.item()
final_custom_forces_max = (
torch.norm(ts_current_system_state.forces, dim=-1).max().item()
)
final_custom_positions = ts_current_system_state.positions.detach()
final_custom_cell = ts_current_system_state.row_vector_cell.squeeze(0).detach()

final_ase_atoms = filtered_ase_atoms_for_run.atoms
final_ase_energy = final_ase_atoms.get_potential_energy()
ase_forces_raw = final_ase_atoms.get_forces()
final_ase_forces_max = torch.norm(
torch.tensor(ase_forces_raw, device=device, dtype=dtype), dim=-1
).max()
final_ase_positions = torch.tensor(
final_ase_atoms.get_positions(), device=device, dtype=dtype
)
final_ase_cell = torch.tensor(
final_ase_atoms.get_cell(), device=device, dtype=dtype
)

energy_diff = abs(final_custom_energy - final_ase_energy)
assert energy_diff < tolerances["energy"], (
f"{current_test_id}: Final energies differ significantly: "
f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, "
f"Diff={energy_diff:.2e}"
)

avg_displacement = (
torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item()
)
assert avg_displacement < tolerances["pos"], (
f"{current_test_id}: Final positions differ ({avg_displacement=:.4f})"
)

cell_diff = torch.norm(final_custom_cell - final_ase_cell).item()
assert cell_diff < tolerances["cell"], (
f"{current_test_id}: Final cell matrices differ (Frobenius norm: "
f"{cell_diff:.4f})\nTorch-sim Cell:\n{final_custom_cell}"
f"\nASE Cell:\n{final_ase_cell}"
)

force_max_diff = abs(final_custom_forces_max - final_ase_forces_max)
assert force_max_diff < tolerances["force_max"], (
f"{current_test_id}: Max forces differ significantly: "
f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, "
f"Diff={force_max_diff:.2e}"
_compare_ase_and_ts_states(
ts_current_system_state,
filtered_ase_atoms_for_run,
tolerances,
current_test_id,
)

last_checkpoint_step_count = checkpoint_step
Expand All @@ -157,47 +173,92 @@ def _run_and_compare_optimizers(
"rattled_sio2_sim_state",
"frechet",
FrechetCellFilter,
[33, 66, 100],
[1, 33, 66, 100],
0.02,
{"energy": 1e-2, "pos": 1.5e-2, "cell": 1.8e-2, "force_max": 1.5e-1},
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 1e-1,
},
"SiO2 (Frechet)",
),
(
"osn2_sim_state",
"frechet",
FrechetCellFilter,
[16, 33, 50],
[1, 16, 33, 50],
0.02,
{"energy": 1e-4, "pos": 1e-3, "cell": 1.8e-3, "force_max": 5e-2},
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 1e-1,
},
"OsN2 (Frechet)",
),
(
"distorted_fcc_al_conventional_sim_state",
"frechet",
FrechetCellFilter,
[33, 66, 100],
[1, 33, 66, 100],
0.01,
{"energy": 1e-2, "pos": 5e-3, "cell": 2e-2, "force_max": 5e-2},
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 5e-1,
},
"Triclinic Al (Frechet)",
),
(
"distorted_fcc_al_conventional_sim_state",
"unit_cell",
UnitCellFilter,
[33, 66, 100],
[1, 33, 66, 100],
0.01,
{"energy": 1e-2, "pos": 3e-2, "cell": 1e-1, "force_max": 5e-2},
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 5e-1,
},
"Triclinic Al (UnitCell)",
),
(
"rattled_sio2_sim_state",
"unit_cell",
UnitCellFilter,
[33, 66, 100],
[1, 33, 66, 100],
0.02,
{"energy": 1.5e-2, "pos": 2.5e-2, "cell": 5e-2, "force_max": 0.25},
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 1e-1,
},
"SiO2 (UnitCell)",
),
(
"osn2_sim_state",
"unit_cell",
UnitCellFilter,
[1, 16, 33, 50],
0.02,
{
"energy": 1e-2,
"force_max": 5e-2,
"lattice_tol": 3e-2,
"site_tol": 3e-2,
"angle_tol": 1e-1,
},
"OsN2 (UnitCell)",
),
],
)
def test_optimizer_vs_ase_parametrized(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_optimize_fire(

# Check force convergence
assert torch.all(final_state.forces < 3e-1)
assert energies.shape[0] > 10
assert energies.shape[0] >= 10
assert energies[0] > energies[-1]
assert not torch.allclose(original_state.positions, final_state.positions)

Expand Down Expand Up @@ -327,7 +327,8 @@ def test_default_converged_fn(
with TorchSimTrajectory(traj_file) as traj:
energies = traj.get_array("energy")

assert energies[-3] > energies[-1]
# Check that overall energy decreases (first to last)
assert energies[0] > energies[-1]
assert not torch.allclose(original_state.positions, final_state.positions)


Expand Down
4 changes: 3 additions & 1 deletion torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,9 @@ def next_batch(
# Increment attempt counters and check for max attempts in a single loop
for cur_idx, abs_idx in enumerate(self.current_idx):
self.swap_attempts[abs_idx] += 1
if self.max_attempts and (self.swap_attempts[abs_idx] >= self.max_attempts):
if self.max_attempts is not None and (
self.swap_attempts[abs_idx] >= self.max_attempts
):
# Force convergence for states that have reached max attempts
convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003

Expand Down
Loading