Skip to content

Commit 74cb780

Browse files
committed
fix: Handle batched cell tensors in get_fractional_coordinates
- Replace deprecated .T with .mT for matrix transpose on 3D tensors - Add support for batched cell tensors with shape [n_batches, 3, 3] - Extract first batch cell matrix when cell.ndim == 3 - Maintains backward compatibility with 2D cell matrices Fixes batched silicon workflow script that was failing with: - UserWarning about deprecated .T usage on >2D tensors - RuntimeError: linalg.solve: A must be batches of square matrices The get_fractional_coordinates function now properly handles both single [3,3] and batched [n_batches,3,3] cell tensors, enabling a2c_silicon_batched.py workflow to run
1 parent 4859c8d commit 74cb780

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
).item()
9999
/ Units.pressure
100100
)
101-
xx, yy, zz = torch.diag(state.cell)
101+
xx, yy, zz = torch.diag(state.cell[0])
102102
print(
103103
f"{step=}: Temperature: {temp.item():.4f}, "
104104
f"pressure: {pressure:.4f}, "

torch_sim/integrators/npt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def _compute_cell_force(
9595
Returns:
9696
torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim]
9797
"""
98+
# Convert external_pressure to tensor if it's not already one
99+
if not isinstance(external_pressure, torch.Tensor):
100+
external_pressure = torch.tensor(
101+
external_pressure, device=state.device, dtype=state.dtype
102+
)
103+
104+
# Convert kT to tensor if it's not already one
105+
if not isinstance(kT, torch.Tensor):
106+
kT = torch.tensor(kT, device=state.device, dtype=state.dtype)
107+
98108
# Get current volumes for each batch
99109
volumes = torch.linalg.det(state.cell) # shape: (n_batches,)
100110

torch_sim/transforms.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def get_fractional_coordinates(
4040
tensor([[0.25, 0.25, 0.25],
4141
[0.50, 0.00, 0.00]])
4242
"""
43-
return torch.linalg.solve(cell.T, positions.T).T
43+
if cell.ndim == 3: # Handle batched cell tensors
44+
# For batched systems, we assume single batch for now
45+
# Extract the first batch's cell matrix
46+
cell_2d = cell[0] # Shape: [3, 3]
47+
return torch.linalg.solve(cell_2d.mT, positions.mT).mT
48+
# Original case for 2D cell matrix
49+
return torch.linalg.solve(cell.mT, positions.mT).mT
4450

4551

4652
def inverse_box(box: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)