Skip to content

Commit a2ced4f

Browse files
committed
add batch dim to a2c functions
1 parent e2330c0 commit a2ced4f

File tree

4 files changed

+118
-26
lines changed

4 files changed

+118
-26
lines changed

examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
compute_forces=True,
5151
compute_stress=False,
5252
dtype=dtype,
53-
enable_cueq=True,
53+
enable_cueq=torch.cuda.is_available(),
5454
)
5555
state = ts.io.atoms_to_state(si_dc, device=device, dtype=dtype)
5656

torch_sim/models/mace.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,24 @@ def forward( # noqa: C901
304304
# TODO (AG): Currently doesn't work for batched neighbor lists
305305
for b in range(self.n_systems):
306306
batch_mask = state.batch == b
307-
# Calculate neighbor list for this system
307+
# Extract cell for this batch and transpose to row vector format
308+
cell_for_nl = state.row_vector_cell[b]
309+
310+
# Ensure cell has correct shape [3, 3] for neighbor list
311+
if cell_for_nl.ndim == 3:
312+
cell_for_nl = cell_for_nl.squeeze(0)
313+
314+
# Safety checks to prevent tensor confusion
315+
if state.positions[batch_mask].shape[1] != 3:
316+
raise ValueError(
317+
f"positions should have shape [n_atoms, 3], got "
318+
f"{state.positions[batch_mask].shape}"
319+
)
320+
if cell_for_nl.shape != (3, 3):
321+
raise ValueError(
322+
f"cell should have shape [3, 3], got {cell_for_nl.shape}"
323+
)
324+
308325
edge_idx, shifts_idx = self.neighbor_list_fn(
309326
positions=state.positions[batch_mask],
310327
cell=state.row_vector_cell[b],

torch_sim/neighbors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,12 @@ def vesin_nl_ts(
537537
References:
538538
https://github.com/Luthaf/vesin
539539
"""
540+
# Defensive checks to catch parameter confusion
541+
if positions.ndim != 2 or positions.shape[1] != 3:
542+
raise ValueError(f"positions must have shape [n_atoms, 3], got {positions.shape}")
543+
if cell.ndim != 2 or cell.shape[0] != 3 or cell.shape[1] != 3:
544+
raise ValueError(f"cell must have shape [3, 3], got {cell.shape}")
545+
540546
device = positions.device
541547
dtype = positions.dtype
542548

torch_sim/workflows/a2c.py

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,17 @@ def random_packed_structure(
280280
diameter = get_diameter(composition)
281281
print(f"Using random pack diameter of {diameter}")
282282

283+
# Ensure cell has batch dimension [1, 3, 3] if it doesn't already
284+
if cell.ndim == 2:
285+
cell = cell.unsqueeze(0) # Add batch dimension
286+
283287
# Perform overlap minimization if diameter is specified
284288
if diameter is not None:
285289
print("Reduce atom overlap using the soft_sphere potential")
286290
# Convert fractional to cartesian coordinates
287-
positions_cart = torch.matmul(positions, cell)
291+
positions_cart = torch.matmul(
292+
positions, cell.squeeze(0)
293+
) # Use first (and only) batch
288294

289295
# Initialize soft sphere potential calculator
290296
model = SoftSphereModel(
@@ -296,23 +302,30 @@ def random_packed_structure(
296302
)
297303

298304
# Dummy atomic numbers
299-
atomic_numbers = torch.ones_like(positions_cart, device=device, dtype=torch.int)
305+
atomic_numbers = torch.ones(N_atoms, device=device, dtype=torch.int)
306+
307+
# Create batch tensor for single system
308+
batch = torch.zeros(N_atoms, device=device, dtype=torch.long)
300309

301310
# Set up FIRE optimizer with unit masses
302311
state = ts.SimState(
303312
positions=positions_cart,
304313
masses=torch.ones(N_atoms, device=device, dtype=dtype),
305314
atomic_numbers=atomic_numbers,
306-
cell=cell,
315+
cell=cell, # Keep batch dimension
307316
pbc=True,
317+
batch=batch,
308318
)
309319
fire_init, fire_update = fire(model=model)
310320
state = fire_init(state)
311321
print(f"Initial energy: {state.energy.item():.4f}")
312322
# Run FIRE optimization until convergence or max iterations
313323
for _step in range(max_iter):
314324
# Check if minimum distance criterion is met (95% of target diameter)
315-
if min_distance(state.positions, cell, distance_tolerance) > diameter * 0.95:
325+
if (
326+
min_distance(state.positions, cell.squeeze(0), distance_tolerance)
327+
> diameter * 0.95
328+
):
316329
break
317330

318331
if log is not None:
@@ -321,6 +334,20 @@ def random_packed_structure(
321334
state = fire_update(state)
322335

323336
print(f"Final energy: {state.energy.item():.4f}")
337+
else:
338+
# If no optimization, still create a proper state with batch dimensions
339+
positions_cart = torch.matmul(positions, cell.squeeze(0))
340+
atomic_numbers = torch.ones(N_atoms, device=device, dtype=torch.int)
341+
batch = torch.zeros(N_atoms, device=device, dtype=torch.long)
342+
343+
state = ts.SimState(
344+
positions=positions_cart,
345+
masses=torch.ones(N_atoms, device=device, dtype=dtype),
346+
atomic_numbers=atomic_numbers,
347+
cell=cell, # Keep batch dimension
348+
pbc=True,
349+
batch=batch,
350+
)
324351

325352
if log is not None:
326353
return state, log
@@ -408,11 +435,17 @@ def random_packed_structure_multi(
408435
diameter_matrix = get_diameter_matrix(composition, device=device, dtype=dtype)
409436
print(f"Using random pack diameter matrix:\n{diameter_matrix.cpu().numpy()}")
410437

438+
# Ensure cell has batch dimension [1, 3, 3] if it doesn't already
439+
if cell.ndim == 2:
440+
cell = cell.unsqueeze(0) # Add batch dimension
441+
411442
# Perform overlap minimization if diameter matrix is specified
412443
if diameter_matrix is not None:
413444
print("Reduce atom overlap using the soft_sphere potential")
414445
# Convert fractional to cartesian coordinates
415-
positions_cart = torch.matmul(positions, cell)
446+
positions_cart = torch.matmul(
447+
positions, cell.squeeze(0)
448+
) # Use first (and only) batch
416449

417450
# Initialize multi-species soft sphere potential calculator
418451
model = SoftSphereMultiModel(
@@ -425,14 +458,18 @@ def random_packed_structure_multi(
425458
)
426459

427460
# Dummy atomic numbers
428-
atomic_numbers = torch.ones_like(positions_cart, device=device, dtype=torch.int)
461+
atomic_numbers = torch.ones(N_atoms, device=device, dtype=torch.int)
462+
463+
# Create batch tensor for single system
464+
batch = torch.zeros(N_atoms, device=device, dtype=torch.long)
429465

430466
state_dict = ts.SimState(
431467
positions=positions_cart,
432468
masses=torch.ones(N_atoms, device=device, dtype=dtype),
433469
atomic_numbers=atomic_numbers,
434-
cell=cell,
470+
cell=cell, # Keep batch dimension
435471
pbc=True,
472+
batch=batch,
436473
)
437474
# Set up FIRE optimizer with unit masses for all atoms
438475
fire_init, fire_update = fire(model=model)
@@ -441,11 +478,25 @@ def random_packed_structure_multi(
441478
# Run FIRE optimization until convergence or max iterations
442479
for _step in range(max_iter):
443480
# Check if minimum distance criterion is met (95% of smallest target diameter)
444-
min_dist = min_distance(state.positions, cell, distance_tolerance)
481+
min_dist = min_distance(state.positions, cell.squeeze(0), distance_tolerance)
445482
if min_dist > diameter_matrix.min() * 0.95:
446483
break
447484
state = fire_update(state)
448485
print(f"Final energy: {state.energy.item():.4f}")
486+
else:
487+
# If no optimization, still create a proper state with batch dimensions
488+
positions_cart = torch.matmul(positions, cell.squeeze(0))
489+
atomic_numbers = torch.ones(N_atoms, device=device, dtype=torch.int)
490+
batch = torch.zeros(N_atoms, device=device, dtype=torch.long)
491+
492+
state = ts.SimState(
493+
positions=positions_cart,
494+
masses=torch.ones(N_atoms, device=device, dtype=dtype),
495+
atomic_numbers=atomic_numbers,
496+
cell=cell, # Keep batch dimension
497+
pbc=True,
498+
batch=batch,
499+
)
449500

450501
return state
451502

@@ -472,8 +523,8 @@ def valid_subcell(
472523
Args:
473524
positions: Atomic positions tensor of shape [n_atoms, 3], where each row contains
474525
the (x,y,z) coordinates of an atom.
475-
cell: Unit cell tensor of shape [3, 3] containing the three lattice vectors that
476-
define the periodic boundary conditions.
526+
cell: Unit cell tensor of shape [3, 3] or [1, 3, 3] containing the three lattice
527+
vectors that define the periodic boundary conditions.
477528
initial_energy: Total energy of the structure before relaxation, in eV.
478529
final_energy: Total energy of the structure after relaxation, in eV.
479530
e_tol: Energy tolerance for comparing initial and final energies, in eV.
@@ -510,8 +561,9 @@ def valid_subcell(
510561
return False
511562

512563
# Check minimum interatomic distances to detect atomic fusion
513-
# Uses periodic boundary conditions via min_distance function
514-
min_dist = min_distance(positions, cell, distance_tolerance)
564+
# Handle both batched and unbatched cell tensors
565+
cell_for_min_dist = cell.squeeze(0) if cell.ndim == 3 else cell
566+
min_dist = min_distance(positions, cell_for_min_dist, distance_tolerance)
515567
if min_dist < fusion_distance:
516568
print("Bad structure! Fusion found.")
517569
return False
@@ -645,15 +697,18 @@ def subcells_to_structures(
645697
candidates: List of (ids, lower_bound, upper_bound)
646698
tuples from get_subcells_to_crystallize
647699
fractional_positions: Fractional coordinates of atoms
648-
cell: Unit cell tensor
700+
cell: Unit cell tensor of shape [3, 3] or [1, 3, 3]
649701
species: List of atomic species symbols
650702
651703
Returns:
652704
list[tuple[torch.Tensor, torch.Tensor, list[str]]]: Each tuple contains:
653705
- fractional_positions: Fractional coordinates of atoms
654-
- cell: Unit cell tensor
706+
- cell: Unit cell tensor with proper batch dimensions
655707
- species: atomic species symbols
656708
"""
709+
# Handle both batched and unbatched cell tensors
710+
cell_2d = cell.squeeze(0) if cell.ndim == 3 else cell
711+
657712
list_subcells = []
658713
for ids, lower_bound, upper_bound in candidates:
659714
# Get positions of atoms in this subcell
@@ -666,7 +721,11 @@ def subcells_to_structures(
666721
new_frac_pos = new_frac_pos / (upper_bound - lower_bound)
667722

668723
# Calculate new cell parameters
669-
new_cell = cell * (upper_bound - lower_bound).unsqueeze(0)
724+
new_cell = cell_2d * (upper_bound - lower_bound).unsqueeze(0)
725+
726+
# Add batch dimension to maintain consistency
727+
if cell.ndim == 3: # Original cell had batch dimension
728+
new_cell = new_cell.unsqueeze(0)
670729

671730
# Get species for these atoms and convert tensor indices to list/numpy array
672731
# before indexing species list
@@ -721,12 +780,18 @@ def get_unit_cell_relaxed_structure(
721780
tuple containing:
722781
- UnitCellFIREState: Final state containing relaxed positions, cell and more
723782
- dict: Logger with energy and stress trajectories
724-
- float: Final energy in eV
725-
- float: Final pressure in eV/ų
726783
"""
727784
# Get device and dtype from model
728785
device, dtype = model.device, model.dtype
729786

787+
# Ensure state has proper batch dimensions
788+
if state.cell.ndim == 2:
789+
state.cell = state.cell.unsqueeze(0) # Add batch dimension
790+
791+
# Ensure batch tensor exists and has correct shape
792+
if state.batch is None:
793+
state.batch = torch.zeros(len(state.positions), device=device, dtype=torch.long)
794+
730795
logger = {
731796
"energy": torch.zeros((max_iter, state.n_batches), device=device, dtype=dtype),
732797
"stress": torch.zeros(
@@ -769,7 +834,7 @@ def step_fn(
769834
f"Final energy: {[f'{e:.4f}' for e in final_energy]} eV, "
770835
f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3"
771836
)
772-
return state, logger, final_energy, final_pressure
837+
return state, logger
773838

774839

775840
def get_relaxed_structure(
@@ -791,12 +856,18 @@ def get_relaxed_structure(
791856
tuple containing:
792857
- FIREState: Final state containing relaxed positions and other quantities
793858
- dict: Logger with energy trajectory
794-
- float: Final energy in eV
795-
- float: Final pressure in eV/ų
796859
"""
797860
# Get device and dtype from model
798861
device, dtype = model.device, model.dtype
799862

863+
# Ensure state has proper batch dimensions
864+
if state.cell.ndim == 2:
865+
state.cell = state.cell.unsqueeze(0) # Add batch dimension
866+
867+
# Ensure batch tensor exists and has correct shape
868+
if state.batch is None:
869+
state.batch = torch.zeros(len(state.positions), device=device, dtype=torch.long)
870+
800871
logger = {"energy": torch.zeros((max_iter, 1), device=device, dtype=dtype)}
801872

802873
results = model(state)
@@ -816,9 +887,7 @@ def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]:
816887

817888
# Get final results
818889
model.compute_stress = True
819-
final_results = model(
820-
positions=state.positions, cell=state.cell, atomic_numbers=state.atomic_numbers
821-
)
890+
final_results = model(state)
822891

823892
final_energy = final_results["energy"].item()
824893
final_stress = final_results["stress"]
@@ -827,4 +896,4 @@ def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]:
827896
f"Final energy: {final_energy:.4f} eV, "
828897
f"Final pressure: {final_pressure:.4f} eV/A^3"
829898
)
830-
return state, logger, final_energy, final_pressure
899+
return state, logger

0 commit comments

Comments
 (0)