Skip to content

Commit 317985c

Browse files
authored
Rename batch to system (#217)
1 parent 2bb3bd5 commit 317985c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1432
-1281
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ coverage.xml
3434

3535
# env
3636
uv.lock
37+
38+
# IDE
39+
.vscode/

examples/scripts/1_Introduction/1.2_MACE.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,38 +63,38 @@
6363
cell = torch.tensor(cell_numpy, device=device, dtype=dtype)
6464
atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int)
6565

66-
# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
67-
atoms_per_batch = torch.tensor(
66+
# create system idx array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
67+
atoms_per_system = torch.tensor(
6868
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
6969
)
70-
batch = torch.repeat_interleave(
71-
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
70+
system_idx = torch.repeat_interleave(
71+
torch.arange(len(atoms_per_system), device=device), atoms_per_system
7272
)
7373

7474
# You can see their shapes are as expected
7575
print(f"Positions: {positions.shape}")
7676
print(f"Cell: {cell.shape}")
7777
print(f"Atomic numbers: {atomic_numbers.shape}")
78-
print(f"Batch: {batch.shape}")
78+
print(f"System indices: {system_idx.shape}")
7979

8080
# Now we can pass them to the model
8181
results = batched_model(
8282
dict(
8383
positions=positions,
8484
cell=cell,
8585
atomic_numbers=atomic_numbers,
86-
batch=batch,
86+
system_idx=system_idx,
8787
pbc=True,
8888
)
8989
)
9090

91-
# The energy has shape (n_batches,) as the structures in a batch
91+
# The energy has shape (n_systems,) as the structures in a batch
9292
print(f"Energy: {results['energy'].shape}")
9393

9494
# The forces have shape (n_atoms, 3) same as positions
9595
print(f"Forces: {results['forces'].shape}")
9696

97-
# The stress has shape (n_batches, 3, 3) same as cell
97+
# The stress has shape (n_systems, 3, 3) same as cell
9898
print(f"Stress: {results['stress'].shape}")
9999

100100
# Check if the energy, forces, and stress are the same for the Si system across the batch

examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,20 @@
9393
masses_numpy = np.concatenate([atoms.get_masses() for atoms in atoms_list])
9494
masses = torch.tensor(masses_numpy, device=device, dtype=dtype)
9595
96-
# Create batch indices tensor for scatter operations
97-
atoms_per_batch = torch.tensor(
96+
# Create system indices tensor for scatter operations
97+
atoms_per_system = torch.tensor(
9898
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
9999
)
100-
batch_indices = torch.repeat_interleave(
101-
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
100+
system_indices = torch.repeat_interleave(
101+
torch.arange(len(atoms_per_system), device=device), atoms_per_system
102102
)
103103
"""
104104

105105
state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
106106

107107
print(f"Positions shape: {state.positions.shape}")
108108
print(f"Cell shape: {state.cell.shape}")
109-
print(f"Batch indices shape: {state.batch.shape}")
109+
print(f"System indices shape: {state.system_idx.shape}")
110110

111111
# Run initial inference
112112
results = batched_model(state)

examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
# Initialize unit cell gradient descent optimizer
8686
gd_init, gd_update = unit_cell_gradient_descent(
8787
model=model,
88-
cell_factor=None, # Will default to atoms per batch
88+
cell_factor=None, # Will default to atoms per system
8989
hydrostatic_strain=False,
9090
constant_volume=False,
9191
scalar_pressure=0.0,

examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
# Initialize unit cell gradient descent optimizer
8282
fire_init, fire_update = unit_cell_fire(
8383
model=model,
84-
cell_factor=None, # Will default to atoms per batch
84+
cell_factor=None, # Will default to atoms per system
8585
hydrostatic_strain=False,
8686
constant_volume=False,
8787
scalar_pressure=0.0,

examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
# Initialize unit cell gradient descent optimizer
8181
fire_init, fire_update = ts.optimizers.frechet_cell_fire(
8282
model=model,
83-
cell_factor=None, # Will default to atoms per batch
83+
cell_factor=None, # Will default to atoms per system
8484
hydrostatic_strain=False,
8585
constant_volume=False,
8686
scalar_pressure=0.0,

examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class HybridSwapMCState(MDState):
8686
hybrid_state = HybridSwapMCState(
8787
**vars(md_state),
8888
last_permutation=torch.zeros(
89-
md_state.n_batches, device=md_state.device, dtype=torch.bool
89+
md_state.n_systems, device=md_state.device, dtype=torch.bool
9090
),
9191
)
9292

examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@
119119
for step in range(N_steps):
120120
if step % 50 == 0:
121121
temp = (
122-
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
122+
calc_kT(
123+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
124+
)
123125
/ Units.temperature
124126
)
125127
pressure = get_pressure(
126128
model(state)["stress"],
127129
calc_kinetic_energy(
128-
masses=state.masses, momenta=state.momenta, batch=state.batch
130+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
129131
),
130132
torch.linalg.det(state.cell),
131133
)
@@ -139,15 +141,15 @@
139141
state = npt_update(state, kT=kT, external_pressure=target_pressure)
140142

141143
temp = (
142-
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
144+
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
143145
/ Units.temperature
144146
)
145147
print(f"Final temperature: {temp.item():.4f}")
146148

147149

148150
stress = model(state)["stress"]
149151
calc_kinetic_energy = calc_kinetic_energy(
150-
masses=state.masses, momenta=state.momenta, batch=state.batch
152+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
151153
)
152154
volume = torch.linalg.det(state.cell)
153155
pressure = get_pressure(stress, calc_kinetic_energy, volume)

examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@
6868
for step in range(N_steps_nvt):
6969
if step % 10 == 0:
7070
temp = (
71-
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
71+
calc_kT(
72+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
73+
)
7274
/ Units.temperature
7375
)
7476
invariant = float(nvt_nose_hoover_invariant(state, kT=kT))
@@ -83,7 +85,9 @@
8385
for step in range(N_steps_npt):
8486
if step % 10 == 0:
8587
temp = (
86-
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
88+
calc_kT(
89+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
90+
)
8791
/ Units.temperature
8892
)
8993
stress = model(state)["stress"]
@@ -92,7 +96,9 @@
9296
get_pressure(
9397
stress,
9498
calc_kinetic_energy(
95-
masses=state.masses, momenta=state.momenta, batch=state.batch
99+
masses=state.masses,
100+
momenta=state.momenta,
101+
system_idx=state.system_idx,
96102
),
97103
volume,
98104
).item()
@@ -107,7 +113,7 @@
107113
state = npt_update(state, kT=kT, external_pressure=target_pressure)
108114

109115
final_temp = (
110-
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
116+
calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx)
111117
/ Units.temperature
112118
)
113119
print(f"Final temperature: {final_temp.item():.4f} K")
@@ -117,7 +123,7 @@
117123
get_pressure(
118124
final_stress,
119125
calc_kinetic_energy(
120-
masses=state.masses, momenta=state.momenta, batch=state.batch
126+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
121127
),
122128
final_volume,
123129
).item()

examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
start_time = time.perf_counter()
7878
for step in range(N_steps):
7979
total_energy = state.energy + calc_kinetic_energy(
80-
masses=state.masses, momenta=state.momenta, batch=state.batch
80+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
8181
)
8282
if step % 10 == 0:
8383
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")

0 commit comments

Comments
 (0)