From 5356d0bb018aa9ede7622d5a787d109a184aeb17 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 16:44:19 -0400 Subject: [PATCH 01/22] fea: use batched vdot --------- Co-authored-by: Janosh Riebesell --- torch_sim/math.py | 25 +++++++++++++++++++++ torch_sim/optimizers.py | 48 ++++++++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/torch_sim/math.py b/torch_sim/math.py index fd9182d2..0adfbe0a 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -987,3 +987,28 @@ def matrix_log_33( print(msg) # Fall back to scipy implementation return matrix_log_scipy(matrix).to(sim_dtype) + + +def batched_vdot( + x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor +) -> torch.Tensor: + """Computes batched vdot (sum of element-wise product) for groups of vectors. + If is_sum_sq is True, computes sum of x_i * x_i (squared norm components). + + Args: + x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). + y: Tensor of shape [N_total_entities, D]. Ignored if is_sum_sq is True. + batch_indices: Tensor of shape [N_total_entities] indicating batch membership. + + Returns: + Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + (or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch, + summed over all components D and all entities in the batch. + """ + if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]: + raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") + + output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) + output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) + + return output diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 0d94b4a5..1a58b4f6 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1481,14 +1481,28 @@ def _ase_fire_step( # noqa: C901, PLR0915 # 3. Velocity mixing BEFORE acceleration (ASE ordering) # Atoms - v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True) - f_unit_atom = state.forces / (f_norm_atom + eps) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) - v_new_atom = ( - 1.0 - alpha_atom - ) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom + # print(f"{state.velocities.shape=}") + v_sum_sq_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + # sum_sq per batch, shape [n_batches] + f_sum_sq_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + # sum_sq per batch, shape [n_batches] + + # Expand to per-atom for applying to vectors + # These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch + # Effectively |V|_batch and |F|_batch for the mixing formula + sqrt_v_sum_sq_batch_expanded = torch.sqrt(v_sum_sq_batch[state.batch].unsqueeze(-1)) + sqrt_f_sum_sq_batch_expanded = torch.sqrt(f_sum_sq_batch[state.batch].unsqueeze(-1)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) # per-atom mask + + # ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch + # = (1-a)*v + a * f * (|V|_batch / |F|_batch) + mixing_term_atom = state.forces * ( + sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps) + ) + + v_new_atom = (1.0 - alpha_atom) * state.velocities + alpha_atom * mixing_term_atom state.velocities = torch.where( pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) ) @@ -1524,12 +1538,22 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_cell = cell_dt * state.cell_velocities # 6. Clamp to max_step - dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) - mask_atom_max_step = dr_norm_atom > max_step - dr_atom = torch.where( - mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom + dr_atom_sum_sq_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) + norm_dr_atom_per_batch = torch.sqrt(dr_atom_sum_sq_batch) # shape [n_batches] + + mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches] + + scaling_factor_batch = torch.ones_like(norm_dr_atom_per_batch) + safe_norm_for_clamped_batches = norm_dr_atom_per_batch[mask_clamp_batch] + scaling_factor_batch[mask_clamp_batch] = max_step / ( + safe_norm_for_clamped_batches + eps ) + # shape [N_atoms, 1] + atom_wise_scaling_factor = scaling_factor_batch[state.batch].unsqueeze(-1) + + dr_atom = dr_atom * atom_wise_scaling_factor + old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) From fa0830aaef516fdbc7ddc383116272386de751c0 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 20:27:51 -0400 Subject: [PATCH 02/22] clean: remove ai slop --- torch_sim/optimizers.py | 221 ++++++++++++++++------------------------ 1 file changed, 90 insertions(+), 131 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 1a58b4f6..a2fa19e6 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1447,85 +1447,81 @@ def _ase_fire_step( # noqa: C901, PLR0915 device, dtype = state.positions.device, state.positions.dtype n_batches = state.n_batches - # Setup batch-wise alpha_start for potential reset - # alpha_start is a 0-dim tensor from the factory - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype - ) + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + else: + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) - # 1. Current power (F·v) per batch (atoms + cell) - atomic_power = (state.forces * state.velocities).sum(dim=1) - batch_power = torch.zeros(n_batches, device=device, dtype=dtype) - batch_power.scatter_add_(0, state.batch, atomic_power) + # 1. Current power (F·v) per batch (atoms + cell) + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) - if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." + if is_cell_optimization: + valid_states = (UnitCellFireState, FrechetCellFIREState) + assert isinstance(state, valid_states), ( + f"Cell optimization requires one of {valid_states}." + ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + # 3. Velocity mixing BEFORE acceleration (ASE ordering) + v_scaling_batch = tsm.batched_vdot( + state.velocities, state.velocities, state.batch ) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) - # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch - - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha - - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 - - # 3. Velocity mixing BEFORE acceleration (ASE ordering) - # Atoms - # print(f"{state.velocities.shape=}") - v_sum_sq_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) - # sum_sq per batch, shape [n_batches] - f_sum_sq_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) - # sum_sq per batch, shape [n_batches] - - # Expand to per-atom for applying to vectors - # These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch - # Effectively |V|_batch and |F|_batch for the mixing formula - sqrt_v_sum_sq_batch_expanded = torch.sqrt(v_sum_sq_batch[state.batch].unsqueeze(-1)) - sqrt_f_sum_sq_batch_expanded = torch.sqrt(f_sum_sq_batch[state.batch].unsqueeze(-1)) - - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) # per-atom mask - - # ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch - # = (1-a)*v + a * f * (|V|_batch / |F|_batch) - mixing_term_atom = state.forces * ( - sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps) - ) + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + v_scaling_batch += ( + state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + ) + f_scaling_batch += ( + state.cell_forces.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + ) - v_new_atom = (1.0 - alpha_atom) * state.velocities + alpha_atom * mixing_term_atom - state.velocities = torch.where( - pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) - ) + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell velocity mixing - cv_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cf_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cf_unit = state.cell_forces / (cf_norm + eps) - alpha_cell_bc = state.alpha.view(-1, 1, 1) - pos_mask_cell_bc = pos_mask_batch.view(-1, 1, 1) - v_new_cell = ( - 1.0 - alpha_cell_bc - ) * state.cell_velocities + alpha_cell_bc * cf_unit * cv_norm - state.cell_velocities = torch.where( - pos_mask_cell_bc, v_new_cell, torch.zeros_like(state.cell_velocities) + alpha_cell_bc = state.alpha.view(-1, 1, 1) + state.cell_velocities = torch.where( + pos_mask_batch.view(-1, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) + + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) atom_dt = state.dt[state.batch].unsqueeze(-1) state.velocities += atom_dt * state.forces - if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_dt = state.dt.view(-1, 1, 1) @@ -1537,103 +1533,71 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) dr_cell = cell_dt * state.cell_velocities - # 6. Clamp to max_step - dr_atom_sum_sq_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - norm_dr_atom_per_batch = torch.sqrt(dr_atom_sum_sq_batch) # shape [n_batches] - - mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches] - - scaling_factor_batch = torch.ones_like(norm_dr_atom_per_batch) - safe_norm_for_clamped_batches = norm_dr_atom_per_batch[mask_clamp_batch] - scaling_factor_batch[mask_clamp_batch] = max_step / ( - safe_norm_for_clamped_batches + eps - ) - - # shape [N_atoms, 1] - atom_wise_scaling_factor = scaling_factor_batch[state.batch].unsqueeze(-1) - - dr_atom = dr_atom * atom_wise_scaling_factor + # 6. Position / cell update + dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell clamp to max_step (Frobenius norm) - dr_cell_norm_fro = torch.norm(dr_cell.view(n_batches, -1), dim=1, keepdim=True) - mask_cell_max_step = dr_cell_norm_fro.view(n_batches, 1, 1) > max_step + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + dr_scaling_cell = torch.sqrt(dr_scaling_batch.view(n_batches, 1, 1)) dr_cell = torch.where( - mask_cell_max_step, - max_step * dr_cell / (dr_cell_norm_fro.view(n_batches, 1, 1) + eps), + dr_scaling_cell > max_step, + max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - # 7. Position / cell update - # Store old cell for scaling atoms - # Ensure old_row_vector_cell is cloned before any modification to state.cell or - # state.row_vector_cell + # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() + dr_scaling_atom = torch.sqrt(dr_scaling_batch[state.batch]) + dr_atom = torch.where( + dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + ) state.positions = state.positions + dr_atom - # F_new stores F_new for Frechet's ucf_cell_grad if needed - F_new: torch.Tensor | None = None - # logm_F_new stores logm_F_new for Frechet's cell_forces recalc if needed - logm_F_new: torch.Tensor | None = None - if is_cell_optimization: + F_new: torch.Tensor | None = None + logm_F_new: torch.Tensor | None = None + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) - # Frechet cell update logic new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled - # cell_factor is (N,1,1) logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) F_new = torch.matrix_exp(logm_F_new) new_row_vector_cell = torch.bmm( state.reference_row_vector_cell, F_new.transpose(-2, -1) ) state.row_vector_cell = new_row_vector_cell - else: # UnitCellFire + else: assert isinstance(state, UnitCellFireState) - # Unit cell update logic F_current = state.deform_grad() - # state.cell_factor is (N,1,1), F_current is (N,3,3) - # cell_factor_exp for element-wise F_current * cell_factor_exp should be - # (N,3,3) or broadcast from (N,1,1) or (N,3,1) cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell - state.cell_positions = F_new_scaled # track the scaled deformation gradient - F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) - # When state.cell is set, state.row_vector_cell is auto-updated + state.cell_positions = F_new_scaled + F_new = F_new_scaled / (cell_factor_exp_mult + eps) new_cell_column_vectors = torch.bmm( state.reference_cell, F_new.transpose(-2, -1) ) state.cell = new_cell_column_vectors - # Scale atomic positions according to cell change (mimicking scale_atoms=True) - if is_cell_optimization and old_row_vector_cell is not None: - current_new_row_vector_cell = state.row_vector_cell # This is A_new after update - + # rescale the positions after cell update + current_new_row_vector_cell = state.row_vector_cell inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) - # Transform matrix T such that A_new = A_old @ T (for row vectors A) - # This means cartesian positions P_new_row = P_old_row @ T transform_matrix_batch = torch.bmm( inv_old_cell_batch, current_new_row_vector_cell - ) # Shape [N_batch, 3, 3] - - # Shape: [N_atoms, 3, 3] + ) atom_specific_transform = transform_matrix_batch[state.batch] - - # state.positions is [N_atoms, 3]. Unsqueeze to [N_atoms, 1, 3] for bmm - # Result of bmm will be [N_atoms, 1, 3], then squeeze scaled_positions = torch.bmm( state.positions.unsqueeze(1), atom_specific_transform ).squeeze(1) state.positions = scaled_positions - # 8. Force / stress refresh & new cell forces + # 7. Force / stress refresh & new cell forces results = model(state) state.forces = results["forces"] state.energy = results["energy"] @@ -1648,7 +1612,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) - # volumes = torch.clamp(volumes, min=eps) # Optional: for stability virial = -volumes * (state.stress + state.pressure) @@ -1657,7 +1620,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) - if state.constant_volume: # Can be true even if hydrostatic_strain is false + + if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype @@ -1671,7 +1635,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert logm_F_new is not None, ( "logm_F_new should be defined for Frechet cell force calculation" ) - # Frechet cell force recalculation ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) @@ -1683,7 +1646,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) for b_idx in range(n_batches): - # logm_F_new[b_idx] is the current point in log-space expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) @@ -1694,12 +1656,9 @@ def _ase_fire_step( # noqa: C901, PLR0915 expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) ) new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces_log_space / ( - state.cell_factor + eps - ) # cell_factor is (N,1,1) - else: # UnitCellFire + state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) + else: assert isinstance(state, UnitCellFireState) - # Unit cell force recalculation - state.cell_forces = virial / state.cell_factor # cell_factor is (N,1,1) + state.cell_forces = virial / state.cell_factor return state From 53f6839bad850f6ce08da8c14978f70b77ca1daf Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 21:59:06 -0400 Subject: [PATCH 03/22] clean: further attempts to clean but still not matching PR --- torch_sim/optimizers.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index a2fa19e6..1bb3f1bb 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1164,13 +1164,13 @@ def fire_init( batch=state.batch, pbc=state.pbc, # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=torch.zeros((n_batches, 3, 3), device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1449,6 +1449,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if state.velocities is None: state.velocities = torch.zeros_like(state.positions) + forces = state.forces if is_cell_optimization: state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype @@ -1458,8 +1459,15 @@ def _ase_fire_step( # noqa: C901, PLR0915 (n_batches,), alpha_start.item(), device=device, dtype=dtype ) + if is_cell_optimization: + forces = torch.bmm( + state.forces.unsqueeze(1), state.deform_grad()[state.batch] + ).squeeze(1) + else: + forces = state.forces + # 1. Current power (F·v) per batch (atoms + cell) - batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) + batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) if is_cell_optimization: valid_states = (UnitCellFireState, FrechetCellFIREState) @@ -1520,27 +1528,17 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - atom_dt = state.dt[state.batch].unsqueeze(-1) - state.velocities += atom_dt * state.forces - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_dt = state.dt.view(-1, 1, 1) - state.cell_velocities += cell_dt * state.cell_forces - - # 5. Displacements - dr_atom = atom_dt * state.velocities - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - dr_cell = cell_dt * state.cell_velocities - - # 6. Position / cell update + state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.forces * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - dr_scaling_cell = torch.sqrt(dr_scaling_batch.view(n_batches, 1, 1)) + dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), @@ -1550,16 +1548,13 @@ def _ase_fire_step( # noqa: C901, PLR0915 # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() - dr_scaling_atom = torch.sqrt(dr_scaling_batch[state.batch]) + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch] dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) state.positions = state.positions + dr_atom if is_cell_optimization: - F_new: torch.Tensor | None = None - logm_F_new: torch.Tensor | None = None - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) From 75ca9a55e47a78d3123407e25ea1d65eae2caa7a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 10:45:42 -0400 Subject: [PATCH 04/22] fix: dr is vdt rather than fdt --- torch_sim/optimizers.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 1bb3f1bb..22d10ecd 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1406,8 +1406,11 @@ def _vv_fire_step( # noqa: C901, PLR0915 return state +VALID_FIRE_CELL_STATES = (UnitCellFireState, FrechetCellFIREState) + + def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1451,6 +1454,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.velocities = torch.zeros_like(state.positions) forces = state.forces if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype ) @@ -1470,10 +1477,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." - ) batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos @@ -1496,7 +1499,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) v_scaling_batch += ( state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) ) @@ -1529,11 +1531,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) - dr_atom = state.forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) @@ -1555,7 +1556,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + dr_atom if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) new_logm_F_scaled = state.cell_positions + dr_cell @@ -1598,7 +1598,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(-1, 1, 1) if torch.any(volumes <= 0): From 971a4f5737100da6adcca92618324110f579acfd Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 10:57:10 -0400 Subject: [PATCH 05/22] typing: fix typing issue --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 22d10ecd..c24ea85b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1410,7 +1410,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | VALID_FIRE_CELL_STATES, + state: FireState | UnitCellFireState | FrechetCellFIREState, model: torch.nn.Module, *, dt_max: torch.Tensor, From cc94facb1c164c8def7823a464cc7711fa956d97 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 13:45:30 -0400 Subject: [PATCH 06/22] wip: still not sure where the difference is now --- torch_sim/math.py | 2 +- torch_sim/optimizers.py | 150 +++++++++++++++++++++------------------- 2 files changed, 78 insertions(+), 74 deletions(-) diff --git a/torch_sim/math.py b/torch_sim/math.py index 0adfbe0a..ff78757d 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -997,7 +997,7 @@ def batched_vdot( Args: x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). - y: Tensor of shape [N_total_entities, D]. Ignored if is_sum_sq is True. + y: Tensor of shape [N_total_entities, D]. batch_indices: Tensor of shape [N_total_entities] indicating batch membership. Returns: diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index c24ea85b..bdfef729 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -270,6 +270,8 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) + n_batches = state.n_batches + # Setup cell_factor if cell_factor is None: # Count atoms per batch @@ -283,7 +285,7 @@ def gd_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) scalar_pressure = torch.full( (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype @@ -316,7 +318,7 @@ def gd_init( ) # shape: (n_batches, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: @@ -391,7 +393,7 @@ def gd_step( # Get per-atom and per-cell learning rates atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,7 +417,7 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -811,7 +813,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -824,7 +826,7 @@ def fire_init( forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1097,7 +1099,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -1121,7 +1123,7 @@ def fire_init( cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1202,8 +1204,11 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) +VALID_FIRE_CELL_STATES = UnitCellFireState | FrechetCellFIREState + + def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1215,7 +1220,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for @@ -1244,6 +1249,17 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) @@ -1252,7 +1268,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) @@ -1261,7 +1276,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: assert isinstance(state, FrechetCellFIREState) @@ -1284,7 +1298,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 else: assert isinstance(state, UnitCellFireState) cur_deform_grad = state.deform_grad() - # cell_factor is (N,1,1) cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded @@ -1305,9 +1318,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: @@ -1351,66 +1363,62 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - atomic_power = (state.forces * state.velocities).sum(dim=1) - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_(dim=0, index=state.batch, src=atomic_power) - batch_power = atomic_power_per_batch + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power - - for batch_idx in range(n_batches): - if batch_power[batch_idx] > 0: - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = torch.minimum(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start_batch[batch_idx] - state.velocities[state.batch == batch_idx] = 0 - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - state.cell_velocities[batch_idx] = 0 - - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = (1.0 - atom_wise_alpha) * state.velocities + ( - atom_wise_alpha * state.forces * v_norm / (f_norm + eps) - ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = (cell_f_norm > eps).expand_as(state.cell_velocities) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - cell_mask, - (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / (cell_f_norm + eps), - state.cell_velocities, + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), ) - return state + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) -VALID_FIRE_CELL_STATES = (UnitCellFireState, FrechetCellFIREState) + return state def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1423,7 +1431,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm @@ -1499,20 +1507,16 @@ def _ase_fire_step( # noqa: C901, PLR0915 f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - v_scaling_batch += ( - state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - ) - f_scaling_batch += ( - state.cell_forces.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - ) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(-1, 1, 1) + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(-1, 1, 1), + pos_mask_batch.view(n_batches, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), @@ -1535,10 +1539,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) if is_cell_optimization: - state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) + state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) - dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, @@ -1549,7 +1553,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() - dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch] + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) @@ -1599,7 +1603,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( From 4a494e6de2f7f52bed22de386635bc1e17e9667b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 26 May 2025 20:15:53 -0400 Subject: [PATCH 07/22] update forces per comment --- torch_sim/optimizers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index bdfef729..108b87a8 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1504,7 +1504,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 v_scaling_batch = tsm.batched_vdot( state.velocities, state.velocities, state.batch ) - f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) if is_cell_optimization: v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) @@ -1524,7 +1524,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) - v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( @@ -1534,7 +1534,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) + state.velocities += forces * state.dt[state.batch].unsqueeze(-1) dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) From d1de14a1fc59ad3ca0bded7dfff3f3c311d7a49b Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Wed, 28 May 2025 22:30:55 +0200 Subject: [PATCH 08/22] Fix ASE pos only implementation * Initialize velocities to None in the pos-only case, see previous changes to the optimizers. (ensures that the correct `dt` is used) * Change the order of the increase of `n_pos`. Again, this ensures the usage of the correct `dt` compared to ASE --- torch_sim/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 108b87a8..c1e564d1 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -592,7 +592,7 @@ def fire_init( batch=state.batch.clone(), pbc=state.pbc, # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, # Optimization attributes @@ -1491,10 +1491,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 pos_mask_batch = batch_power > 0.0 neg_mask_batch = ~pos_mask_batch - state.n_pos[pos_mask_batch] += 1 inc_mask = (state.n_pos > n_min) & pos_mask_batch state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha + state.n_pos[pos_mask_batch] += 1 state.dt[neg_mask_batch] *= f_dec state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] From beeb4df6b508dee76f3b57176ea2765637d5ce50 Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Wed, 28 May 2025 22:33:18 +0200 Subject: [PATCH 09/22] Fix torch-sim ASE-FIRE (Frechet Cell) * Remove rescaling of positions when updating cell, it's not relevant * Correctly rescale the positions with respect to the deformation gradient * Consider the `cell_forces` in the convergence when doing cell optimizations --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 4 ++-- torch_sim/optimizers.py | 22 +++++++++++++------ torch_sim/runners.py | 12 ++++++++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 2fc90f3b..76dd6b00 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -171,7 +171,7 @@ def run_optimization_ts( # noqa: PLR0915 convergence_steps = torch.full( (total_structures,), -1, dtype=torch.long, device=device ) - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol, include_cell_forces=ts_use_frechet) converged_tensor_global = torch.zeros( total_structures, dtype=torch.bool, device=device ) @@ -194,7 +194,7 @@ def run_optimization_ts( # noqa: PLR0915 current_indices_list, dtype=torch.long, device=device ) - steps_this_round = 10 + steps_this_round = 1 for _ in range(steps_this_round): opt_state = update_fn_opt(opt_state) global_step += steps_this_round diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index c1e564d1..198d4f8d 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1469,14 +1469,16 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype ) + cur_deform_grad = state.deform_grad() else: alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) if is_cell_optimization: + cur_deform_grad = state.deform_grad() forces = torch.bmm( - state.forces.unsqueeze(1), state.deform_grad()[state.batch] + state.forces.unsqueeze(1), cur_deform_grad[state.batch] ).squeeze(1) else: forces = state.forces @@ -1554,10 +1556,17 @@ def _ase_fire_step( # noqa: C901, PLR0915 old_row_vector_cell = state.row_vector_cell.clone() dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) + dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) - state.positions = state.positions + dr_atom + + if is_cell_optimization: + state.positions = torch.linalg.solve( + cur_deform_grad[state.batch], state.positions.unsqueeze(-1) + ).squeeze(-1) + dr_atom + else: + state.positions = state.positions + dr_atom if is_cell_optimization: if is_frechet: @@ -1590,11 +1599,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 transform_matrix_batch = torch.bmm( inv_old_cell_batch, current_new_row_vector_cell ) - atom_specific_transform = transform_matrix_batch[state.batch] - scaled_positions = torch.bmm( - state.positions.unsqueeze(1), atom_specific_transform - ).squeeze(1) - state.positions = scaled_positions + + state.positions = torch.bmm( + state.positions.unsqueeze(1), F_new[state.batch].transpose(-2, -1) + ).squeeze(1) # 7. Force / stress refresh & new cell forces results = model(state) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 737312b6..2913d41e 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -278,12 +278,13 @@ def _chunked_apply( return concatenate_states(ordered_states) -def generate_force_convergence_fn(force_tol: float = 1e-1) -> Callable: +def generate_force_convergence_fn(force_tol: float = 1e-1, include_cell_forces: bool = False) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. Args: force_tol (float): Force tolerance for convergence + include_cell_forces (bool): Whether to include the `cell_forces` in the convergence check. Returns: Convergence function that takes a state and last energy and @@ -295,7 +296,14 @@ def convergence_fn( last_energy: torch.Tensor | None = None, # noqa: ARG001 ) -> bool: """Check if the system has converged.""" - return batchwise_max_force(state) < force_tol + force_conv = batchwise_max_force(state) < force_tol + if not include_cell_forces: + return force_conv + + cell_forces_norm, _ = state.cell_forces.norm(dim=2).max(dim=1) + cell_force_conv = cell_forces_norm < force_tol + + return force_conv & cell_force_conv return convergence_fn From 0ba531cb73bafa38dc7b82b101bf7393251e9970 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 28 May 2025 18:30:06 -0400 Subject: [PATCH 10/22] linting --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 4 +- tests/test_optimizers_vs_ase.py | 88 +++++++++++++------ torch_sim/optimizers.py | 28 +++--- torch_sim/runners.py | 9 +- 4 files changed, 79 insertions(+), 50 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 76dd6b00..399c5a5e 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -171,7 +171,9 @@ def run_optimization_ts( # noqa: PLR0915 convergence_steps = torch.full( (total_structures,), -1, dtype=torch.long, device=device ) - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol, include_cell_forces=ts_use_frechet) + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=ts_use_frechet + ) converged_tensor_global = torch.zeros( total_structures, dtype=torch.bool, device=device ) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index e52a82d1..188101ed 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -5,9 +5,10 @@ import torch from ase.filters import FrechetCellFilter, UnitCellFilter from ase.optimize import FIRE +from pymatgen.analysis.structure_matcher import StructureMatcher import torch_sim as ts -from torch_sim.io import state_to_atoms +from torch_sim.io import atoms_to_state, state_to_atoms, state_to_structures from torch_sim.models.mace import MaceModel from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire @@ -63,6 +64,13 @@ def _run_and_compare_optimizers( last_checkpoint_step_count = 0 convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + structure_matcher = StructureMatcher( + ltol=tolerances["lattice_tol"], + stol=tolerances["site_tol"], + angle_tol=tolerances["angle_tol"], + scale=False, + ) + for checkpoint_step in checkpoints: steps_for_current_segment = checkpoint_step - last_checkpoint_step_count @@ -95,22 +103,22 @@ def _run_and_compare_optimizers( final_custom_forces_max = ( torch.norm(ts_current_system_state.forces, dim=-1).max().item() ) - final_custom_positions = ts_current_system_state.positions.detach() - final_custom_cell = ts_current_system_state.row_vector_cell.squeeze(0).detach() + # Convert torch-sim state to pymatgen Structure + ts_structure = state_to_structures(ts_current_system_state)[0] + + # Convert ASE atoms to pymatgen Structure final_ase_atoms = filtered_ase_atoms_for_run.atoms final_ase_energy = final_ase_atoms.get_potential_energy() ase_forces_raw = final_ase_atoms.get_forces() final_ase_forces_max = torch.norm( torch.tensor(ase_forces_raw, device=device, dtype=dtype), dim=-1 ).max() - final_ase_positions = torch.tensor( - final_ase_atoms.get_positions(), device=device, dtype=dtype - ) - final_ase_cell = torch.tensor( - final_ase_atoms.get_cell(), device=device, dtype=dtype - ) + ase_structure = state_to_structures( + atoms_to_state(final_ase_atoms, device, dtype) + )[0] + # Compare energies energy_diff = abs(final_custom_energy - final_ase_energy) assert energy_diff < tolerances["energy"], ( f"{current_test_id}: Final energies differ significantly: " @@ -118,20 +126,7 @@ def _run_and_compare_optimizers( f"Diff={energy_diff:.2e}" ) - avg_displacement = ( - torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() - ) - assert avg_displacement < tolerances["pos"], ( - f"{current_test_id}: Final positions differ ({avg_displacement=:.4f})" - ) - - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() - assert cell_diff < tolerances["cell"], ( - f"{current_test_id}: Final cell matrices differ (Frobenius norm: " - f"{cell_diff:.4f})\nTorch-sim Cell:\n{final_custom_cell}" - f"\nASE Cell:\n{final_ase_cell}" - ) - + # Compare forces force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) assert force_max_diff < tolerances["force_max"], ( f"{current_test_id}: Max forces differ significantly: " @@ -139,6 +134,13 @@ def _run_and_compare_optimizers( f"Diff={force_max_diff:.2e}" ) + # Compare structures using StructureMatcher + assert structure_matcher.fit(ts_structure, ase_structure), ( + f"{current_test_id}: Structures do not match according to StructureMatcher, " + f"{ts_structure=}" + f"{ase_structure=}" + ) + last_checkpoint_step_count = checkpoint_step @@ -159,7 +161,13 @@ def _run_and_compare_optimizers( FrechetCellFilter, [33, 66, 100], 0.02, - {"energy": 1e-2, "pos": 1.5e-2, "cell": 1.8e-2, "force_max": 1.5e-1}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "SiO2 (Frechet)", ), ( @@ -168,7 +176,13 @@ def _run_and_compare_optimizers( FrechetCellFilter, [16, 33, 50], 0.02, - {"energy": 1e-4, "pos": 1e-3, "cell": 1.8e-3, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "OsN2 (Frechet)", ), ( @@ -177,7 +191,13 @@ def _run_and_compare_optimizers( FrechetCellFilter, [33, 66, 100], 0.01, - {"energy": 1e-2, "pos": 5e-3, "cell": 2e-2, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, "Triclinic Al (Frechet)", ), ( @@ -186,7 +206,13 @@ def _run_and_compare_optimizers( UnitCellFilter, [33, 66, 100], 0.01, - {"energy": 1e-2, "pos": 3e-2, "cell": 1e-1, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, "Triclinic Al (UnitCell)", ), ( @@ -195,7 +221,13 @@ def _run_and_compare_optimizers( UnitCellFilter, [33, 66, 100], 0.02, - {"energy": 1.5e-2, "pos": 2.5e-2, "cell": 5e-2, "force_max": 0.25}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "SiO2 (UnitCell)", ), ], diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 198d4f8d..8d1f025e 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1552,9 +1552,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_cell, ) - # save the old cell to allow rescaling of the positions after cell update - old_row_vector_cell = state.row_vector_cell.clone() - dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) dr_atom = torch.where( @@ -1562,13 +1559,13 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) if is_cell_optimization: - state.positions = torch.linalg.solve( - cur_deform_grad[state.batch], state.positions.unsqueeze(-1) - ).squeeze(-1) + dr_atom - else: - state.positions = state.positions + dr_atom + state.positions = ( + torch.linalg.solve( + cur_deform_grad[state.batch], state.positions.unsqueeze(-1) + ).squeeze(-1) + + dr_atom + ) - if is_cell_optimization: if is_frechet: assert isinstance(state, FrechetCellFIREState) new_logm_F_scaled = state.cell_positions + dr_cell @@ -1593,16 +1590,11 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) state.cell = new_cell_column_vectors - # rescale the positions after cell update - current_new_row_vector_cell = state.row_vector_cell - inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) - transform_matrix_batch = torch.bmm( - inv_old_cell_batch, current_new_row_vector_cell - ) - state.positions = torch.bmm( - state.positions.unsqueeze(1), F_new[state.batch].transpose(-2, -1) - ).squeeze(1) + state.positions.unsqueeze(1), F_new[state.batch].transpose(-2, -1) + ).squeeze(1) + else: + state.positions = state.positions + dr_atom # 7. Force / stress refresh & new cell forces results = model(state) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 2913d41e..72d39fcc 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -278,13 +278,16 @@ def _chunked_apply( return concatenate_states(ordered_states) -def generate_force_convergence_fn(force_tol: float = 1e-1, include_cell_forces: bool = False) -> Callable: +def generate_force_convergence_fn( + force_tol: float = 1e-1, *, include_cell_forces: bool = False +) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. Args: force_tol (float): Force tolerance for convergence - include_cell_forces (bool): Whether to include the `cell_forces` in the convergence check. + include_cell_forces (bool): Whether to include the `cell_forces` in + the convergence check. Returns: Convergence function that takes a state and last energy and @@ -302,7 +305,7 @@ def convergence_fn( cell_forces_norm, _ = state.cell_forces.norm(dim=2).max(dim=1) cell_force_conv = cell_forces_norm < force_tol - + return force_conv & cell_force_conv return convergence_fn From 183e19ddc64b1a35651335af0802da24eb862cf5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 28 May 2025 19:31:22 -0400 Subject: [PATCH 11/22] test: still differ significantly after step 1 for distorted structures --- tests/test_optimizers_vs_ase.py | 172 ++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 73 deletions(-) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 188101ed..b4fd464b 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -1,4 +1,3 @@ -import copy from typing import TYPE_CHECKING, Any import pytest @@ -17,6 +16,65 @@ from mace.calculators import MACECalculator +def _compare_ase_and_ts_states( + ts_current_system_state: ts.state.SimState, + filtered_ase_atoms_for_run: Any, + tolerances: dict[str, float], + current_test_id: str, +) -> None: + structure_matcher = StructureMatcher( + ltol=tolerances["lattice_tol"], + stol=tolerances["site_tol"], + angle_tol=tolerances["angle_tol"], + scale=False, + ) + + tkwargs = { + "device": ts_current_system_state.device, + "dtype": ts_current_system_state.dtype, + } + + final_custom_energy = ts_current_system_state.energy.item() + final_custom_forces_max = ( + torch.norm(ts_current_system_state.forces, dim=-1).max().item() + ) + + # Convert torch-sim state to pymatgen Structure + ts_structure = state_to_structures(ts_current_system_state)[0] + + # Convert ASE atoms to pymatgen Structure + final_ase_atoms = filtered_ase_atoms_for_run.atoms + final_ase_energy = final_ase_atoms.get_potential_energy() + ase_forces_raw = final_ase_atoms.get_forces() + final_ase_forces_max = torch.norm( + torch.tensor(ase_forces_raw, **tkwargs), dim=-1 + ).max() + ase_structure = state_to_structures(atoms_to_state(final_ase_atoms, **tkwargs))[0] + + # Compare energies + energy_diff = abs(final_custom_energy - final_ase_energy) + assert energy_diff < tolerances["energy"], ( + f"{current_test_id}: Final energies differ significantly: " + f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " + f"Diff={energy_diff:.2e}" + ) + + # Compare forces + force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) + assert force_max_diff < tolerances["force_max"], ( + f"{current_test_id}: Max forces differ significantly: " + f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, " + f"Diff={force_max_diff:.2e}" + ) + + # Compare structures using StructureMatcher + assert structure_matcher.fit(ts_structure, ase_structure), ( + f"{current_test_id}: Structures do not match according to StructureMatcher, " + f"{ts_structure=}" + f"{ase_structure=}" + ) + + def _run_and_compare_optimizers( initial_sim_state_fixture: ts.state.SimState, torchsim_mace_mpa: MaceModel, @@ -33,14 +91,7 @@ def _run_and_compare_optimizers( dtype = torch.float64 device = torchsim_mace_mpa.device - ts_current_system_state = copy.deepcopy(initial_sim_state_fixture).to( - dtype=dtype, device=device - ) - ts_current_system_state.positions = ( - ts_current_system_state.positions.detach().requires_grad_() - ) - ts_current_system_state.cell = ts_current_system_state.cell.detach().requires_grad_() - ts_optimizer_state = None + ts_current_system_state = initial_sim_state_fixture.clone() optimizer_builders = { "frechet": frechet_cell_fire, @@ -55,7 +106,7 @@ def _run_and_compare_optimizers( ) ase_atoms_for_run = state_to_atoms( - copy.deepcopy(initial_sim_state_fixture).to(dtype=dtype, device=device) + initial_sim_state_fixture.clone().to(dtype=dtype, device=device) )[0] ase_atoms_for_run.calc = ase_mace_mpa filtered_ase_atoms_for_run = ase_filter_class(ase_atoms_for_run) @@ -64,81 +115,41 @@ def _run_and_compare_optimizers( last_checkpoint_step_count = 0 convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) - structure_matcher = StructureMatcher( - ltol=tolerances["lattice_tol"], - stol=tolerances["site_tol"], - angle_tol=tolerances["angle_tol"], - scale=False, + results = torchsim_mace_mpa(ts_current_system_state) + ts_initial_system_state = ts_current_system_state.clone() + ts_initial_system_state.forces = results["forces"] + ts_initial_system_state.energy = results["energy"] + ase_atoms_for_run.calc.calculate(ase_atoms_for_run) + + _compare_ase_and_ts_states( + ts_initial_system_state, + filtered_ase_atoms_for_run, + tolerances, + f"{test_id_prefix} (Initial)", ) for checkpoint_step in checkpoints: steps_for_current_segment = checkpoint_step - last_checkpoint_step_count if steps_for_current_segment > 0: - # Ensure requires_grad is set for the input to ts.optimize - # ts.optimize is expected to return a state suitable for further optimization - # if optimizer_state is passed. - ts_current_system_state.positions = ( - ts_current_system_state.positions.detach().requires_grad_() - ) - ts_current_system_state.cell = ( - ts_current_system_state.cell.detach().requires_grad_() - ) - new_ts_state_and_optimizer_state = ts.optimize( + updated_ts_state = ts.optimize( system=ts_current_system_state, model=torchsim_mace_mpa, optimizer=optimizer_callable_for_ts_optimize, max_steps=steps_for_current_segment, convergence_fn=convergence_fn, - optimizer_state=ts_optimizer_state, ) - ts_current_system_state = new_ts_state_and_optimizer_state - ts_optimizer_state = new_ts_state_and_optimizer_state + ts_current_system_state = updated_ts_state.clone() ase_optimizer.run(fmax=force_tol, steps=steps_for_current_segment) current_test_id = f"{test_id_prefix} (Step {checkpoint_step})" - final_custom_energy = ts_current_system_state.energy.item() - final_custom_forces_max = ( - torch.norm(ts_current_system_state.forces, dim=-1).max().item() - ) - - # Convert torch-sim state to pymatgen Structure - ts_structure = state_to_structures(ts_current_system_state)[0] - - # Convert ASE atoms to pymatgen Structure - final_ase_atoms = filtered_ase_atoms_for_run.atoms - final_ase_energy = final_ase_atoms.get_potential_energy() - ase_forces_raw = final_ase_atoms.get_forces() - final_ase_forces_max = torch.norm( - torch.tensor(ase_forces_raw, device=device, dtype=dtype), dim=-1 - ).max() - ase_structure = state_to_structures( - atoms_to_state(final_ase_atoms, device, dtype) - )[0] - - # Compare energies - energy_diff = abs(final_custom_energy - final_ase_energy) - assert energy_diff < tolerances["energy"], ( - f"{current_test_id}: Final energies differ significantly: " - f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " - f"Diff={energy_diff:.2e}" - ) - - # Compare forces - force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) - assert force_max_diff < tolerances["force_max"], ( - f"{current_test_id}: Max forces differ significantly: " - f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, " - f"Diff={force_max_diff:.2e}" - ) - - # Compare structures using StructureMatcher - assert structure_matcher.fit(ts_structure, ase_structure), ( - f"{current_test_id}: Structures do not match according to StructureMatcher, " - f"{ts_structure=}" - f"{ase_structure=}" + _compare_ase_and_ts_states( + ts_current_system_state, + filtered_ase_atoms_for_run, + tolerances, + current_test_id, ) last_checkpoint_step_count = checkpoint_step @@ -159,7 +170,7 @@ def _run_and_compare_optimizers( "rattled_sio2_sim_state", "frechet", FrechetCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.02, { "energy": 1e-2, @@ -174,7 +185,7 @@ def _run_and_compare_optimizers( "osn2_sim_state", "frechet", FrechetCellFilter, - [16, 33, 50], + [1, 16, 33, 50], 0.02, { "energy": 1e-2, @@ -189,7 +200,7 @@ def _run_and_compare_optimizers( "distorted_fcc_al_conventional_sim_state", "frechet", FrechetCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.01, { "energy": 1e-2, @@ -204,7 +215,7 @@ def _run_and_compare_optimizers( "distorted_fcc_al_conventional_sim_state", "unit_cell", UnitCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.01, { "energy": 1e-2, @@ -219,7 +230,7 @@ def _run_and_compare_optimizers( "rattled_sio2_sim_state", "unit_cell", UnitCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.02, { "energy": 1e-2, @@ -230,6 +241,21 @@ def _run_and_compare_optimizers( }, "SiO2 (UnitCell)", ), + ( + "osn2_sim_state", + "unit_cell", + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "OsN2 (UnitCell)", + ), ], ) def test_optimizer_vs_ase_parametrized( From 25cf6e2f779a4da0f2d44aed9bb619b41c2d395f Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Mon, 2 Jun 2025 12:16:46 +0200 Subject: [PATCH 12/22] Fix test comparing ASE and torch-sim optimization * Include the `cell_forces` in the convergence check * Fix the number of iterations that are performed. `steps_between_swaps` is set to 1, so the number of iterations is equal to the number of swaps. In the previous version, less iterations would have been performed when reaching the maximum number of swaps. For example, when trying to run 32 steps with `steps_between_swaps=5`, the optimization would have stopped after 30 iterations, i.e., 6 swaps. * Fix `autobatching.py`. The if statement would have been triggered for `max_attempts=0`, which was the case when running one iteration and `steps_between_swaps=5` --- tests/test_optimizers_vs_ase.py | 5 ++++- torch_sim/autobatching.py | 4 +++- torch_sim/runners.py | 27 ++++++++++++++++++++------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index b4fd464b..19dca7d8 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -113,7 +113,9 @@ def _run_and_compare_optimizers( ase_optimizer = FIRE(filtered_ase_atoms_for_run, logfile=None) last_checkpoint_step_count = 0 - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) results = torchsim_mace_mpa(ts_current_system_state) ts_initial_system_state = ts_current_system_state.clone() @@ -138,6 +140,7 @@ def _run_and_compare_optimizers( optimizer=optimizer_callable_for_ts_optimize, max_steps=steps_for_current_segment, convergence_fn=convergence_fn, + steps_between_swaps=1, ) ts_current_system_state = updated_ts_state.clone() diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 3b87d489..e7781d26 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -1030,7 +1030,9 @@ def next_batch( # Increment attempt counters and check for max attempts in a single loop for cur_idx, abs_idx in enumerate(self.current_idx): self.swap_attempts[abs_idx] += 1 - if self.max_attempts and (self.swap_attempts[abs_idx] >= self.max_attempts): + if self.max_attempts is not None and ( + self.swap_attempts[abs_idx] >= self.max_attempts + ): # Force convergence for states that have reached max attempts convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003 diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 72d39fcc..b550f865 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -16,6 +16,12 @@ from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import ( + FireState, + FrechetCellFIREState, + UnitCellFireState, + UnitCellGDState, +) from torch_sim.quantities import batchwise_max_force, calc_kinetic_energy, calc_kT from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TrajectoryReporter @@ -389,13 +395,20 @@ def optimize( # noqa: C901 autobatcher = _configure_in_flight_autobatcher( model, state, autobatcher, max_attempts ) - state = _chunked_apply( - init_fn, - state, - model, - max_memory_scaler=autobatcher.max_memory_scaler, - memory_scales_with=autobatcher.memory_scales_with, - ) + + if type(state) not in [ + FireState, + UnitCellFireState, + UnitCellGDState, + FrechetCellFIREState, + ]: + state = _chunked_apply( + init_fn, + state, + model, + max_memory_scaler=autobatcher.max_memory_scaler, + memory_scales_with=autobatcher.memory_scales_with, + ) autobatcher.load_states(state) trajectory_reporter = _configure_reporter( trajectory_reporter, From c0afdd29a5b30de9bd50722effea5ffc2713a73c Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Mon, 2 Jun 2025 12:25:57 +0200 Subject: [PATCH 13/22] Fix `optimizers` when using `UnitCellFilter` * Fix the `None` initialization * Fix the cell update when using `UnitCellFilter` --- torch_sim/optimizers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 8d1f025e..6a962b92 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -866,13 +866,13 @@ def fire_init( batch=state.batch.clone(), pbc=state.pbc, # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), - cell_velocities=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1585,10 +1585,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 F_new_scaled = current_F_scaled + dr_cell state.cell_positions = F_new_scaled F_new = F_new_scaled / (cell_factor_exp_mult + eps) - new_cell_column_vectors = torch.bmm( - state.reference_cell, F_new.transpose(-2, -1) + new_row_vector_cell = torch.bmm( + state.reference_row_vector_cell, F_new.transpose(-2, -1) ) - state.cell = new_cell_column_vectors + state.row_vector_cell = new_row_vector_cell state.positions = torch.bmm( state.positions.unsqueeze(1), F_new[state.batch].transpose(-2, -1) From e49c1786a7a0c5c42380c84f0c2bced3cc30b426 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 10:59:14 -0400 Subject: [PATCH 14/22] fix test_optimize_fire --- .github/workflows/link-check.yml | 2 -- tests/test_runners.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/link-check.yml b/.github/workflows/link-check.yml index 2b57b818..25a0aa68 100644 --- a/.github/workflows/link-check.yml +++ b/.github/workflows/link-check.yml @@ -18,5 +18,3 @@ jobs: with: # ignore ipynb links since they're generated on the fly args: --exclude-path dist --exclude '\.ipynb$' --accept 100..=103,200..=299,403,429,500 -- ./**/*.{md,py,yml,json} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/tests/test_runners.py b/tests/test_runners.py index f90cfee7..1b7c6260 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -297,7 +297,7 @@ def test_optimize_fire( # Check force convergence assert torch.all(final_state.forces < 3e-1) - assert energies.shape[0] > 10 + assert energies.shape[0] >= 10 assert energies[0] > energies[-1] assert not torch.allclose(original_state.positions, final_state.positions) @@ -327,7 +327,8 @@ def test_default_converged_fn( with TorchSimTrajectory(traj_file) as traj: energies = traj.get_array("energy") - assert energies[-3] > energies[-1] + # Check that overall energy decreases (first to last) + assert energies[0] > energies[-1] assert not torch.allclose(original_state.positions, final_state.positions) From 66871d56be32ee6f8b3aa1f2e08761d130ea8c78 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:03:06 -0400 Subject: [PATCH 15/22] allow FireState.velocities = None since it's being set to None in multiple places --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 6a962b92..38f60206 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -476,7 +476,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor + velocities: torch.Tensor | None = None # FIRE algorithm parameters dt: torch.Tensor From 85380ad54be96b5f7d88b6f18af168db1208c2e1 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:04:35 -0400 Subject: [PATCH 16/22] safer `batched_vdot`: check dimensionality of input tensors `y` and `batch_indices` - fix stale docstring mentioning is_sum_sq kwarg --- torch_sim/math.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torch_sim/math.py b/torch_sim/math.py index ff78757d..07037e0e 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -993,7 +993,6 @@ def batched_vdot( x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor ) -> torch.Tensor: """Computes batched vdot (sum of element-wise product) for groups of vectors. - If is_sum_sq is True, computes sum of x_i * x_i (squared norm components). Args: x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). @@ -1002,12 +1001,21 @@ def batched_vdot( Returns: Tensor: shape [n_batches] where each element is the sum(x_i * y_i) - (or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch, + for entities belonging to that batch, summed over all components D and all entities in the batch. """ - if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]: + if ( + x.ndim != 2 + or y.ndim != 2 + or batch_indices.ndim != 1 + or x.shape != y.shape + or x.shape[0] != batch_indices.shape[0] + ): raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") + if batch_indices.min() < 0: + raise ValueError("batch_indices must be non-negative") + output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) From 2e54afe08d6f2f17ae8d2da6f13bf92aa877c003 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:08:06 -0400 Subject: [PATCH 17/22] generate_force_convergence_fn raise informative error on needed but missing cell_forces --- torch_sim/runners.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b550f865..90c9aad2 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -306,13 +306,15 @@ def convergence_fn( ) -> bool: """Check if the system has converged.""" force_conv = batchwise_max_force(state) < force_tol - if not include_cell_forces: - return force_conv - cell_forces_norm, _ = state.cell_forces.norm(dim=2).max(dim=1) - cell_force_conv = cell_forces_norm < force_tol + if include_cell_forces: + if (cell_forces := getattr(state, "cell_forces", None)) is None: + raise ValueError("cell_forces not found in state") + cell_forces_norm, _ = cell_forces.norm(dim=2).max(dim=1) + cell_force_conv = cell_forces_norm < force_tol + return force_conv & cell_force_conv - return force_conv & cell_force_conv + return force_conv return convergence_fn @@ -396,12 +398,9 @@ def optimize( # noqa: C901 model, state, autobatcher, max_attempts ) - if type(state) not in [ - FireState, - UnitCellFireState, - UnitCellGDState, - FrechetCellFIREState, - ]: + if not isinstance( + state, (FireState, UnitCellFireState, UnitCellGDState, FrechetCellFIREState) + ): state = _chunked_apply( init_fn, state, From 4484465443444636d3019b9a27c3ce1977a3a8d1 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:11:58 -0400 Subject: [PATCH 18/22] pascal case VALID_FIRE_CELL_STATES->AnyFireCellState and fix non-f-string error messages --- torch_sim/optimizers.py | 18 +++++++++--------- torch_sim/runners.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 38f60206..14d625db 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1204,11 +1204,11 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -VALID_FIRE_CELL_STATES = UnitCellFireState | FrechetCellFIREState +AnyFireCellState = UnitCellFireState | FrechetCellFIREState def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | VALID_FIRE_CELL_STATES, + state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1220,7 +1220,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | VALID_FIRE_CELL_STATES: +) -> FireState | AnyFireCellState: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for @@ -1252,9 +1252,9 @@ def _vv_fire_step( # noqa: C901, PLR0915 if state.velocities is None: state.velocities = torch.zeros_like(state.positions) if is_cell_optimization: - if not isinstance(state, VALID_FIRE_CELL_STATES): + if not isinstance(state, AnyFireCellState): raise ValueError( - "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype @@ -1418,7 +1418,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | VALID_FIRE_CELL_STATES, + state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1431,7 +1431,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | VALID_FIRE_CELL_STATES: +) -> FireState | AnyFireCellState: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm @@ -1462,9 +1462,9 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.velocities = torch.zeros_like(state.positions) forces = state.forces if is_cell_optimization: - if not isinstance(state, VALID_FIRE_CELL_STATES): + if not isinstance(state, AnyFireCellState): raise ValueError( - "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 90c9aad2..5925872b 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -537,8 +537,8 @@ class StaticState(type(state)): pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) - for substate, batch_indices in batch_iterator: - print(substate.atomic_numbers) + for sub_state, batch_indices in batch_iterator: + print(sub_state.atomic_numbers) # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: # we must remake the trajectory reporter for each batch @@ -546,20 +546,20 @@ class StaticState(type(state)): filenames=[og_filenames[idx] for idx in batch_indices] ) - model_outputs = model(substate) + model_outputs = model(sub_state) - substate = StaticState( - **vars(substate), + sub_state = StaticState( + **vars(sub_state), energy=model_outputs["energy"], forces=model_outputs["forces"] if model.compute_forces else None, stress=model_outputs["stress"] if model.compute_stress else None, ) - props = trajectory_reporter.report(substate, 0, model=model) + props = trajectory_reporter.report(sub_state, 0, model=model) all_props.extend(props) if tqdm_pbar: - tqdm_pbar.update(substate.n_batches) + tqdm_pbar.update(sub_state.n_batches) trajectory_reporter.finish() From ff03101e07bc6ac1caff05b91e742d0b313be45e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:17:00 -0400 Subject: [PATCH 19/22] fix FireState TypeError: non-default argument 'dt' follows default argument --- torch_sim/optimizers.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 14d625db..7274d44a 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -476,13 +476,14 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None = None # FIRE algorithm parameters dt: torch.Tensor alpha: torch.Tensor n_pos: torch.Tensor + velocities: torch.Tensor | None = None + def fire( model: torch.nn.Module, @@ -591,8 +592,6 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, - # New attributes - velocities=None, forces=forces, energy=energy, # Optimization attributes @@ -865,8 +864,6 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, - # New attributes - velocities=None, forces=forces, energy=energy, stress=stress, @@ -1165,8 +1162,6 @@ def fire_init( atomic_numbers=state.atomic_numbers, batch=state.batch, pbc=state.pbc, - # New attributes - velocities=None, forces=forces, energy=energy, stress=stress, From 345828819960ae9260b62c1d02f7eb2efafbf335 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:35:01 -0400 Subject: [PATCH 20/22] allow None but don't set default for state.velocities --- torch_sim/optimizers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 7274d44a..43262825 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -338,6 +338,7 @@ def gd_init( forces=forces, energy=energy, stress=stress, + velocities=None, masses=state.masses, cell=state.cell, pbc=state.pbc, @@ -476,6 +477,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor + velocities: torch.Tensor | None # FIRE algorithm parameters dt: torch.Tensor @@ -592,6 +594,7 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, + velocities=None, forces=forces, energy=energy, # Optimization attributes @@ -864,6 +867,7 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, + velocities=None, forces=forces, energy=energy, stress=stress, @@ -966,7 +970,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor + cell_velocities: torch.Tensor | None cell_forces: torch.Tensor cell_masses: torch.Tensor @@ -1162,6 +1166,7 @@ def fire_init( atomic_numbers=state.atomic_numbers, batch=state.batch, pbc=state.pbc, + velocities=None, forces=forces, energy=energy, stress=stress, From d441bf391cf1ab6bc13d78b5a8be9154fcdfcbad Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 11:44:35 -0400 Subject: [PATCH 21/22] fix bad merge conflict resolution --- torch_sim/optimizers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 43262825..186df609 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -338,7 +338,6 @@ def gd_init( forces=forces, energy=energy, stress=stress, - velocities=None, masses=state.masses, cell=state.cell, pbc=state.pbc, @@ -484,8 +483,6 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor - velocities: torch.Tensor | None = None - def fire( model: torch.nn.Module, From 8deb1c82ec835f3363b4927c2deaa0a664fb28ca Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 3 Jun 2025 12:07:43 -0400 Subject: [PATCH 22/22] tweaks --- tests/test_optimizers_vs_ase.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 19dca7d8..acf5e0b9 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -29,7 +29,7 @@ def _compare_ase_and_ts_states( scale=False, ) - tkwargs = { + tensor_kwargs = { "device": ts_current_system_state.device, "dtype": ts_current_system_state.dtype, } @@ -47,9 +47,10 @@ def _compare_ase_and_ts_states( final_ase_energy = final_ase_atoms.get_potential_energy() ase_forces_raw = final_ase_atoms.get_forces() final_ase_forces_max = torch.norm( - torch.tensor(ase_forces_raw, **tkwargs), dim=-1 + torch.tensor(ase_forces_raw, **tensor_kwargs), dim=-1 ).max() - ase_structure = state_to_structures(atoms_to_state(final_ase_atoms, **tkwargs))[0] + ts_state = atoms_to_state(final_ase_atoms, **tensor_kwargs) + ase_structure = state_to_structures(ts_state)[0] # Compare energies energy_diff = abs(final_custom_energy - final_ase_energy) @@ -69,9 +70,8 @@ def _compare_ase_and_ts_states( # Compare structures using StructureMatcher assert structure_matcher.fit(ts_structure, ase_structure), ( - f"{current_test_id}: Structures do not match according to StructureMatcher, " - f"{ts_structure=}" - f"{ase_structure=}" + f"{current_test_id}: Structures do not match according to StructureMatcher\n" + f"{ts_structure=}\n{ase_structure=}" )