Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ coverage.xml

# env
uv.lock

# IDE
.vscode/
16 changes: 8 additions & 8 deletions examples/scripts/1_Introduction/1.2_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,38 @@
cell = torch.tensor(cell_numpy, device=device, dtype=dtype)
atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int)

# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
atoms_per_batch = torch.tensor(
# create graph index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
atoms_per_graph = torch.tensor(
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
)
batch = torch.repeat_interleave(
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
graph_idx = torch.repeat_interleave(
torch.arange(len(atoms_per_graph), device=device), atoms_per_graph
)

# You can see their shapes are as expected
print(f"Positions: {positions.shape}")
print(f"Cell: {cell.shape}")
print(f"Atomic numbers: {atomic_numbers.shape}")
print(f"Batch: {batch.shape}")
print(f"Graph indices: {graph_idx.shape}")

# Now we can pass them to the model
results = batched_model(
dict(
positions=positions,
cell=cell,
atomic_numbers=atomic_numbers,
batch=batch,
graph_idx=graph_idx,
pbc=True,
)
)

# The energy has shape (n_batches,) as the structures in a batch
# The energy has shape (n_graphs,) as the structures in a batch
print(f"Energy: {results['energy'].shape}")

# The forces have shape (n_atoms, 3) same as positions
print(f"Forces: {results['forces'].shape}")

# The stress has shape (n_batches, 3, 3) same as cell
# The stress has shape (n_graphs, 3, 3) same as cell
print(f"Stress: {results['stress'].shape}")

# Check if the energy, forces, and stress are the same for the Si system across the batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,20 @@
masses_numpy = np.concatenate([atoms.get_masses() for atoms in atoms_list])
masses = torch.tensor(masses_numpy, device=device, dtype=dtype)

# Create batch indices tensor for scatter operations
atoms_per_batch = torch.tensor(
# Create graph indices tensor for scatter operations
atoms_per_graph = torch.tensor(
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
)
batch_indices = torch.repeat_interleave(
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
graph_indices = torch.repeat_interleave(
torch.arange(len(atoms_per_graph), device=device), atoms_per_graph
)
"""

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)

print(f"Positions shape: {state.positions.shape}")
print(f"Cell shape: {state.cell.shape}")
print(f"Batch indices shape: {state.batch.shape}")
print(f"Graph indices shape: {state.graph_idx.shape}")

# Run initial inference
results = batched_model(state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
# Initialize unit cell gradient descent optimizer
gd_init, gd_update = unit_cell_gradient_descent(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per graph
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# Initialize unit cell gradient descent optimizer
fire_init, fire_update = unit_cell_fire(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per graph
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# Initialize unit cell gradient descent optimizer
fire_init, fire_update = ts.optimizers.frechet_cell_fire(
model=model,
cell_factor=None, # Will default to atoms per batch
cell_factor=None, # Will default to atoms per graph
hydrostatic_strain=False,
constant_volume=False,
scalar_pressure=0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class HybridSwapMCState(MDState):
hybrid_state = HybridSwapMCState(
**vars(md_state),
last_permutation=torch.zeros(
md_state.n_batches, device=md_state.device, dtype=torch.bool
md_state.n_graphs, device=md_state.device, dtype=torch.bool
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@
for step in range(N_steps):
if step % 50 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
pressure = get_pressure(
model(state)["stress"],
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
),
torch.linalg.det(state.cell),
)
Expand All @@ -139,15 +139,15 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {temp.item():.4f}")


stress = model(state)["stress"]
calc_kinetic_energy = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
volume = torch.linalg.det(state.cell)
pressure = get_pressure(stress, calc_kinetic_energy, volume)
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
for step in range(N_steps_nvt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
invariant = float(nvt_nose_hoover_invariant(state, kT=kT))
Expand All @@ -83,7 +83,7 @@
for step in range(N_steps_npt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
stress = model(state)["stress"]
Expand All @@ -92,7 +92,7 @@
get_pressure(
stress,
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
),
volume,
).item()
Expand All @@ -107,7 +107,7 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f} K")
Expand All @@ -117,7 +117,7 @@
get_pressure(
final_stress,
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
),
final_volume,
).item()
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.2_MACE_NVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
start_time = time.perf_counter()
for step in range(N_steps):
total_energy = state.energy + calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
if step % 10 == 0:
print(f"Step {step}: Total energy: {total_energy.item():.4f} eV")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@
for step in range(N_steps):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"{step=}: Temperature: {temp.item():.4f}")
state = langevin_update(state=state, kT=kT)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f}")
4 changes: 2 additions & 2 deletions examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@
for step in range(N_steps):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
invariant = float(nvt_nose_hoover_invariant(state, kT=kT))
print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}")
state = nvt_update(state=state, kT=kT)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f}")
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_kT(

# Calculate current temperature and save data
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
actual_temps[step] = temp
Expand Down
10 changes: 6 additions & 4 deletions examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@
for step in range(N_steps):
if step % 50 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
invariant = float(
npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure)
)
e_kin = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
pressure = get_pressure(
model(state)["stress"], e_kin, torch.det(state.current_cell)
Expand All @@ -145,14 +145,16 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {temp.item():.4f}")

pressure = get_pressure(
model(state)["stress"],
calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch),
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
),
torch.det(state.current_cell),
)
pressure = pressure.item() / Units.pressure
Expand Down
12 changes: 7 additions & 5 deletions examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
for step in range(N_steps_nvt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
invariant = float(
Expand All @@ -86,7 +86,7 @@
for step in range(N_steps_npt):
if step % 10 == 0:
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
invariant = float(
Expand All @@ -95,7 +95,7 @@
stress = model(state)["stress"]
volume = torch.det(state.current_cell)
e_kin = calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, batch=state.batch
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
)
pressure = float(get_pressure(stress, e_kin, volume))
xx, yy, zz = torch.diag(state.current_cell[0])
Expand All @@ -107,15 +107,17 @@
state = npt_update(state, kT=kT, external_pressure=target_pressure)

final_temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)
print(f"Final temperature: {final_temp.item():.4f}")
final_stress = model(state)["stress"]
final_volume = torch.det(state.current_cell)
final_pressure = get_pressure(
final_stress,
calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch),
calc_kinetic_energy(
masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx
),
final_volume,
)
print(f"Final pressure: {final_pressure.item():.4f}")
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype)
for step in range(N_steps):
temp = (
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch)
calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx)
/ Units.temperature
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/4_High_level_api/4.2_auto_batching_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
all_completed_states, convergence_tensor, state = [], None, None
while (result := batcher.next_batch(state, convergence_tensor))[0] is not None:
state, completed_states = result
print(f"Starting new batch of {state.n_batches} states.")
print(f"Starting new batch of {state.n_graphs} states.")

all_completed_states.extend(completed_states)
print("Total number of completed states", len(all_completed_states))
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/5_Workflow/5.2_In_Flight_WBM.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
all_completed_states, convergence_tensor, state = [], None, None
while (result := batcher.next_batch(state, convergence_tensor))[0] is not None:
state, completed_states = result
print(f"Starting new batch of {state.n_batches} states.")
print(f"Starting new batch of {state.n_graphs} states.")

all_completed_states.extend(completed_states)
print("Total number of completed states", len(all_completed_states))
Expand Down
Loading
Loading