From 0aa799128dceb2a1920d13868923f9b630313dcc Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 7 Jul 2025 07:37:34 -0700 Subject: [PATCH 1/2] Rename batch to graph add .vscode to gitignore rename per_batch to per_graph wip convert many batches to graphs convert more batches to graphs rename all of batch indices rename all of batch indices wip fix tests fix more tests more test renaming more test renaming again fix test_deform_grad_batched rename more batch to use graphs update mock state to use graphs precommit fix bug rename in calculate_momenta fix more precommit rename more batch to graphs more renames more renaming fix deprecated text in docstring added test for deprecated batch properties in simstate rename batch to graph more renames cleanup minor typo --- .gitignore | 3 + examples/scripts/1_Introduction/1.2_MACE.py | 16 +- .../2.3_MACE_Gradient_Descent.py | 10 +- ....5_MACE_UnitCellFilter_Gradient_Descent.py | 2 +- .../2.6_MACE_UnitCellFilter_FIRE.py | 2 +- .../2.7_MACE_FrechetCellFilter_FIRE.py | 2 +- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- .../3.11_Lennard_Jones_NPT_Langevin.py | 8 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 10 +- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 2 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 2 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 2 +- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 4 +- .../3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py | 4 +- .../3.6_MACE_NVT_Nose_Hoover_temp_profile.py | 2 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 10 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 12 +- .../3.9_MACE_NVT_staggered_stress.py | 2 +- .../4_High_level_api/4.2_auto_batching_api.py | 2 +- .../scripts/5_Workflow/5.2_In_Flight_WBM.py | 2 +- .../7_Others/7.3_Batched_neighbor_list.py | 8 +- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 26 +- examples/tutorials/autobatching_tutorial.py | 6 +- examples/tutorials/high_level_tutorial.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- examples/tutorials/low_level_tutorial.py | 2 +- examples/tutorials/state_tutorial.py | 22 +- tests/models/conftest.py | 4 +- tests/test_autobatching.py | 28 +- tests/test_correlations.py | 10 +- tests/test_integrators.py | 14 +- tests/test_io.py | 24 +- tests/test_monte_carlo.py | 28 +- tests/test_neighbors.py | 12 +- tests/test_optimizers.py | 28 +- tests/test_runners.py | 48 +- tests/test_state.py | 154 +++--- tests/test_trajectory.py | 50 +- tests/test_transforms.py | 92 ++-- torch_sim/autobatching.py | 36 +- torch_sim/integrators/md.py | 42 +- torch_sim/integrators/npt.py | 438 +++++++++-------- torch_sim/integrators/nve.py | 12 +- torch_sim/integrators/nvt.py | 64 +-- torch_sim/io.py | 102 ++-- torch_sim/math.py | 2 +- torch_sim/models/fairchem.py | 8 +- torch_sim/models/graphpes.py | 4 +- torch_sim/models/interface.py | 18 +- torch_sim/models/lennard_jones.py | 18 +- torch_sim/models/mace.py | 54 +- torch_sim/models/metatomic.py | 4 +- torch_sim/models/morse.py | 10 +- torch_sim/models/orb.py | 6 +- torch_sim/models/particle_life.py | 8 +- torch_sim/models/sevennet.py | 6 +- torch_sim/models/soft_sphere.py | 30 +- torch_sim/monte_carlo.py | 68 +-- torch_sim/neighbors.py | 6 +- torch_sim/optimizers.py | 460 +++++++++--------- torch_sim/quantities.py | 54 +- torch_sim/runners.py | 16 +- torch_sim/state.py | 399 ++++++++------- torch_sim/trajectory.py | 44 +- torch_sim/transforms.py | 82 ++-- torch_sim/typing.py | 2 +- torch_sim/workflows/a2c.py | 4 +- 67 files changed, 1373 insertions(+), 1283 deletions(-) diff --git a/.gitignore b/.gitignore index 9c028c81..2a0bbdf2 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ coverage.xml # env uv.lock + +# IDE +.vscode/ diff --git a/examples/scripts/1_Introduction/1.2_MACE.py b/examples/scripts/1_Introduction/1.2_MACE.py index e417ed52..c7a1223c 100644 --- a/examples/scripts/1_Introduction/1.2_MACE.py +++ b/examples/scripts/1_Introduction/1.2_MACE.py @@ -63,19 +63,19 @@ cell = torch.tensor(cell_numpy, device=device, dtype=dtype) atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int) -# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms -atoms_per_batch = torch.tensor( +# create graph index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms +atoms_per_graph = torch.tensor( [len(atoms) for atoms in atoms_list], device=device, dtype=torch.int ) -batch = torch.repeat_interleave( - torch.arange(len(atoms_per_batch), device=device), atoms_per_batch +graph_idx = torch.repeat_interleave( + torch.arange(len(atoms_per_graph), device=device), atoms_per_graph ) # You can see their shapes are as expected print(f"Positions: {positions.shape}") print(f"Cell: {cell.shape}") print(f"Atomic numbers: {atomic_numbers.shape}") -print(f"Batch: {batch.shape}") +print(f"Graph indices: {graph_idx.shape}") # Now we can pass them to the model results = batched_model( @@ -83,18 +83,18 @@ positions=positions, cell=cell, atomic_numbers=atomic_numbers, - batch=batch, + graph_idx=graph_idx, pbc=True, ) ) -# The energy has shape (n_batches,) as the structures in a batch +# The energy has shape (n_graphs,) as the structures in a batch print(f"Energy: {results['energy'].shape}") # The forces have shape (n_atoms, 3) same as positions print(f"Forces: {results['forces'].shape}") -# The stress has shape (n_batches, 3, 3) same as cell +# The stress has shape (n_graphs, 3, 3) same as cell print(f"Stress: {results['stress'].shape}") # Check if the energy, forces, and stress are the same for the Si system across the batch diff --git a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py index 0ae9f9e0..35caeb74 100644 --- a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py @@ -93,12 +93,12 @@ masses_numpy = np.concatenate([atoms.get_masses() for atoms in atoms_list]) masses = torch.tensor(masses_numpy, device=device, dtype=dtype) -# Create batch indices tensor for scatter operations -atoms_per_batch = torch.tensor( +# Create graph indices tensor for scatter operations +atoms_per_graph = torch.tensor( [len(atoms) for atoms in atoms_list], device=device, dtype=torch.int ) -batch_indices = torch.repeat_interleave( - torch.arange(len(atoms_per_batch), device=device), atoms_per_batch +graph_indices = torch.repeat_interleave( + torch.arange(len(atoms_per_graph), device=device), atoms_per_graph ) """ @@ -106,7 +106,7 @@ print(f"Positions shape: {state.positions.shape}") print(f"Cell shape: {state.cell.shape}") -print(f"Batch indices shape: {state.batch.shape}") +print(f"Graph indices shape: {state.graph_idx.shape}") # Run initial inference results = batched_model(state) diff --git a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py index f92a7185..cb6aff34 100644 --- a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py @@ -85,7 +85,7 @@ # Initialize unit cell gradient descent optimizer gd_init, gd_update = unit_cell_gradient_descent( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per graph hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py index 8125e296..22614313 100644 --- a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py @@ -81,7 +81,7 @@ # Initialize unit cell gradient descent optimizer fire_init, fire_update = unit_cell_fire( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per graph hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py index 1799ba13..99ed4e8e 100644 --- a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py @@ -80,7 +80,7 @@ # Initialize unit cell gradient descent optimizer fire_init, fire_update = ts.optimizers.frechet_cell_fire( model=model, - cell_factor=None, # Will default to atoms per batch + cell_factor=None, # Will default to atoms per graph hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 38278198..447a85a5 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -86,7 +86,7 @@ class HybridSwapMCState(MDState): hybrid_state = HybridSwapMCState( **vars(md_state), last_permutation=torch.zeros( - md_state.n_batches, device=md_state.device, dtype=torch.bool + md_state.n_graphs, device=md_state.device, dtype=torch.bool ), ) diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 4932d432..ec4b821b 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -119,13 +119,13 @@ for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) pressure = get_pressure( model(state)["stress"], calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ), torch.linalg.det(state.cell), ) @@ -139,7 +139,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") @@ -147,7 +147,7 @@ stress = model(state)["stress"] calc_kinetic_energy = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) volume = torch.linalg.det(state.cell) pressure = get_pressure(stress, calc_kinetic_energy, volume) diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index a02ecaff..2b675c3a 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -68,7 +68,7 @@ for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) @@ -83,7 +83,7 @@ for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) stress = model(state)["stress"] @@ -92,7 +92,7 @@ get_pressure( stress, calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ), volume, ).item() @@ -107,7 +107,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f} K") @@ -117,7 +117,7 @@ get_pressure( final_stress, calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ), final_volume, ).item() diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index c7ec1273..2171e26f 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -77,7 +77,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index f9e6bf53..99801376 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -88,7 +88,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index f09b5bc0..240ab7dc 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -71,7 +71,7 @@ start_time = time.perf_counter() for step in range(N_steps): total_energy = state.energy + calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index 3ef1b93f..1a498bbf 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -83,14 +83,14 @@ for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"{step=}: Temperature: {temp.item():.4f}") state = langevin_update(state=state, kT=kT) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py index d1c88893..b9d2a2f5 100644 --- a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py @@ -68,7 +68,7 @@ for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) @@ -76,7 +76,7 @@ state = nvt_update(state=state, kT=kT) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py index a94337f1..2739a993 100644 --- a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py +++ b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py @@ -163,7 +163,7 @@ def get_kT( # Calculate current temperature and save data temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) actual_temps[step] = temp diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index e85379d9..398b8191 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -123,14 +123,14 @@ for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) invariant = float( npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) e_kin = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) pressure = get_pressure( model(state)["stress"], e_kin, torch.det(state.current_cell) @@ -145,14 +145,16 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") pressure = get_pressure( model(state)["stress"], - calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch), + calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx + ), torch.det(state.current_cell), ) pressure = pressure.item() / Units.pressure diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index bfa45313..b6b814ca 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -69,7 +69,7 @@ for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) invariant = float( @@ -86,7 +86,7 @@ for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) invariant = float( @@ -95,7 +95,7 @@ stress = model(state)["stress"] volume = torch.det(state.current_cell) e_kin = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) pressure = float(get_pressure(stress, e_kin, volume)) xx, yy, zz = torch.diag(state.current_cell[0]) @@ -107,7 +107,7 @@ state = npt_update(state, kT=kT, external_pressure=target_pressure) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") @@ -115,7 +115,9 @@ final_volume = torch.det(state.current_cell) final_pressure = get_pressure( final_stress, - calc_kinetic_energy(masses=state.masses, momenta=state.momenta, batch=state.batch), + calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx + ), final_volume, ) print(f"Final pressure: {final_pressure.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index 9d7e02a6..39b974a4 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -65,7 +65,7 @@ stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) for step in range(N_steps): temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) + calc_kT(masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx) / Units.temperature ) diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index 7a3387ed..aeccaecf 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -76,7 +76,7 @@ all_completed_states, convergence_tensor, state = [], None, None while (result := batcher.next_batch(state, convergence_tensor))[0] is not None: state, completed_states = result - print(f"Starting new batch of {state.n_batches} states.") + print(f"Starting new batch of {state.n_graphs} states.") all_completed_states.extend(completed_states) print("Total number of completed states", len(all_completed_states)) diff --git a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py index 78a1019f..dd37392b 100644 --- a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py +++ b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py @@ -83,7 +83,7 @@ all_completed_states, convergence_tensor, state = [], None, None while (result := batcher.next_batch(state, convergence_tensor))[0] is not None: state, completed_states = result - print(f"Starting new batch of {state.n_batches} states.") + print(f"Starting new batch of {state.n_graphs} states.") all_completed_states.extend(completed_states) print("Total number of completed states", len(all_completed_states)) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index a21ddeaf..ae5064aa 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -18,15 +18,15 @@ atoms_list = [bulk("Si", "diamond", a=5.43), bulk("Ge", "diamond", a=5.65)] state = ts.io.atoms_to_state(atoms_list, device="cpu", dtype=torch.float32) pos, cell, pbc = state.positions, state.cell, state.pbc -batch, n_atoms = state.batch, state.n_atoms +graph_idx, n_atoms = state.graph_idx, state.n_atoms cutoff = 4.0 self_interaction = False -# Fix: Ensure pbc has the correct shape [n_batches, 3] +# Fix: Ensure pbc has the correct shape [n_graphs, 3] pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) mapping, mapping_batch, shifts_idx = torch_nl_linked_cell( - cutoff, pos, cell, pbc_tensor, batch, self_interaction + cutoff, pos, cell, pbc_tensor, graph_idx, self_interaction ) cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) @@ -38,7 +38,7 @@ print(dds.shape) mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2( - cutoff, pos, cell, pbc_tensor, batch, self_interaction + cutoff, pos, cell, pbc_tensor, graph_idx, self_interaction ) cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2) dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 82fdf8f0..33b06999 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -141,7 +141,7 @@ def run_optimization_ts( # noqa: PLR0915 start_time = time.perf_counter() print("Initial cell parameters (Torch-Sim):") - for k_idx in range(initial_state.n_batches): + for k_idx in range(initial_state.n_graphs): cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) @@ -168,7 +168,7 @@ def run_optimization_ts( # noqa: PLR0915 ) batcher.load_states(opt_state) - total_structures = opt_state.n_batches + total_structures = opt_state.n_graphs convergence_steps = torch.full( (total_structures,), -1, dtype=torch.long, device=device ) @@ -219,7 +219,7 @@ def run_optimization_ts( # noqa: PLR0915 if global_step % 50 == 0: total_converged_frac = converged_tensor_global.sum().item() / total_structures - active_structures = opt_state.n_batches if opt_state else 0 + active_structures = opt_state.n_graphs if opt_state else 0 print( f"{global_step=}: Active structures: {active_structures}, " f"Total converged: {total_converged_frac:.2%}" @@ -230,7 +230,7 @@ def run_optimization_ts( # noqa: PLR0915 if final_state_concatenated is not None and hasattr(final_state_concatenated, "cell"): print("Final cell parameters (Torch-Sim):") - for k_idx in range(final_state_concatenated.n_batches): + for k_idx in range(final_state_concatenated.n_graphs): cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) @@ -329,12 +329,12 @@ def run_optimization_ase( # noqa: C901, PLR0915 all_masses = [] all_atomic_numbers = [] all_cells = [] - all_batches_for_gd = [] + all_graphs_for_gd = [] final_energies_ase = [] final_forces_ase_tensors = [] current_atom_offset = 0 - for batch_idx, ats_final in enumerate(final_ase_atoms_list): + for graph_idx, ats_final in enumerate(final_ase_atoms_list): all_positions.append( torch.tensor(ats_final.get_positions(), device=device, dtype=dtype) ) @@ -350,9 +350,9 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) num_atoms_in_current = len(ats_final) - all_batches_for_gd.append( + all_graphs_for_gd.append( torch.full( - (num_atoms_in_current,), batch_idx, device=device, dtype=torch.long + (num_atoms_in_current,), graph_idx, device=device, dtype=torch.long ) ) current_atom_offset += num_atoms_in_current @@ -361,7 +361,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 if ats_final.calc is None: print( "Re-attaching ASE calculator for final energy/forces for " - f"structure {batch_idx}." + f"structure {graph_idx}." ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, @@ -375,7 +375,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) except Exception as e: # noqa: BLE001 print( - f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}" + f"Could not get final energy/forces for an ASE structure {graph_idx}: {e}" ) final_energies_ase.append(float("nan")) if all_positions and len(all_positions[-1]) > 0: @@ -393,8 +393,8 @@ def run_optimization_ase( # noqa: C901, PLR0915 concatenated_positions = torch.cat(all_positions, dim=0) concatenated_masses = torch.cat(all_masses, dim=0) concatenated_atomic_numbers = torch.cat(all_atomic_numbers, dim=0) - concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (N_batch, 3, 3) - concatenated_batch_indices = torch.cat(all_batches_for_gd, dim=0) + concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (n_graphs, 3, 3) + concatenated_graph_indices = torch.cat(all_graphs_for_gd, dim=0) concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) @@ -413,7 +413,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 cell=concatenated_cells, pbc=initial_state.pbc, atomic_numbers=concatenated_atomic_numbers, - batch=concatenated_batch_indices, + graph_idx=concatenated_graph_indices, energy=concatenated_energies, forces=concatenated_forces, ) diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index c1647190..03e44343 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -249,7 +249,7 @@ def process_batch(batch): fire_state.positions = ( fire_state.positions + torch.randn_like(fire_state.positions) * 0.05 ) -total_states = fire_state.n_batches +total_states = fire_state.n_graphs # Define a convergence function that checks the force on each atom is less than 5e-1 convergence_fn = ts.generate_force_convergence_fn(5e-1) @@ -279,11 +279,11 @@ def process_batch(batch): assert len(final_states) == total_states # Note that the fire_state has been modified in place -assert fire_state.n_batches == 0 +assert fire_state.n_graphs == 0 # %% -fire_state.n_batches +fire_state.n_graphs # %% [markdown] diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index 9d6c4ab4..8fa6990b 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -375,7 +375,7 @@ def mock_determine_max_batch_size(*args, **kwargs): convergence function are `state` and `last_energy`. The `state` is a `SimState` object that contains the current state of the system and the `last_energy` is the energy of the previous step. The convergence function should return a boolean tensor of length -`n_batches`. +`n_graphs`. This is how we'd manually define the default `convergence_fn`: """ diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 28c73fc0..0db6f0d6 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -133,7 +133,7 @@ class HybridSwapMCState(ts.integrators.MDState): hybrid_state = HybridSwapMCState( **vars(md_state), last_permutation=torch.zeros( - md_state.n_batches, device=md_state.device, dtype=torch.bool + md_state.n_graphs, device=md_state.device, dtype=torch.bool ), ) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index f12893c3..fc6e6e5d 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -228,7 +228,7 @@ state = nvt_langevin_update_fn(state=state, kT=initial_kT * (1 + step / 30)) if step % 5 == 0: temp_E_units = ts.calc_kT( - masses=state.masses, momenta=state.momenta, batch=state.batch + masses=state.masses, momenta=state.momenta, graph_idx=state.graph_idx ) temp = temp_E_units / MetalUnits.temperature print(f"{step=}: Temperature: {temp}") diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 1abf5d28..59e3b0f6 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -31,7 +31,7 @@ * Unit cell parameters * Periodic boundary conditions * Atomic numbers (elements) -* Batch indices (for processing multiple systems simultaneously) +* Graph indices (for processing multiple systems simultaneously) """ @@ -58,7 +58,7 @@ # Convert to SimState si_state = ts.initialize_state(si_atoms, device=torch.device("cpu"), dtype=torch.float64) -print(f"State has {si_state.n_atoms} atoms and {si_state.n_batches} batches") +print(f"State has {si_state.n_atoms} atoms and {si_state.n_graphs} graphs") # here we print all the attributes of the SimState print(f"Positions shape: {si_state.positions.shape}") @@ -66,7 +66,7 @@ print(f"Atomic numbers shape: {si_state.atomic_numbers.shape}") print(f"Masses shape: {si_state.masses.shape}") print(f"PBC: {si_state.pbc}") -print(f"Batch indices shape: {si_state.batch.shape}") +print(f"Graph indices shape: {si_state.graph_idx.shape}") # %% [markdown] @@ -75,7 +75,7 @@ * Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`, `masses`, `atomic_numbers`, and `batch`. Names are plural. -* Batchwise attributes are tensors with shape (n_batches, ...), this is just `cell` for +* Batchwise attributes are tensors with shape (n_graphs, ...), this is just `cell` for the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. @@ -109,14 +109,14 @@ ) print( - f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_batches} batches" + f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_graphs} graphs" ) # we can see how the shapes of batchwise, atomwise, and global properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") -print(f"Batch indices shape: {multi_state.batch.shape}") +print(f"Graph indices shape: {multi_state.graph_idx.shape}") # %% [markdown] @@ -148,18 +148,18 @@ # %% we can copy the state with the clone method multi_state_copy = multi_state.clone() -print(f"This state has {multi_state_copy.n_batches} batches") +print(f"This state has {multi_state_copy.n_graphs} graphs") # we can pop states off while modifying the original state popped_states = multi_state_copy.pop([0, 2]) print( f"We popped {len(popped_states)} states, leaving us with " - f"{multi_state_copy.n_batches} batch in the original state" + f"{multi_state_copy.n_graphs} graphs in the original state" ) # we can put them back together with concatenate multi_state_full = ts.concatenate_states([*popped_states, multi_state_copy]) -print(f"Again we have {multi_state_full.n_batches} batches in the full state") +print(f"Again we have {multi_state_full.n_graphs} graphs in the full state") # or if we don't want to modify the original state, we can instead index into it # negative indexing @@ -253,14 +253,14 @@ **asdict(si_state), # Copy all SimState properties momenta=torch.zeros_like(si_state.positions), # Initial 0 momenta forces=torch.zeros_like(si_state.positions), # Initial 0 forces - energy=torch.zeros((si_state.n_batches,), device=si_state.device), # Initial 0 energy + energy=torch.zeros((si_state.n_graphs,), device=si_state.device), # Initial 0 energy ) print("MDState properties:") scope = infer_property_scope(md_state) print("Global properties:", scope["global"]) print("Per-atom properties:", scope["per_atom"]) -print("Per-batch properties:", scope["per_batch"]) +print("Per-graph properties:", scope["per_graph"]) # %% [markdown] diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 923f34f7..b6dd0658 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -169,7 +169,7 @@ def test_model_output_validation( og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() - og_batch = sim_state.batch.clone() + og_batch = sim_state.graph_idx.clone() og_atomic_numbers = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -177,7 +177,7 @@ def test_model_output_validation( # assert model did not mutate the input assert torch.allclose(og_positions, sim_state.positions) assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.batch) + assert torch.allclose(og_batch, sim_state.graph_idx) assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) # assert model output has the correct keys diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 5d15d4a1..2481dba0 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -115,10 +115,10 @@ def test_split_state(si_double_sim_state: ts.SimState) -> None: # Check each state has the correct properties for state in enumerate(split_states): - assert state[1].n_batches == 1 + assert state[1].n_graphs == 1 assert torch.all( - state[1].batch == 0 - ) # Each split state should have batch indices reset to 0 + state[1].graph_idx == 0 + ) # Each split state should have graph indices reset to 0 assert state[1].n_atoms == si_double_sim_state.n_atoms // 2 assert state[1].positions.shape[0] == si_double_sim_state.n_atoms // 2 assert state[1].cell.shape[0] == 1 @@ -472,14 +472,14 @@ def test_in_flight_with_fire( batcher.load_states(fire_states) def convergence_fn(state: ts.SimState) -> bool: - batch_wise_max_force = torch.zeros( - state.n_batches, device=state.device, dtype=torch.float64 + graph_wise_max_force = torch.zeros( + state.n_graphs, device=state.device, dtype=torch.float64 ) max_forces = state.forces.norm(dim=1) - batch_wise_max_force = batch_wise_max_force.scatter_reduce( - dim=0, index=state.batch, src=max_forces, reduce="amax" + graph_wise_max_force = graph_wise_max_force.scatter_reduce( + dim=0, index=state.graph_idx, src=max_forces, reduce="amax" ) - return batch_wise_max_force < 5e-1 + return graph_wise_max_force < 5e-1 all_completed_states, convergence_tensor = [], None while True: @@ -514,7 +514,7 @@ def test_binning_auto_batcher_with_fire( batch_lengths = [state.n_atoms for state in fire_states] optimal_batches = to_constant_volume_bins(batch_lengths, 400) - optimal_n_batches = len(optimal_batches) + optimal_n_graphs = len(optimal_batches) batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=400 @@ -522,9 +522,9 @@ def test_binning_auto_batcher_with_fire( batcher.load_states(fire_states) finished_states = [] - n_batches = 0 + n_graphs = 0 for batch in batcher: - n_batches += 1 + n_graphs += 1 for _ in range(5): batch = fire_update(batch) @@ -535,7 +535,7 @@ def test_binning_auto_batcher_with_fire( for restored, original in zip(restored_states, fire_states, strict=True): assert torch.all(restored.atomic_numbers == original.atomic_numbers) # analytically determined to be optimal - assert n_batches == optimal_n_batches + assert n_graphs == optimal_n_graphs def test_in_flight_max_iterations( @@ -561,7 +561,7 @@ def test_in_flight_max_iterations( state, [] = batcher.next_batch(None, None) # Create a convergence tensor that never converges - convergence_tensor = torch.zeros(state.n_batches, dtype=torch.bool) + convergence_tensor = torch.zeros(state.n_graphs, dtype=torch.bool) all_completed_states = [] iteration_count = 0 @@ -574,7 +574,7 @@ def test_in_flight_max_iterations( # Update convergence tensor for next iteration (still all False) if state is not None: - convergence_tensor = torch.zeros(state.n_batches, dtype=torch.bool) + convergence_tensor = torch.zeros(state.n_graphs, dtype=torch.bool) if iteration_count > max_attempts + 4: raise ValueError("Should have terminated by now") diff --git a/tests/test_correlations.py b/tests/test_correlations.py index f14612de..a5970f4d 100644 --- a/tests/test_correlations.py +++ b/tests/test_correlations.py @@ -31,12 +31,14 @@ def __init__(self, velocities: torch.Tensor, device: torch.device) -> None: self.velocities = velocities self.device = device # Required for TrajectoryReporter - self.n_batches = 1 - self.batch = torch.zeros(velocities.shape[0], device=device, dtype=torch.int64) + self.n_graphs = 1 + self.graph_idx = torch.zeros( + velocities.shape[0], device=device, dtype=torch.int64 + ) def split(self) -> list["MockState"]: - """Split state into batches.""" - # Just return self since 1 batch + """Split state into multiple graphs.""" + # Just return self since 1 graph return [self] diff --git a/tests/test_integrators.py b/tests/test_integrators.py index bad80122..6455865c 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -109,7 +109,7 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, graph_idx=state.graph_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -172,7 +172,7 @@ def test_npt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, graph_idx=state.graph_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -213,7 +213,7 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, graph_idx=state.graph_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -273,7 +273,7 @@ def test_nvt_langevin_multi_kt( state = update_fn(state=state) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT(state.momenta, state.masses, batch=state.batch) + temp = calc_kT(state.momenta, state.masses, graph_idx=state.graph_idx) energies.append(state.energy) temperatures.append(temp / MetalUnits.temperature) @@ -372,8 +372,8 @@ def test_compare_single_vs_batched_integrators( torch.testing.assert_close(single_state.energy, final_state.energy) -def test_compute_cell_force_atoms_per_batch(): - """Test that compute_cell_force correctly scales by number of atoms per batch. +def test_compute_cell_force_atoms_per_graph(): + """Test that compute_cell_force correctly scales by number of atoms per graph. Covers fix in https://github.com/Radical-AI/torch-sim/pull/153.""" from torch_sim.integrators.npt import _compute_cell_force @@ -389,7 +389,7 @@ def test_compute_cell_force_atoms_per_batch(): masses=torch.ones(72), cell=torch.eye(3).repeat(2, 1, 1), pbc=True, - batch=torch.cat([s1, s2]), + graph_idx=torch.cat([s1, s2]), atomic_numbers=torch.ones(72, dtype=torch.long), stress=torch.zeros((2, 3, 3)), reference_cell=torch.eye(3).repeat(2, 1, 1), diff --git a/tests/test_io.py b/tests/test_io.py index 84f665d5..eed300b3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -48,9 +48,9 @@ def test_multiple_structures_to_state( assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.graph_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8) + state.graph_idx == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8) ) @@ -65,8 +65,8 @@ def test_single_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert state.cell.shape == (1, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (8,) - assert state.batch.shape == (8,) - assert torch.all(state.batch == 0) + assert state.graph_idx.shape == (8,) + assert torch.all(state.graph_idx == 0) def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: @@ -80,9 +80,10 @@ def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.graph_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + state.graph_idx + == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), ) @@ -171,9 +172,10 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) assert state.cell.shape == (2, 3, 3) assert state.pbc assert state.atomic_numbers.shape == (16,) - assert state.batch.shape == (16,) + assert state.graph_idx.shape == (16,) assert torch.all( - state.batch == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + state.graph_idx + == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), ) @@ -235,11 +237,11 @@ def test_state_round_trip( # Get the sim_state fixture dynamically using the name sim_state: ts.SimState = request.getfixturevalue(sim_state_name) to_format_fn, from_format_fn = conversion_functions - unique_batches = torch.unique(sim_state.batch) + unique_graphs = torch.unique(sim_state.graph_idx) # Convert to intermediate format intermediate_format = to_format_fn(sim_state) - assert len(intermediate_format) == len(unique_batches) + assert len(intermediate_format) == len(unique_graphs) # Convert back to state round_trip_state: ts.SimState = from_format_fn(intermediate_format, device, dtype) @@ -248,7 +250,7 @@ def test_state_round_trip( assert torch.allclose(sim_state.positions, round_trip_state.positions) assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) - assert torch.all(sim_state.batch == round_trip_state.batch) + assert torch.all(sim_state.graph_idx == round_trip_state.graph_idx) assert sim_state.pbc == round_trip_state.pbc if isinstance(intermediate_format[0], Atoms): diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 880adefb..7081635a 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -50,7 +50,7 @@ def test_generate_permutation( ): swaps = generate_swaps(batched_diverse_state, generator=generator) permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation(permutation, batched_diverse_state.batch) + validate_permutation(permutation, batched_diverse_state.graph_idx) def test_generate_swaps(batched_diverse_state: ts.SimState, generator: torch.Generator): @@ -64,9 +64,9 @@ def test_generate_swaps(batched_diverse_state: ts.SimState, generator: torch.Gen assert torch.all(swaps >= 0) assert torch.all(swaps < batched_diverse_state.n_atoms) - # Check swaps are within same batch - batch = batched_diverse_state.batch - assert torch.all(batch[swaps[:, 0]] == batch[swaps[:, 1]]) + # Check swaps are within same graph + graph_idx = batched_diverse_state.graph_idx + assert torch.all(graph_idx[swaps[:, 0]] == graph_idx[swaps[:, 1]]) def test_swaps_to_permutation( @@ -95,7 +95,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): # Valid permutation swaps = generate_swaps(batched_diverse_state) permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation(permutation, batched_diverse_state.batch) # Should not raise + validate_permutation(permutation, batched_diverse_state.graph_idx) # Should not raise # Invalid permutation (swap between batches) invalid_perm = permutation.clone() @@ -105,7 +105,7 @@ def test_validate_permutation(batched_diverse_state: ts.SimState): invalid_perm[batched_diverse_state.n_atoms - 1] = 0 with pytest.raises(ValueError, match="Swaps must be between"): - validate_permutation(invalid_perm, batched_diverse_state.batch) + validate_permutation(invalid_perm, batched_diverse_state.graph_idx) def test_monte_carlo( @@ -147,17 +147,17 @@ def test_monte_carlo( # Verify the state has changed after multiple steps assert not torch.allclose(current_state.positions, initial_positions) - # Verify batch assignments remain unchanged - assert torch.all(current_state.batch == batched_diverse_state.batch) + # Verify graph_idx assignments remain unchanged + assert torch.all(current_state.graph_idx == batched_diverse_state.graph_idx) - # Verify atomic numbers distribution remains the same per batch - for batch_idx in torch.unique(current_state.batch): - batch_mask_orig = batched_diverse_state.batch == batch_idx - batch_mask_result = current_state.batch == batch_idx + # Verify atomic numbers distribution remains the same per graph + for idx in torch.unique(current_state.graph_idx): + graph_mask_orig = batched_diverse_state.graph_idx == idx + graph_mask_result = current_state.graph_idx == idx orig_counts = torch.bincount( - batched_diverse_state.atomic_numbers[batch_mask_orig] + batched_diverse_state.atomic_numbers[graph_mask_orig] ) - result_counts = torch.bincount(current_state.atomic_numbers[batch_mask_result]) + result_counts = torch.bincount(current_state.atomic_numbers[graph_mask_result]) assert torch.all(orig_counts == result_counts) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index eebc7391..e682f642 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -37,7 +37,7 @@ def ase_to_torch_batch( - pos: Tensor of atomic positions. - cell: Tensor of unit cell vectors. - pbc: Tensor indicating periodic boundary conditions. - - batch: Tensor indicating the batch index for each atom. + - graph_idx: Tensor indicating the graph index for each atom. - n_atoms: Tensor containing the number of atoms in each structure. """ n_atoms = torch.tensor([len(atoms) for atoms in atoms_list], dtype=torch.long) @@ -49,17 +49,17 @@ def ase_to_torch_batch( pbc = torch.cat([torch.from_numpy(atoms.get_pbc()) for atoms in atoms_list]) stride = torch.cat((torch.tensor([0]), n_atoms.cumsum(0))) - batch = torch.zeros(pos.shape[0], dtype=torch.long) + graph_idx = torch.zeros(pos.shape[0], dtype=torch.long) for ii, (st, end) in enumerate( zip(stride[:-1], stride[1:], strict=True) # noqa: RUF007 ): - batch[st:end] = ii + graph_idx[st:end] = ii n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long) return ( pos.to(dtype=dtype, device=device), cell.to(dtype=dtype, device=device), pbc.to(device=device), - batch.to(device=device), + graph_idx.to(device=device), n_atoms.to(device=device), ) @@ -556,11 +556,11 @@ def test_neighbor_lists_time_and_memory( start_time = time.perf_counter() if nl_fn in [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell]: - batch = torch.zeros(n_atoms, dtype=torch.long, device=device) + graph_idx = torch.zeros(n_atoms, dtype=torch.long, device=device) # Fix pbc tensor shape pbc = torch.tensor([[True, True, True]], device=device) mapping, mapping_batch, shifts_idx = nl_fn( - cutoff, pos, cell, pbc, batch, self_interaction=False + cutoff, pos, cell, pbc, graph_idx, self_interaction=False ) else: mapping, shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 87ff8697..2959036f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -128,7 +128,7 @@ def test_fire_optimization( cell=ar_supercell_sim_state.cell.clone(), pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + graph_idx=ar_supercell_sim_state.graph_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -229,7 +229,7 @@ def test_fire_ase_negative_power_branch( state = init_fn(ar_supercell_sim_state) # Save parameters from initial state - initial_dt_batch = state.dt.clone() # per-batch dt + initial_dt_batch = state.dt.clone() # per-graph dt # Manipulate state to ensure P < 0 for the update_fn step # Ensure forces are non-trivial @@ -262,10 +262,10 @@ def test_fire_ase_negative_power_branch( # Assertions for velocity update in ASE P < 0 case: # v_after_mixing_is_0, then v_final = dt_new * F_at_power_calc expected_final_velocities = ( - expected_dt_val * forces_at_power_calc[updated_state.batch == 0] + expected_dt_val * forces_at_power_calc[updated_state.graph_idx == 0] ) assert torch.allclose( - updated_state.velocities[updated_state.batch == 0], + updated_state.velocities[updated_state.graph_idx == 0], expected_final_velocities, atol=1e-6, ) @@ -317,8 +317,8 @@ def test_fire_vv_negative_power_branch( # If P<0 branch was taken, velocities should be zeroed assert torch.allclose( - updated_state.velocities[updated_state.batch == 0], - torch.zeros_like(updated_state.velocities[updated_state.batch == 0]), + updated_state.velocities[updated_state.graph_idx == 0], + torch.zeros_like(updated_state.velocities[updated_state.graph_idx == 0]), atol=1e-7, ) @@ -345,7 +345,7 @@ def test_unit_cell_fire_optimization( cell=current_cell, pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + graph_idx=ar_supercell_sim_state.graph_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -429,7 +429,7 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( assert opt_state.forces is not None assert opt_state.stress is not None expected_cf_tensor = torch.full( - (opt_state.n_batches, 1, 1), + (opt_state.n_graphs, 1, 1), float(cell_factor_val), # Ensure float for comparison if int is passed device=lj_model.device, dtype=lj_model.dtype, @@ -452,11 +452,11 @@ def test_cell_optimizer_init_cell_factor_none( ) -> None: """Test cell optimizer init_fn with cell_factor=None.""" init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) - # Ensure n_batches > 0 for cell_factor calculation from counts - assert ar_supercell_sim_state.n_batches > 0 + # Ensure n_graphs > 0 for cell_factor calculation from counts + assert ar_supercell_sim_state.n_graphs > 0 opt_state = init_fn(ar_supercell_sim_state) # Uses ts.SimState directly assert isinstance(opt_state, expected_state_type) - _, counts = torch.unique(ar_supercell_sim_state.batch, return_counts=True) + _, counts = torch.unique(ar_supercell_sim_state.graph_idx, return_counts=True) expected_cf_tensor = counts.to(dtype=lj_model.dtype).view(-1, 1, 1) assert torch.allclose(opt_state.cell_factor, expected_cf_tensor) assert opt_state.energy is not None @@ -525,7 +525,7 @@ def test_frechet_cell_fire_optimization( cell=current_cell, pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), - batch=ar_supercell_sim_state.batch.clone(), + graph_idx=ar_supercell_sim_state.graph_idx.clone(), ) initial_state_positions = current_sim_state.positions.clone() @@ -874,6 +874,6 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: # position only optimizations for step, energy_unit_cell in enumerate(individual_energies_unit_cell): assert abs(energy_unit_cell - individual_energies_fire[step]) < 1e-4, ( - f"Energy for batch {step} doesn't match position only optimization: " - f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" + f"Energy for graph {step} doesn't match position only optimization: " + f"graph={energy_unit_cell}, individual={individual_energies_fire[step]}" ) diff --git a/tests/test_runners.py b/tests/test_runners.py index 5d7201c9..fab55de2 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -149,7 +149,7 @@ def test_integrate_many_nvt( lj_model.dtype, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_graphs) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -231,7 +231,7 @@ def test_integrate_with_autobatcher_and_reporting( max_memory_scaler=260, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_graphs) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -340,7 +340,7 @@ def test_batched_optimize_fire( ) -> None: """Test batched FIRE optimization with LJ potential.""" trajectory_files = [ - tmp_path / f"nvt_{idx}.h5md" for idx in range(ar_double_sim_state.n_batches) + tmp_path / f"nvt_{idx}.h5md" for idx in range(ar_double_sim_state.n_graphs) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -414,7 +414,7 @@ def test_optimize_with_autobatcher_and_reporting( ) trajectory_files = [ - tmp_path / f"opt_{batch}.h5md" for batch in range(triple_state.n_batches) + tmp_path / f"opt_{batch}.h5md" for batch in range(triple_state.n_graphs) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -798,7 +798,7 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: # autobatcher=True, # disabled for CPU-based LJ model in test ) - assert relaxed_state.energy.shape == (final_state.n_batches,) + assert relaxed_state.energy.shape == (final_state.n_graphs,) @pytest.fixture @@ -806,22 +806,20 @@ def mock_state() -> Callable: """Create a mock state for testing convergence functions.""" device = torch.device("cpu") dtype = torch.float64 - n_batches, n_atoms = 2, 8 + n_graphs, n_atoms = 2, 8 torch.manual_seed(0) # deterministic forces class MockState: def __init__(self, *, include_cell_forces: bool = True) -> None: self.forces = torch.randn(n_atoms, 3, device=device, dtype=dtype) - self.batch = torch.repeat_interleave( - torch.arange(n_batches), n_atoms // n_batches + self.graph_idx = torch.repeat_interleave( + torch.arange(n_graphs), n_atoms // n_graphs ) self.device = device self.dtype = dtype - self.n_batches = n_batches + self.n_graphs = n_graphs if include_cell_forces: - self.cell_forces = torch.randn( - n_batches, 3, 3, device=device, dtype=dtype - ) + self.cell_forces = torch.randn(n_graphs, 3, 3, device=device, dtype=dtype) return MockState @@ -858,7 +856,7 @@ def test_generate_force_convergence_fn( if has_cell_forces: ar_supercell_sim_state.cell_forces = torch.randn( - ar_supercell_sim_state.n_batches, + ar_supercell_sim_state.n_graphs, 3, 3, device=ar_supercell_sim_state.device, @@ -877,7 +875,7 @@ def test_generate_force_convergence_fn( result = convergence_fn(state) assert isinstance(result, torch.Tensor) assert result.dtype == torch.bool - assert result.shape == (state.n_batches,) + assert result.shape == (state.n_graphs,) def test_generate_force_convergence_fn_tolerance_ordering( @@ -888,7 +886,7 @@ def test_generate_force_convergence_fn_tolerance_ordering( ar_supercell_sim_state.forces = model_output["forces"] ar_supercell_sim_state.energy = model_output["energy"] ar_supercell_sim_state.cell_forces = torch.randn( - ar_supercell_sim_state.n_batches, + ar_supercell_sim_state.n_graphs, 3, 3, device=ar_supercell_sim_state.device, @@ -926,26 +924,26 @@ def test_generate_force_convergence_fn_logic( ) -> None: """Test convergence logic with controlled force values.""" device, dtype = torch.device("cpu"), torch.float64 - n_batches, n_atoms = len(atomic_forces), 8 + n_graphs, n_atoms = len(atomic_forces), 8 class ControlledMockState: def __init__(self) -> None: - self.n_batches = n_batches + self.n_graphs = n_graphs self.device, self.dtype = device, dtype - self.batch = torch.repeat_interleave( - torch.arange(n_batches), n_atoms // n_batches + self.graph_idx = torch.repeat_interleave( + torch.arange(n_graphs), n_atoms // n_graphs ) - # Set specific force magnitudes per batch + # Set specific force magnitudes per graph self.forces = torch.zeros(n_atoms, 3, device=device, dtype=dtype) - self.cell_forces = torch.zeros(n_batches, 3, 3, device=device, dtype=dtype) + self.cell_forces = torch.zeros(n_graphs, 3, 3, device=device, dtype=dtype) - for batch_idx, (atomic_force, cell_force) in enumerate( + for graph_idx, (atomic_force, cell_force) in enumerate( zip(atomic_forces, cell_forces, strict=False) ): - batch_mask = self.batch == batch_idx - self.forces[batch_mask, 0] = atomic_force - self.cell_forces[batch_idx, 0, 0] = cell_force + graph_mask = self.graph_idx == graph_idx + self.forces[graph_mask, 0] = atomic_force + self.cell_forces[graph_idx, 0, 0] = cell_force state = ControlledMockState() convergence_fn = ts.generate_force_convergence_fn( diff --git a/tests/test_state.py b/tests/test_state.py index 1e5f325b..d3beb68d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -9,7 +9,7 @@ from torch_sim.state import ( DeformGradMixin, SimState, - _normalize_batch_indices, + _normalize_graph_indices, _pop_states, _slice_state, concatenate_states, @@ -28,8 +28,13 @@ def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: """Test inference of property scope.""" scope = infer_property_scope(si_sim_state) assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"} - assert set(scope["per_batch"]) == {"cell"} + assert set(scope["per_atom"]) == { + "positions", + "masses", + "atomic_numbers", + "graph_idx", + } + assert set(scope["per_graph"]) == {"cell"} def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: @@ -46,19 +51,19 @@ def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: "positions", "masses", "atomic_numbers", - "batch", + "graph_idx", "forces", "momenta", } - assert set(scope["per_batch"]) == {"cell", "energy"} + assert set(scope["per_graph"]) == {"cell", "energy"} def test_slice_substate( si_double_sim_state: ts.SimState, si_sim_state: ts.SimState ) -> None: """Test slicing a substate from the SimState.""" - for batch_index in range(2): - substate = _slice_state(si_double_sim_state, [batch_index]) + for graph_index in range(2): + substate = _slice_state(si_double_sim_state, [graph_index]) assert isinstance(substate, SimState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) @@ -67,7 +72,7 @@ def test_slice_substate( assert torch.allclose(substate.masses, si_sim_state.masses) assert torch.allclose(substate.cell, si_sim_state.cell) assert torch.allclose(substate.atomic_numbers, si_sim_state.atomic_numbers) - assert torch.allclose(substate.batch, torch.zeros_like(substate.batch)) + assert torch.allclose(substate.graph_idx, torch.zeros_like(substate.graph_idx)) def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: @@ -77,8 +82,8 @@ def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: energy=torch.zeros((2,), device=si_double_sim_state.device), forces=torch.randn_like(si_double_sim_state.positions), ) - for batch_index in range(2): - substate = _slice_state(state, [batch_index]) + for graph_index in range(2): + substate = _slice_state(state, [graph_index]) assert isinstance(substate, MDState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) @@ -101,10 +106,10 @@ def test_concatenate_two_si_states( assert concatenated.masses.shape == si_double_sim_state.masses.shape assert concatenated.cell.shape == si_double_sim_state.cell.shape assert concatenated.atomic_numbers.shape == si_double_sim_state.atomic_numbers.shape - assert concatenated.batch.shape == si_double_sim_state.batch.shape + assert concatenated.graph_idx.shape == si_double_sim_state.graph_idx.shape - # Check batch indices - expected_batch = torch.cat( + # Check graph indices + expected_graph_indices = torch.cat( [ torch.zeros( si_sim_state.n_atoms, dtype=torch.int64, device=si_sim_state.device @@ -114,12 +119,12 @@ def test_concatenate_two_si_states( ), ] ) - assert torch.all(concatenated.batch == expected_batch) + assert torch.all(concatenated.graph_idx == expected_graph_indices) - # Check that positions match (accounting for batch indices) - for batch_idx in range(2): - mask_concat = concatenated.batch == batch_idx - mask_double = si_double_sim_state.batch == batch_idx + # Check that positions match (accounting for graph indices) + for graph_idx in range(2): + mask_concat = concatenated.graph_idx == graph_idx + mask_double = si_double_sim_state.graph_idx == graph_idx assert torch.allclose( concatenated.positions[mask_concat], si_double_sim_state.positions[mask_double], @@ -143,22 +148,22 @@ def test_concatenate_si_and_fe_states( concatenated.masses.shape[0] == si_sim_state.masses.shape[0] + fe_supercell_sim_state.masses.shape[0] ) - assert concatenated.cell.shape[0] == 2 # One cell per batch + assert concatenated.cell.shape[0] == 2 # One cell per graph - # Check batch indices + # Check graph indices si_atoms = si_sim_state.n_atoms fe_atoms = fe_supercell_sim_state.n_atoms - expected_batch = torch.cat( + expected_graph_indices = torch.cat( [ torch.zeros(si_atoms, dtype=torch.int64, device=si_sim_state.device), torch.ones(fe_atoms, dtype=torch.int64, device=fe_supercell_sim_state.device), ] ) - assert torch.all(concatenated.batch == expected_batch) + assert torch.all(concatenated.graph_idx == expected_graph_indices) - # check n_atoms_per_batch + # check n_atoms_per_graph assert torch.all( - concatenated.n_atoms_per_batch + concatenated.n_atoms_per_graph == torch.tensor( [si_sim_state.n_atoms, fe_supercell_sim_state.n_atoms], device=concatenated.device, @@ -192,22 +197,22 @@ def test_concatenate_double_si_and_fe_states( ) assert ( concatenated.cell.shape[0] == 3 - ) # One cell for each original batch (2 Si + 1 Ar) + ) # One cell for each original graph (2 Si + 1 Ar) - # Check batch indices + # Check graph indices fe_atoms = fe_supercell_sim_state.n_atoms - # The double Si state already has batches 0 and 1, so Ar should be batch 2 - expected_batch = torch.cat( + # The double Si state already has graphs 0 and 1, so Ar should be graph 2 + expected_graph_indices = torch.cat( [ - si_double_sim_state.batch, + si_double_sim_state.graph_idx, torch.full( (fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device ), ] ) - assert torch.all(concatenated.batch == expected_batch) - assert torch.unique(concatenated.batch).shape[0] == 3 + assert torch.all(concatenated.graph_idx == expected_graph_indices) + assert torch.unique(concatenated.graph_idx).shape[0] == 3 # Check that we can slice back to the original states si_slice_0 = concatenated[0] @@ -223,14 +228,14 @@ def test_concatenate_double_si_and_fe_states( def test_split_state(si_double_sim_state: ts.SimState) -> None: """Test splitting a state into a list of states.""" states = si_double_sim_state.split() - assert len(states) == si_double_sim_state.n_batches + assert len(states) == si_double_sim_state.n_graphs for state in states: assert isinstance(state, ts.SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) assert state.atomic_numbers.shape == (8,) - assert torch.allclose(state.batch, torch.zeros_like(state.batch)) + assert torch.allclose(state.graph_idx, torch.zeros_like(state.graph_idx)) def test_split_many_states( @@ -248,7 +253,7 @@ def test_split_many_states( assert torch.allclose(sub_state.masses, state.masses) assert torch.allclose(sub_state.cell, state.cell) assert torch.allclose(sub_state.atomic_numbers, state.atomic_numbers) - assert torch.allclose(sub_state.batch, state.batch) + assert torch.allclose(sub_state.graph_idx, state.graph_idx) assert len(states) == 3 @@ -276,7 +281,7 @@ def test_pop_states( assert kept_state.masses.shape == (len_kept,) assert kept_state.cell.shape == (2, 3, 3) assert kept_state.atomic_numbers.shape == (len_kept,) - assert kept_state.batch.shape == (len_kept,) + assert kept_state.graph_idx.shape == (len_kept,) def test_initialize_state_from_structure( @@ -337,8 +342,8 @@ def test_state_pop_method( assert torch.allclose(popped_states[0].positions, ar_supercell_sim_state.positions) # Verify the original state was modified - assert concatenated.n_batches == 2 - assert torch.unique(concatenated.batch).tolist() == [0, 1] + assert concatenated.n_graphs == 2 + assert torch.unique(concatenated.graph_idx).tolist() == [0, 1] # Test popping multiple batches multi_state = concatenate_states(states) @@ -348,8 +353,8 @@ def test_state_pop_method( assert torch.allclose(popped_multi[1].positions, fe_supercell_sim_state.positions) # Verify the original multi-state was modified - assert multi_state.n_batches == 1 - assert torch.unique(multi_state.batch).tolist() == [0] + assert multi_state.n_graphs == 1 + assert torch.unique(multi_state.graph_idx).tolist() == [0] assert torch.allclose(multi_state.positions, ar_supercell_sim_state.positions) @@ -367,19 +372,19 @@ def test_state_getitem( single_state = concatenated[1] assert isinstance(single_state, SimState) assert torch.allclose(single_state.positions, ar_supercell_sim_state.positions) - assert single_state.n_batches == 1 + assert single_state.n_graphs == 1 # Test list indexing multi_state = concatenated[[0, 2]] assert isinstance(multi_state, SimState) - assert multi_state.n_batches == 2 + assert multi_state.n_graphs == 2 assert torch.allclose(multi_state[0].positions, si_sim_state.positions) assert torch.allclose(multi_state[1].positions, fe_supercell_sim_state.positions) # Test slice indexing slice_state = concatenated[1:3] assert isinstance(slice_state, SimState) - assert slice_state.n_batches == 2 + assert slice_state.n_graphs == 2 assert torch.allclose(slice_state[0].positions, ar_supercell_sim_state.positions) assert torch.allclose(slice_state[1].positions, fe_supercell_sim_state.positions) @@ -391,67 +396,67 @@ def test_state_getitem( # Test step in slice step_state = concatenated[::2] assert isinstance(step_state, SimState) - assert step_state.n_batches == 2 + assert step_state.n_graphs == 2 assert torch.allclose(step_state[0].positions, si_sim_state.positions) assert torch.allclose(step_state[1].positions, fe_supercell_sim_state.positions) full_state = concatenated[:] assert torch.allclose(full_state.positions, concatenated.positions) # Verify original state is unchanged - assert concatenated.n_batches == 3 + assert concatenated.n_graphs == 3 -def test_normalize_batch_indices(si_double_sim_state: ts.SimState) -> None: - """Test the _normalize_batch_indices utility method.""" +def test_normalize_graph_indices(si_double_sim_state: ts.SimState) -> None: + """Test the _normalize_graph_indices utility method.""" state = si_double_sim_state # State with 2 batches - n_batches = state.n_batches + n_graphs = state.n_graphs device = state.device # Test integer indexing - assert _normalize_batch_indices(0, n_batches, device).tolist() == [0] - assert _normalize_batch_indices(1, n_batches, device).tolist() == [1] + assert _normalize_graph_indices(0, n_graphs, device).tolist() == [0] + assert _normalize_graph_indices(1, n_graphs, device).tolist() == [1] # Test negative integer indexing - assert _normalize_batch_indices(-1, n_batches, device).tolist() == [1] - assert _normalize_batch_indices(-2, n_batches, device).tolist() == [0] + assert _normalize_graph_indices(-1, n_graphs, device).tolist() == [1] + assert _normalize_graph_indices(-2, n_graphs, device).tolist() == [0] # Test list indexing - assert _normalize_batch_indices([0, 1], n_batches, device).tolist() == [0, 1] + assert _normalize_graph_indices([0, 1], n_graphs, device).tolist() == [0, 1] # Test list with negative indices - assert _normalize_batch_indices([0, -1], n_batches, device).tolist() == [0, 1] - assert _normalize_batch_indices([-2, -1], n_batches, device).tolist() == [0, 1] + assert _normalize_graph_indices([0, -1], n_graphs, device).tolist() == [0, 1] + assert _normalize_graph_indices([-2, -1], n_graphs, device).tolist() == [0, 1] # Test slice indexing - indices = _normalize_batch_indices(slice(0, 2), n_batches, device) + indices = _normalize_graph_indices(slice(0, 2), n_graphs, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test slice with negative indices - indices = _normalize_batch_indices(slice(-2, None), n_batches, device) + indices = _normalize_graph_indices(slice(-2, None), n_graphs, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test slice with step - indices = _normalize_batch_indices(slice(0, 2, 2), n_batches, device) + indices = _normalize_graph_indices(slice(0, 2, 2), n_graphs, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0], device=state.device)) # Test tensor indexing tensor_indices = torch.tensor([0, 1], device=state.device) - indices = _normalize_batch_indices(tensor_indices, n_batches, device) + indices = _normalize_graph_indices(tensor_indices, n_graphs, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == tensor_indices) # Test tensor with negative indices tensor_indices = torch.tensor([0, -1], device=state.device) - indices = _normalize_batch_indices(tensor_indices, n_batches, device) + indices = _normalize_graph_indices(tensor_indices, n_graphs, device) assert isinstance(indices, torch.Tensor) assert torch.all(indices == torch.tensor([0, 1], device=state.device)) # Test error for unsupported type try: - _normalize_batch_indices((0, 1), n_batches, device) # Tuple is not supported + _normalize_graph_indices((0, 1), n_graphs, device) # Tuple is not supported raise ValueError("Should have raised TypeError") except TypeError: pass @@ -601,7 +606,9 @@ def test_deform_grad_batched(device: torch.device) -> None: atomic_numbers=torch.ones(n_atoms * batch_size, device=device, dtype=torch.long), velocities=torch.randn(n_atoms * batch_size, 3, device=device), reference_cell=reference_cell, - batch=torch.repeat_interleave(torch.arange(batch_size, device=device), n_atoms), + graph_idx=torch.repeat_interleave( + torch.arange(batch_size, device=device), n_atoms + ), ) deform_grad = state.deform_grad() @@ -611,3 +618,28 @@ def test_deform_grad_batched(device: torch.device) -> None: for i in range(batch_size): expected = expected_factors[i] * torch.eye(3, device=device) assert torch.allclose(deform_grad[i], expected) + + +def test_deprecated_batch_properties_equal_to_new_graph_properties( + device: torch.device, +) -> None: + """Test that deprecated batch properties are equal to new graph properties. + + This tests that the rename from batch to graph is not breaking anything.""" + state = SimState( + positions=torch.randn(10, 3, device=device), + masses=torch.ones(10, device=device), + cell=torch.eye(3, device=device).unsqueeze(0).repeat(2, 1, 1), + pbc=True, + atomic_numbers=torch.ones(10, device=device, dtype=torch.long), + graph_idx=torch.repeat_interleave(torch.arange(2, device=device), 5), + ) + assert state.batch is state.graph_idx + assert state.n_batches == state.n_graphs + assert torch.allclose(state.n_atoms_per_batch, state.n_atoms_per_graph) + + # now test that assigning the old .batch property behaves the same + new_graph_idx = torch.arange(4, device=device) + state.batch = new_graph_idx + assert torch.allclose(state.graph_idx, new_graph_idx) + assert torch.allclose(state.batch, new_graph_idx) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 3a662242..35c65160 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -32,7 +32,7 @@ def random_state() -> MDState: ), cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), atomic_numbers=torch.ones(10, dtype=torch.int32), - batch=torch.zeros(10, dtype=torch.int32), + graph_idx=torch.zeros(10, dtype=torch.int32), pbc=True, ) @@ -678,9 +678,9 @@ def test_multi_batch_reporter( # Check that each trajectory has the correct number of atoms # (should be half of the total in the double state) - atoms_per_batch = si_double_sim_state.positions.shape[0] // 2 - assert traj0.get_array("positions").shape[1] == atoms_per_batch - assert traj1.get_array("positions").shape[1] == atoms_per_batch + atoms_per_graph = si_double_sim_state.positions.shape[0] // 2 + assert traj0.get_array("positions").shape[1] == atoms_per_graph + assert traj1.get_array("positions").shape[1] == atoms_per_graph # Check property data assert "ones" in traj0.array_registry @@ -698,11 +698,11 @@ def test_property_model_consistency( """Test property models are consistent for single and multi-batch cases.""" # Create reporters for single and multi-batch cases single_reporters = [] - for batch_idx in range(2): + for graph_idx in range(2): # Extract single batch states - single_state = si_double_sim_state[batch_idx] + single_state = si_double_sim_state[graph_idx] reporter = TrajectoryReporter( - tmp_path / f"single_{batch_idx}.hdf5", + tmp_path / f"single_{graph_idx}.hdf5", state_frequency=1, prop_calculators=prop_calculators, ) @@ -710,7 +710,7 @@ def test_property_model_consistency( reporter.report(single_state, 0) reporter.close() single_reporters.append( - TorchSimTrajectory(tmp_path / f"single_{batch_idx}.hdf5", mode="r") + TorchSimTrajectory(tmp_path / f"single_{graph_idx}.hdf5", mode="r") ) # Create multi-batch reporter @@ -727,14 +727,14 @@ def test_property_model_consistency( TorchSimTrajectory(tmp_path / "multi_1.hdf5", mode="r"), ] - # Compare property values between single and multi-batch approaches - for batch_idx in range(2): - single_ke = single_reporters[batch_idx].get_array("ones")[0] - multi_ke = multi_trajectories[batch_idx].get_array("ones")[0] + # Compare property values between single and multi-graph approaches + for graph_idx in range(2): + single_ke = single_reporters[graph_idx].get_array("ones")[0] + multi_ke = multi_trajectories[graph_idx].get_array("ones")[0] assert torch.allclose(torch.tensor(single_ke), torch.tensor(multi_ke)) - single_com = single_reporters[batch_idx].get_array("center_of_mass")[0] - multi_com = multi_trajectories[batch_idx].get_array("center_of_mass")[0] + single_com = single_reporters[graph_idx].get_array("center_of_mass")[0] + multi_com = multi_trajectories[graph_idx].get_array("center_of_mass")[0] assert torch.allclose(torch.tensor(single_com), torch.tensor(multi_com)) # Close all trajectories @@ -767,12 +767,12 @@ def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tenso reporter.close() # Verify properties were returned - assert len(props) == 2 # One dict per batch - for batch_props in props: - assert set(batch_props) == {"energy"} - assert isinstance(batch_props["energy"], torch.Tensor) - assert batch_props["energy"].shape == (1,) - assert batch_props["energy"] == pytest.approx(49.4150) + assert len(props) == 2 # One dict per graph + for graph_props in props: + assert set(graph_props) == {"energy"} + assert isinstance(graph_props["energy"], torch.Tensor) + assert graph_props["energy"].shape == (1,) + assert graph_props["energy"] == pytest.approx(49.4150) # Verify property was calculated correctly trajectories = [ @@ -780,21 +780,21 @@ def energy_calculator(state: ts.SimState, model: torch.nn.Module) -> torch.Tenso TorchSimTrajectory(tmp_path / "model_1.hdf5", mode="r"), ] - for batch_idx, trajectory in enumerate(trajectories): + for graph_idx, trajectory in enumerate(trajectories): # Get the property value from file file_energy = trajectory.get_array("energy")[0] - batch_props = props[batch_idx] + graph_props = props[graph_idx] # Calculate expected value - substate = si_double_sim_state[batch_idx] + substate = si_double_sim_state[graph_idx] expected = lj_model(substate)["energy"] # Compare file contents with expected np.testing.assert_allclose(file_energy, expected) # Compare returned properties with expected - np.testing.assert_allclose(batch_props["energy"], expected) + np.testing.assert_allclose(graph_props["energy"], expected) # Compare returned properties with file contents - np.testing.assert_allclose(batch_props["energy"], file_energy) + np.testing.assert_allclose(graph_props["energy"], file_energy) trajectory.close() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 879a70a4..6e178416 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -250,8 +250,8 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None # Modify a specific atom's position in each batch to be outside the cell # Get the first atom in each batch - batch_0_mask = state.batch == 0 - batch_1_mask = state.batch == 1 + batch_0_mask = state.graph_idx == 0 + batch_1_mask = state.graph_idx == 1 # Get current cell size (assume cubic for simplicity) cell_size = state.cell[0, 0, 0] @@ -268,7 +268,9 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None test_positions[idx1, 0] = -0.5 # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, graph_idx=state.graph_idx + ) # Check first modified atom is properly wrapped assert wrapped[idx0, 0] < cell_size @@ -317,7 +319,7 @@ def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: cell = torch.stack([cell1, cell2]) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, batch=batch) + wrapped = tst.pbc_wrap_batched(positions, cell=cell, graph_idx=batch) # Calculate expected result for first atom (using original algorithm for verification) expected1 = tst.pbc_wrap_general(positions[0:1], cell1) @@ -343,11 +345,11 @@ def test_pbc_wrap_batched_edge_case(device: torch.device) -> None: device=device, ) - # Create batch indices - batch = torch.tensor([0, 1], device=device) + # Create graph indices + graph_idx = torch.tensor([0, 1], device=device) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, batch=batch) + wrapped = tst.pbc_wrap_batched(positions, cell=cell, graph_idx=graph_idx) # Expected results (wrapping to 0.0 rather than 2.0) expected = torch.tensor( @@ -367,12 +369,12 @@ def test_pbc_wrap_batched_invalid_inputs(device: torch.device) -> None: # Valid inputs for reference positions = torch.ones(4, 3, device=device) cell = torch.stack([torch.eye(3, device=device)] * 2) - batch = torch.tensor([0, 0, 1, 1], device=device) + graph_idx = torch.tensor([0, 0, 1, 1], device=device) # Test integer tensors with pytest.raises(TypeError): tst.pbc_wrap_batched( - torch.ones(4, 3, dtype=torch.int64, device=device), cell, batch + torch.ones(4, 3, dtype=torch.int64, device=device), cell, graph_idx ) # Test dimension mismatch - positions @@ -380,15 +382,15 @@ def test_pbc_wrap_batched_invalid_inputs(device: torch.device) -> None: tst.pbc_wrap_batched( torch.ones(4, 2, device=device), # Wrong dimension (2 instead of 3) cell, - batch, + graph_idx, ) - # Test mismatch between batch indices and cell + # Test mismatch between graph indices and cell with pytest.raises(ValueError): tst.pbc_wrap_batched( positions, torch.stack([torch.eye(3, device=device)] * 3), # 3 cell but only 2 batches - batch, + graph_idx, ) @@ -399,40 +401,42 @@ def test_pbc_wrap_batched_multi_atom(si_double_sim_state: ts.SimState) -> None: # Get a copy of positions to modify test_positions = state.positions.clone() - # Move all atoms of the first batch outside the cell in +x - batch_0_mask = state.batch == 0 + # Move all atoms of the first graph outside the cell in +x + graph_0_mask = state.graph_idx == 0 cell_size_x = state.cell[0, 0, 0].item() - test_positions[batch_0_mask, 0] += cell_size_x + test_positions[graph_0_mask, 0] += cell_size_x - # Move all atoms of the second batch outside the cell in -y - batch_1_mask = state.batch == 1 + # Move all atoms of the second graph outside the cell in -y + graph_1_mask = state.graph_idx == 1 cell_size_y = state.cell[0, 1, 1].item() - test_positions[batch_1_mask, 1] -= cell_size_y + test_positions[graph_1_mask, 1] -= cell_size_y # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, graph_idx=state.graph_idx + ) # Check all positions are within the cell boundaries - for b in range(2): # For each batch - batch_mask = state.batch == b + for b in range(2): # For each graph + graph_mask = state.graph_idx == b # Check x coordinates - assert torch.all(wrapped[batch_mask, 0] >= 0) - assert torch.all(wrapped[batch_mask, 0] < state.cell[b, 0, 0]) + assert torch.all(wrapped[graph_mask, 0] >= 0) + assert torch.all(wrapped[graph_mask, 0] < state.cell[b, 0, 0]) # Check y coordinates - assert torch.all(wrapped[batch_mask, 1] >= 0) - assert torch.all(wrapped[batch_mask, 1] < state.cell[b, 1, 1]) + assert torch.all(wrapped[graph_mask, 1] >= 0) + assert torch.all(wrapped[graph_mask, 1] < state.cell[b, 1, 1]) # Check z coordinates - assert torch.all(wrapped[batch_mask, 2] >= 0) - assert torch.all(wrapped[batch_mask, 2] < state.cell[b, 2, 2]) + assert torch.all(wrapped[graph_mask, 2] >= 0) + assert torch.all(wrapped[graph_mask, 2] < state.cell[b, 2, 2]) def test_pbc_wrap_batched_preserves_relative_positions( si_double_sim_state: ts.SimState, ) -> None: - """Test that relative positions within each batch are preserved after wrapping.""" + """Test that relative positions within each graph are preserved after wrapping.""" state = si_double_sim_state # Get a copy of positions @@ -443,20 +447,22 @@ def test_pbc_wrap_batched_preserves_relative_positions( test_positions += torch.tensor([10.0, 15.0, 20.0], device=state.device) # Apply wrapping - wrapped = tst.pbc_wrap_batched(test_positions, cell=state.cell, batch=state.batch) + wrapped = tst.pbc_wrap_batched( + test_positions, cell=state.cell, graph_idx=state.graph_idx + ) - # Check that relative positions within each batch are preserved + # Check that relative positions within each graph are preserved for b in range(2): # For each batch - batch_mask = state.batch == b + graph_idx_mask = state.graph_idx == b # Calculate pairwise distances before wrapping - atoms_in_batch = torch.sum(batch_mask).item() + atoms_in_batch = torch.sum(graph_idx_mask).item() for n_atoms in range(atoms_in_batch - 1): for j in range(n_atoms + 1, atoms_in_batch): # Get the indices of atoms i and j in this batch - batch_indices = torch.where(batch_mask)[0] - idx_i = batch_indices[n_atoms] - idx_j = batch_indices[j] + graph_indices = torch.where(graph_idx_mask)[0] + idx_i = graph_indices[n_atoms] + idx_j = graph_indices[j] # Original vector from i to j orig_vec = ( @@ -839,11 +845,11 @@ def test_get_fractional_coordinates_batched() -> None: [[1.0, 1.0, 1.0], [2.0, 0.0, 0.0]], device=device, dtype=dtype ) - # Test single batch case (should work) - cell_single_batch = torch.tensor( + # Test single graph case (should work) + cell_single_graph = torch.tensor( [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]], device=device, dtype=dtype ) - frac_batched = tst.get_fractional_coordinates(positions, cell_single_batch) + frac_batched = tst.get_fractional_coordinates(positions, cell_single_graph) # Compare with 2D case cell_2d = torch.tensor( @@ -852,11 +858,11 @@ def test_get_fractional_coordinates_batched() -> None: frac_2d = tst.get_fractional_coordinates(positions, cell_2d) assert torch.allclose(frac_batched, frac_2d), ( - "Single batch case should produce same result as 2D case" + "Single graph case should produce same result as 2D case" ) - # Test multi-batch case (should raise NotImplementedError) - cell_multi_batch = torch.tensor( + # Test multi-graph case (should raise NotImplementedError) + cell_multi_graph = torch.tensor( [ [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], [[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]], @@ -865,8 +871,8 @@ def test_get_fractional_coordinates_batched() -> None: dtype=dtype, ) - with pytest.raises(NotImplementedError, match="Multiple batched cell tensors"): - tst.get_fractional_coordinates(positions, cell_multi_batch) + with pytest.raises(NotImplementedError, match="Multiple graph cell tensors"): + tst.get_fractional_coordinates(positions, cell_multi_graph) @pytest.mark.parametrize( diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index d436076a..cb05f6b8 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -234,7 +234,7 @@ def measure_model_memory_forward(state: SimState, model: ModelInterface) -> floa print( # noqa: T201 "Model Memory Estimation: Running forward pass on state with " - f"{state.n_atoms} atoms and {state.n_batches} batches.", + f"{state.n_atoms} atoms and {state.n_graphs} graphs.", ) # Clear GPU memory torch.cuda.synchronize() @@ -293,8 +293,8 @@ def determine_max_batch_size( sizes.append(next_size) for i in range(len(sizes)): - n_batches = sizes[i] - concat_state = concatenate_states([state] * n_batches) + n_graphs = sizes[i] + concat_state = concatenate_states([state] * n_graphs) try: measure_model_memory_forward(concat_state, model) @@ -343,7 +343,7 @@ def calculate_memory_scaler( # Calculate memory scaling factor based on atom count and density metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density") """ - if state.n_batches > 1: + if state.n_graphs > 1: return sum(calculate_memory_scaler(s, memory_scales_with) for s in state.split()) if memory_scales_with == "n_atoms": return state.n_atoms @@ -405,8 +405,8 @@ def estimate_max_memory_scaler( print( # noqa: T201 "Model Memory Estimation: Estimating memory from worst case of " f"largest and smallest system. Largest system has {max_state.n_atoms} atoms " - f"and {max_state.n_batches} batches, and smallest system has " - f"{min_state.n_atoms} atoms and {min_state.n_batches} batches.", + f"and {max_state.n_graphs} batches, and smallest system has " + f"{min_state.n_atoms} atoms and {min_state.n_graphs} batches.", ) min_state_max_batches = determine_max_batch_size(min_state, model, **kwargs) max_state_max_batches = determine_max_batch_size(max_state, model, **kwargs) @@ -428,7 +428,7 @@ class BinningAutoBatcher: Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per batch. + max_memory_scaler (float): Maximum memory metric allowed per graph. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. return_indices (bool): Whether to return original indices with batches. state_slices (list[SimState]): Individual states to be batched. @@ -477,7 +477,7 @@ def __init__( - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density Defaults to "n_atoms_x_density". - max_memory_scaler (float | None): Maximum metric value allowed per batch. If + max_memory_scaler (float | None): Maximum metric value allowed per graph. If None, will be automatically estimated. Defaults to None. return_indices (bool): Whether to return original indices along with batches. Defaults to False. @@ -716,7 +716,7 @@ class InFlightAutoBatcher: Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per batch. + max_memory_scaler (float): Maximum memory metric allowed per graph. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. return_indices (bool): Whether to return original indices with batches. max_iterations (int | None): Maximum number of iterations per state. @@ -776,7 +776,7 @@ def __init__( - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density Defaults to "n_atoms_x_density". - max_memory_scaler (float | None): Maximum metric value allowed per batch. + max_memory_scaler (float | None): Maximum metric value allowed per graph. If None, will be automatically estimated. Defaults to None. return_indices (bool): Whether to return original indices along with batches. Defaults to False. @@ -933,13 +933,13 @@ def _get_first_batch(self) -> SimState: # if max_metric is not set, estimate it has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - n_batches = determine_max_batch_size( + n_graphs = determine_max_batch_size( first_state, self.model, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, ) - self.max_memory_scaler = n_batches * first_metric * 0.8 + self.max_memory_scaler = n_graphs * first_metric * 0.8 states = self._get_next_states() @@ -971,7 +971,7 @@ def next_batch( # noqa: C901 for the first call. Contains shape information specific to the SimState instance. convergence_tensor (torch.Tensor | None): Boolean tensor with shape - [n_batches] indicating which states have converged (True) or not + [n_graphs] indicating which states have converged (True) or not (False). Should be None only for the first call. Returns: @@ -1019,14 +1019,14 @@ def next_batch( # noqa: C901 # assert statements helpful for debugging, should be moved to validate fn # the first two are most important - if len(convergence_tensor) != updated_state.n_batches: - raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_batches=}") + if len(convergence_tensor) != updated_state.n_graphs: + raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_graphs=}") if len(self.current_idx) != len(self.current_scalers): raise ValueError(f"{len(self.current_idx)=} != {len(self.current_scalers)=}") if len(convergence_tensor.shape) != 1: raise ValueError(f"{len(convergence_tensor.shape)=} != 1") - if updated_state.n_batches <= 0: - raise ValueError(f"{updated_state.n_batches=} <= 0") + if updated_state.n_graphs <= 0: + raise ValueError(f"{updated_state.n_graphs=} <= 0") # Increment attempt counters and check for max attempts in a single loop for cur_idx, abs_idx in enumerate(self.current_idx): @@ -1057,7 +1057,7 @@ def next_batch( # noqa: C901 ) # concatenate remaining state with next states - if updated_state.n_batches > 0: + if updated_state.n_graphs > 0: next_states = [updated_state, *next_states] next_batch = concatenate_states(next_states) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 8640d965..7d5bf56b 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -21,17 +21,17 @@ class MDState(SimState): Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] momenta (torch.Tensor): Particle momenta [n_particles, n_dim] - energy (torch.Tensor): Total energy of the system [n_batches] + energy (torch.Tensor): Total energy of the system [n_graphs] forces (torch.Tensor): Forces on particles [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - cell (torch.Tensor): Simulation cell matrix [n_batches, n_dim, n_dim] + cell (torch.Tensor): Simulation cell matrix [n_graphs, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions - batch (torch.Tensor): Batch indices [n_particles] + graph_idx (torch.Tensor): Graph indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] Properties: velocities (torch.Tensor): Particle velocities [n_particles, n_dim] - n_batches (int): Number of independent systems in the batch + n_graphs (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ @@ -51,7 +51,7 @@ def velocities(self) -> torch.Tensor: def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, - batch: torch.Tensor, + graph_idx: torch.Tensor, kT: torch.Tensor | float, seed: int | None = None, ) -> torch.Tensor: @@ -64,8 +64,8 @@ def calculate_momenta( Args: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - batch (torch.Tensor): Batch indices [n_particles] - kT (torch.Tensor): Temperature in energy units [n_batches] + graph_idx (torch.Tensor): Graph indices [n_particles] + kT (torch.Tensor): Temperature in energy units [n_graphs] seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -79,32 +79,32 @@ def calculate_momenta( generator.manual_seed(seed) if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: - # kT is a tensor with shape (n_batches,) - kT = kT[batch] + # kT is a tensor with shape (n_graphs,) + kT = kT[graph_idx] # Generate random momenta from normal distribution momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) - batchwise_momenta = torch.zeros( - (batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype + graphwise_momenta = torch.zeros( + (graph_idx[-1] + 1, momenta.shape[1]), device=device, dtype=dtype ) - # create 3 copies of batch - batch_3 = batch.view(-1, 1).repeat(1, 3) - bincount = torch.bincount(batch) + # create 3 copies of graph_idx + graph_idx_3 = graph_idx.view(-1, 1).repeat(1, 3) + bincount = torch.bincount(graph_idx) mean_momenta = torch.scatter_reduce( - batchwise_momenta, + graphwise_momenta, dim=0, - index=batch_3, + index=graph_idx_3, src=momenta, reduce="sum", ) / bincount.view(-1, 1) return torch.where( torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), - momenta - mean_momenta[batch], + momenta - mean_momenta[graph_idx], momenta, ) @@ -118,7 +118,7 @@ def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: Args: state (MDState): Current system state containing forces and momenta - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_graphs] Returns: MDState: Updated state with new momenta after force application @@ -138,7 +138,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: Args: state (MDState): Current system state containing positions and velocities - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_graphs] Returns: MDState: Updated state with new positions after propagation @@ -147,9 +147,9 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: new_positions = state.positions + state.velocities * dt if state.pbc: - # Split positions and cells by batch + # Split positions and cells by graph new_positions = transforms.pbc_wrap_batched( - new_positions, state.cell, state.batch + new_positions, state.cell, state.graph_idx ) state.positions = new_positions diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index eb7b1f4d..9d1ea5c7 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -31,25 +31,25 @@ class NPTLangevinState(SimState): Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] velocities (torch.Tensor): Particle velocities [n_particles, n_dim] - energy (torch.Tensor): Energy of the system [n_batches] + energy (torch.Tensor): Energy of the system [n_graphs] forces (torch.Tensor): Forces on particles [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] - cell (torch.Tensor): Simulation cell matrix [n_batches, n_dim, n_dim] + cell (torch.Tensor): Simulation cell matrix [n_graphs, n_dim, n_dim] pbc (bool): Whether to use periodic boundary conditions - batch (torch.Tensor): Batch indices [n_particles] + graph_idx (torch.Tensor): Graph indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] - stress (torch.Tensor): Stress tensor [n_batches, n_dim, n_dim] + stress (torch.Tensor): Stress tensor [n_graphs, n_dim, n_dim] reference_cell (torch.Tensor): Original cell vectors used as reference for - scaling [n_batches, n_dim, n_dim] - cell_positions (torch.Tensor): Cell positions [n_batches, n_dim, n_dim] - cell_velocities (torch.Tensor): Cell velocities [n_batches, n_dim, n_dim] + scaling [n_graphs, n_dim, n_dim] + cell_positions (torch.Tensor): Cell positions [n_graphs, n_dim, n_dim] + cell_velocities (torch.Tensor): Cell velocities [n_graphs, n_dim, n_dim] cell_masses (torch.Tensor): Masses associated with the cell degrees of freedom - shape [n_batches] + shape [n_graphs] Properties: momenta (torch.Tensor): Particle momenta calculated as velocities*masses with shape [n_particles, n_dimensions] - n_batches (int): Number of independent systems in the batch + n_graphs (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors """ @@ -88,12 +88,12 @@ def _compute_cell_force( Args: state (NPTLangevinState): Current NPT state external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dimensions, n_dimensions] + tensor with shape [n_graphs, n_dimensions, n_dimensions] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_graphs] Returns: - torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] + torch.Tensor: Force acting on the cell [n_graphs, n_dim, n_dim] """ # Convert external_pressure to tensor if it's not already one if not isinstance(external_pressure, torch.Tensor): @@ -106,10 +106,10 @@ def _compute_cell_force( kT = torch.tensor(kT, device=state.device, dtype=state.dtype) # Get current volumes for each batch - volumes = torch.linalg.det(state.cell) # shape: (n_batches,) + volumes = torch.linalg.det(state.cell) # shape: (n_graphs,) # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1) + volumes = volumes.view(-1, 1, 1) # shape: (n_graphs, 1, 1) # Create pressure tensor (diagonal with external pressure) if external_pressure.ndim == 0: @@ -117,9 +117,9 @@ def _compute_cell_force( pressure_tensor = external_pressure * torch.eye( 3, device=state.device, dtype=state.dtype ) - pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1) + pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_graphs, -1, -1) else: - # Already a tensor with shape compatible with n_batches + # Already a tensor with shape compatible with n_graphs pressure_tensor = external_pressure # Calculate virials from stress and external pressure @@ -129,14 +129,14 @@ def _compute_cell_force( # Add kinetic contribution (kT * Identity) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) + batch_kT = kT.expand(state.n_graphs) e_kin_per_atom = batch_kT.view(-1, 1, 1) * torch.eye( 3, device=state.device, dtype=state.dtype ).unsqueeze(0) - # Correct implementation with scaling by n_atoms_per_batch - return virial + e_kin_per_atom * state.n_atoms_per_batch.view(-1, 1, 1) + # Correct implementation with scaling by n_atoms_per_graph + return virial + e_kin_per_atom * state.n_atoms_per_graph.view(-1, 1, 1) def npt_langevin( # noqa: C901, PLR0915 @@ -164,18 +164,18 @@ def npt_langevin( # noqa: C901, PLR0915 Args: model (torch.nn.Module): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] external_pressure (torch.Tensor): Target pressure to maintain, either scalar - or shape [n_batches, n_dim, n_dim] for anisotropic pressure + or shape [n_graphs, n_dim, n_dim] for anisotropic pressure alpha (torch.Tensor, optional): Friction coefficient for particle Langevin - thermostat, either scalar or shape [n_batches]. Defaults to 1/(100*dt). + thermostat, either scalar or shape [n_graphs]. Defaults to 1/(100*dt). cell_alpha (torch.Tensor, optional): Friction coefficient for cell Langevin - thermostat, either scalar or shape [n_batches]. Defaults to same as alpha. + thermostat, either scalar or shape [n_graphs]. Defaults to same as alpha. b_tau (torch.Tensor, optional): Barostat time constant controlling how quickly the system responds to pressure differences, either scalar or shape - [n_batches]. Defaults to 1/(1000*dt). + [n_graphs]. Defaults to 1/(1000*dt). seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -229,24 +229,24 @@ def beta( Args: state (NPTLangevinState): Current NPT state alpha (torch.Tensor): Friction coefficient, either scalar or - shape [n_batches] + shape [n_graphs] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + shape [n_graphs] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] Returns: torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ - # Generate batch-specific noise with correct shape + # Generate graph-specific noise with correct shape noise = torch.randn_like(state.velocities) - # Calculate the thermal noise amplitude by batch + # Calculate the thermal noise amplitude by graph batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) + batch_kT = kT.expand(state.n_graphs) - # Map batch kT to atoms - atom_kT = batch_kT[state.batch] + # Map graph kT to atoms + atom_kT = batch_kT[state.graph_idx] # Calculate the prefactor for each atom # The standard deviation should be sqrt(2*alpha*kB*T*dt) @@ -269,29 +269,29 @@ def cell_beta( Args: state (NPTLangevinState): Current NPT state cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_batches] + with shape [n_graphs] kT (torch.Tensor): System temperature in energy units, either scalar or - with shape [n_batches] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + with shape [n_graphs] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] Returns: torch.Tensor: Scaled random noise for cell dynamics with shape - [n_batches, n_dimensions, n_dimensions] + [n_graphs, n_dimensions, n_dimensions] """ # Generate standard normal distribution (zero mean, unit variance) noise = torch.randn_like(state.cell_positions, device=device, dtype=dtype) # Ensure cell_alpha and kT have batch dimension if they're scalars if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_graphs) if kT.ndim == 0: - kT = kT.expand(state.n_batches) + kT = kT.expand(state.n_graphs) # Reshape for broadcasting - cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) - kT = kT.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_graphs, 1, 1) + kT = kT.view(-1, 1, 1) # shape: (n_graphs, 1, 1) if dt.ndim == 0: - dt = dt.expand(state.n_batches).view(-1, 1, 1) + dt = dt.expand(state.n_graphs).view(-1, 1, 1) else: dt = dt.view(-1, 1, 1) @@ -316,12 +316,12 @@ def compute_cell_force( Args: state (NPTLangevinState): Current NPT state external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dimensions, n_dimensions] + tensor with shape [n_graphs, n_dimensions, n_dimensions] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_graphs] Returns: - torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] + torch.Tensor: Force acting on the cell [n_graphs, n_dim, n_dim] """ return _compute_cell_force(state, external_pressure, kT) @@ -340,25 +340,25 @@ def cell_position_step( Args: state (NPTLangevinState): Current NPT state - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] pressure_force (torch.Tensor): Pressure force for barostat - [n_batches, n_dim, n_dim] + [n_graphs, n_dim, n_dim] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_batches] + with shape [n_graphs] Returns: NPTLangevinState: Updated state with new cell positions """ # Calculate effective mass term - Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_batches, 1, 1) + Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_graphs, 1, 1) # Ensure parameters have batch dimension if dt.ndim == 0: - dt = dt.expand(state.n_batches) + dt = dt.expand(state.n_graphs) if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_graphs) # Reshape for broadcasting dt_expanded = dt.view(-1, 1, 1) @@ -403,34 +403,32 @@ def cell_velocity_step( Args: state (NPTLangevinState): Current NPT state F_p_n (torch.Tensor): Initial pressure force with shape - [n_batches, n_dimensions, n_dimensions] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + [n_graphs, n_dimensions, n_dimensions] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] pressure_force (torch.Tensor): Final pressure force - shape [n_batches, n_dim, n_dim] + shape [n_graphs, n_dim, n_dim] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_batches] + shape [n_graphs] kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_batches] + shape [n_graphs] Returns: NPTLangevinState: Updated state with new cell velocities """ # Ensure parameters have batch dimension if dt.ndim == 0: - dt = dt.expand(state.n_batches) + dt = dt.expand(state.n_graphs) if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_batches) + cell_alpha = cell_alpha.expand(state.n_graphs) if kT.ndim == 0: - kT = kT.expand(state.n_batches) + kT = kT.expand(state.n_graphs) # Reshape for broadcasting - need to maintain 3x3 dimensions - dt_expanded = dt.view(-1, 1, 1) # shape: (n_batches, 1, 1) - cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) + dt_expanded = dt.view(-1, 1, 1) # shape: (n_graphs, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_graphs, 1, 1) - # Calculate cell masses per batch - reshape to match 3x3 cell matrices - cell_masses_expanded = state.cell_masses.view( - -1, 1, 1 - ) # shape: (n_batches, 1, 1) + # Calculate cell masses per graph - reshape to match 3x3 cell matrices + cell_masses_expanded = state.cell_masses.view(-1, 1, 1) # shape: (n_graphs, 1, 1) # These factors come from the Langevin integration scheme a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( @@ -439,13 +437,13 @@ def cell_velocity_step( b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) # Calculate the three terms for velocity update - # a will broadcast from (n_batches, 1, 1) to (n_batches, 3, 3) + # a will broadcast from (n_graphs, 1, 1) to (n_graphs, 3, 3) c_1 = a * state.cell_velocities # Damped old velocity # Force contribution (average of initial and final forces) c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) - # Generate batch-specific cell noise with correct shape (n_batches, 3, 3) + # Generate graph-specific cell noise with correct shape (n_graphs, 3, 3) cell_noise = torch.randn_like(state.cell_velocities) # Calculate thermal noise amplitude @@ -463,7 +461,7 @@ def cell_velocity_step( def langevin_position_step( state: NPTLangevinState, - L_n: torch.Tensor, # This should be shape (n_batches,) + L_n: torch.Tensor, # This should be shape (n_graphs,) dt: torch.Tensor, kT: torch.Tensor, ) -> NPTLangevinState: @@ -476,42 +474,40 @@ def langevin_position_step( Args: state (NPTLangevinState): Current NPT state - L_n (torch.Tensor): Previous cell length scale with shape [n_batches] - dt: Integration timestep, either scalar or with shape [n_batches] + L_n (torch.Tensor): Previous cell length scale with shape [n_graphs] + dt: Integration timestep, either scalar or with shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] Returns: NPTLangevinState: Updated state with new positions """ - # Calculate effective mass term by batch + # Calculate effective mass term by graph # Map masses to have batch dimension M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) # Calculate new cell length scale (cube root of volume for isotropic scaling) L_n_new = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 - ) # shape: (n_batches,) + state.cell_positions.reshape(state.n_graphs, -1)[:, 0], 1 / 3 + ) # shape: (n_graphs,) - # Map batch-specific L_n and L_n_new to atom-level using batch indices - # Make sure L_n is the right shape (n_batches,) before indexing - if L_n.ndim != 1 or L_n.shape[0] != state.n_batches: + # Map graph-specific L_n and L_n_new to atom-level using graph indices + # Make sure L_n is the right shape (n_graphs,) before indexing + if L_n.ndim != 1 or L_n.shape[0] != state.n_graphs: # If L_n has wrong shape, calculate it again to ensure correct shape - L_n = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 - ) + L_n = torch.pow(state.cell_positions.reshape(state.n_graphs, -1)[:, 0], 1 / 3) - # Map batch values to atoms using batch indices - L_n_atoms = L_n[state.batch] # shape: (n_atoms,) - L_n_new_atoms = L_n_new[state.batch] # shape: (n_atoms,) + # Map graph-specific values to atoms using graph indices + L_n_atoms = L_n[state.graph_idx] # shape: (n_atoms,) + L_n_new_atoms = L_n_new[state.graph_idx] # shape: (n_atoms,) # Calculate damping factor alpha_atoms = alpha if alpha.ndim > 0: - alpha_atoms = alpha[state.batch] + alpha_atoms = alpha[state.graph_idx] dt_atoms = dt if dt.ndim > 0: - dt_atoms = dt[state.batch] + dt_atoms = dt[state.graph_idx] b = 1 / (1 + ((alpha_atoms * dt_atoms) / M_2)) @@ -529,8 +525,8 @@ def langevin_position_step( noise = torch.randn_like(state.velocities) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) - atom_kT = batch_kT[state.batch] + batch_kT = kT.expand(state.n_graphs) + atom_kT = batch_kT[state.graph_idx] # Calculate noise prefactor according to fluctuation-dissipation theorem noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) @@ -549,7 +545,7 @@ def langevin_position_step( # Apply periodic boundary conditions if needed if state.pbc: state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.batch + state.positions, state.cell, state.graph_idx ) return state @@ -569,9 +565,9 @@ def langevin_velocity_step( Args: state (NPTLangevinState): Current NPT state forces: Forces on particles - dt: Integration timestep, either scalar or with shape [n_batches] + dt: Integration timestep, either scalar or with shape [n_graphs] kT: Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] Returns: NPTLangevinState: Updated state with new velocities @@ -582,10 +578,10 @@ def langevin_velocity_step( # Map batch parameters to atom level alpha_atoms = alpha if alpha.ndim > 0: - alpha_atoms = alpha[state.batch] + alpha_atoms = alpha[state.graph_idx] dt_atoms = dt if dt.ndim > 0: - dt_atoms = dt[state.batch] + dt_atoms = dt[state.graph_idx] # Calculate damping factors for Langevin integration a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) @@ -601,8 +597,8 @@ def langevin_velocity_step( noise = torch.randn_like(state.velocities) batch_kT = kT if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) - atom_kT = batch_kT[state.batch] + batch_kT = kT.expand(state.n_graphs) + atom_kT = batch_kT[state.graph_idx] # Calculate noise prefactor according to fluctuation-dissipation theorem noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) @@ -647,7 +643,7 @@ def npt_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.graph_idx, kT, seed), ) # Initialize cell parameters @@ -656,20 +652,20 @@ def npt_init( # Calculate initial cell_positions (volume) cell_positions = ( torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) - ) # shape: (n_batches, 1, 1) + ) # shape: (n_graphs, 1, 1) # Initialize cell velocities to zero - cell_velocities = torch.zeros((state.n_batches, 3, 3), device=device, dtype=dtype) + cell_velocities = torch.zeros((state.n_graphs, 3, 3), device=device, dtype=dtype) # Calculate cell masses based on system size and temperature # This follows standard NPT barostat mass scaling - n_atoms_per_batch = torch.bincount(state.batch) + n_atoms_per_graph = torch.bincount(state.graph_idx) batch_kT = ( - kT.expand(state.n_batches) + kT.expand(state.n_graphs) if isinstance(kT, torch.Tensor) and kT.ndim == 0 else kT ) - cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + cell_masses = (n_atoms_per_graph + 1) * batch_kT * b_tau * b_tau # Create the initial state return NPTLangevinState( @@ -681,7 +677,7 @@ def npt_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + graph_idx=state.graph_idx, atomic_numbers=state.atomic_numbers, reference_cell=reference_cell, cell_positions=cell_positions, @@ -706,15 +702,15 @@ def npt_update( Args: state (NPTLangevinState): Current NPT state with particle and cell variables - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_batches] + shape [n_graphs] external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_batches, n_dim, n_dim] + tensor with shape [n_graphs, n_dim, n_dim] alpha (torch.Tensor): Position friction coefficient, either scalar or - shape [n_batches] + shape [n_graphs] cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_batches] + shape [n_graphs] Returns: NPTLangevinState: Updated NPT state after one timestep with new positions, @@ -731,12 +727,12 @@ def npt_update( dt = torch.tensor(dt, device=device, dtype=dtype) # Make sure parameters have batch dimension if they're scalars - batch_kT = kT.expand(state.n_batches) if kT.ndim == 0 else kT + batch_kT = kT.expand(state.n_graphs) if kT.ndim == 0 else kT # Update barostat mass based on current temperature # This ensures proper coupling between system and barostat - n_atoms_per_batch = torch.bincount(state.batch) - state.cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + n_atoms_per_graph = torch.bincount(state.graph_idx) + state.cell_masses = (n_atoms_per_graph + 1) * batch_kT * b_tau * b_tau # Compute model output for current state model_output = model(state) @@ -749,24 +745,24 @@ def npt_update( state=state, external_pressure=external_pressure, kT=kT ) L_n = torch.pow( - state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 - ) # shape: (n_batches,) + state.cell_positions.reshape(state.n_graphs, -1)[:, 0], 1 / 3 + ) # shape: (n_graphs,) # Step 1: Update cell position state = cell_position_step(state=state, dt=dt, pressure_force=F_p_n, kT=kT) # Update cell (currently only isotropic fluctuations) dim = state.positions.shape[1] # Usually 3 for 3D - # V_0 and V are shape: (n_batches,) + # V_0 and V are shape: (n_graphs,) V_0 = torch.linalg.det(state.reference_cell) - V = state.cell_positions.reshape(state.n_batches, -1)[:, 0] + V = state.cell_positions.reshape(state.n_graphs, -1)[:, 0] # Scale cell uniformly in all dimensions - scaling = (V / V_0) ** (1.0 / dim) # shape: (n_batches,) + scaling = (V / V_0) ** (1.0 / dim) # shape: (n_graphs,) # Apply scaling to reference cell to get new cell new_cell = torch.zeros_like(state.cell) - for b in range(state.n_batches): + for b in range(state.n_graphs): new_cell[b] = scaling[b] * state.reference_cell[b] state.cell = new_cell @@ -822,14 +818,14 @@ class NPTNoseHooverState(MDState): forces (torch.Tensor): Forces on particles with shape [n_particles, n_dims] masses (torch.Tensor): Particle masses with shape [n_particles] reference_cell (torch.Tensor): Reference simulation cell matrix with shape - [n_batches, n_dimensions, n_dimensions]. Used to measure relative volume + [n_graphs, n_dimensions, n_dimensions]. Used to measure relative volume changes. - cell_position (torch.Tensor): Logarithmic cell coordinate with shape [n_batches]. + cell_position (torch.Tensor): Logarithmic cell coordinate with shape [n_graphs]. Represents (1/d)ln(V/V_0) where V is current volume and V_0 is reference volume. cell_momentum (torch.Tensor): Cell momentum (velocity) conjugate to cell_position - with shape [n_batches]. Controls volume changes. - cell_mass (torch.Tensor): Mass parameter for cell dynamics with shape [n_batches]. + with shape [n_graphs]. Controls volume changes. + cell_mass (torch.Tensor): Mass parameter for cell dynamics with shape [n_graphs]. Controls coupling between volume fluctuations and pressure. barostat (NoseHooverChain): Chain thermostat coupled to cell dynamics for pressure control @@ -842,7 +838,7 @@ class NPTNoseHooverState(MDState): velocities (torch.Tensor): Particle velocities computed as momenta divided by masses. Shape: [n_particles, n_dimensions] current_cell (torch.Tensor): Current simulation cell matrix derived from - cell_position. Shape: [n_batches, n_dimensions, n_dimensions] + cell_position. Shape: [n_graphs, n_dimensions, n_dimensions] Notes: - The cell parameterization ensures volume positivity @@ -853,10 +849,10 @@ class NPTNoseHooverState(MDState): """ # Cell variables - now with batch dimensions - reference_cell: torch.Tensor # [n_batches, 3, 3] - cell_position: torch.Tensor # [n_batches] - cell_momentum: torch.Tensor # [n_batches] - cell_mass: torch.Tensor # [n_batches] + reference_cell: torch.Tensor # [n_graphs, 3, 3] + cell_position: torch.Tensor # [n_graphs] + cell_momentum: torch.Tensor # [n_graphs] + cell_mass: torch.Tensor # [n_graphs] # Thermostat variables thermostat: NoseHooverChain @@ -885,13 +881,13 @@ def current_cell(self) -> torch.Tensor: Returns: torch.Tensor: Current simulation cell matrix with shape - [n_batches, n_dimensions, n_dimensions] + [n_graphs, n_dimensions, n_dimensions] """ dim = self.positions.shape[1] - V_0 = torch.det(self.reference_cell) # [n_batches] - V = V_0 * torch.exp(dim * self.cell_position) # [n_batches] - scale = (V / V_0) ** (1.0 / dim) # [n_batches] - # Expand scale to [n_batches, 1, 1] for broadcasting + V_0 = torch.det(self.reference_cell) # [n_graphs] + V = V_0 * torch.exp(dim * self.cell_position) # [n_graphs] + scale = (V / V_0) ** (1.0 / dim) # [n_graphs] + # Expand scale to [n_graphs, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * self.reference_cell @@ -952,9 +948,9 @@ def _npt_cell_info( Returns: tuple: - - torch.Tensor: Current system volume with shape [n_batches] - - callable: Function that takes a volume tensor [n_batches] and returns - the corresponding cell matrix [n_batches, n_dimensions, n_dimensions] + - torch.Tensor: Current system volume with shape [n_graphs] + - callable: Function that takes a volume tensor [n_graphs] and returns + the corresponding cell matrix [n_graphs, n_dimensions, n_dimensions] Notes: - Uses logarithmic cell coordinate parameterization @@ -963,21 +959,21 @@ def _npt_cell_info( - Supports batched operations """ dim = state.positions.shape[1] - ref = state.reference_cell # [n_batches, dim, dim] - V_0 = torch.det(ref) # [n_batches] - Reference volume - V = V_0 * torch.exp(dim * state.cell_position) # [n_batches] - Current volume + ref = state.reference_cell # [n_graphs, dim, dim] + V_0 = torch.det(ref) # [n_graphs] - Reference volume + V = V_0 * torch.exp(dim * state.cell_position) # [n_graphs] - Current volume def volume_to_cell(V: torch.Tensor) -> torch.Tensor: """Compute cell matrix for given volumes. Args: - V (torch.Tensor): Volumes with shape [n_batches] + V (torch.Tensor): Volumes with shape [n_graphs] Returns: - torch.Tensor: Cell matrices with shape [n_batches, dim, dim] + torch.Tensor: Cell matrices with shape [n_graphs, dim, dim] """ - scale = (V / V_0) ** (1.0 / dim) # [n_batches] - # Expand scale to [n_batches, 1, 1] for broadcasting + scale = (V / V_0) ** (1.0 / dim) # [n_graphs] + # Expand scale to [n_graphs, 1, 1] for broadcasting scale = scale.unsqueeze(-1).unsqueeze(-1) return scale * ref @@ -996,7 +992,7 @@ def update_cell_mass( Args: state (NPTNoseHooverState): Current state of the NPT system kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_batches] + shape [n_graphs] Returns: NPTNoseHooverState: Updated state with new cell mass @@ -1014,11 +1010,11 @@ def update_cell_mass( kT = torch.tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT - kT_batch = kT.expand(state.n_batches) if kT.ndim == 0 else kT + kT_graph = kT.expand(state.n_graphs) if kT.ndim == 0 else kT - # Calculate cell masses for each batch - n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches) - cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * state.barostat.tau**2 + # Calculate cell masses for each graph + n_atoms_per_graph = torch.bincount(state.graph_idx, minlength=state.n_graphs) + cell_mass = dim * (n_atoms_per_graph + 1) * kT_graph * state.barostat.tau**2 # Update state with new cell masses state.cell_mass = cell_mass.to(device=device, dtype=dtype) @@ -1072,7 +1068,7 @@ def exp_iL1( # noqa: N802 Args: state (NPTNoseHooverState): Current simulation state velocities (torch.Tensor): Particle velocities [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_batches] + cell_velocity (torch.Tensor): Cell velocity with shape [n_graphs] dt (torch.Tensor): Integration timestep Returns: @@ -1083,10 +1079,10 @@ def exp_iL1( # noqa: N802 - Properly handles cell scaling through cell_velocity - Maintains time-reversibility of the integration scheme - Applies periodic boundary conditions if state.pbc is True - - Supports batched operations with proper atom-to-batch mapping + - Supports batched operations with proper atom-to-graph mapping """ - # Map batch-level cell velocities to atom level using batch indices - cell_velocity_atoms = cell_velocity[state.batch] # [n_atoms] + # Map graph-level cell velocities to atom level using graph indices + cell_velocity_atoms = cell_velocity[state.graph_idx] # [n_atoms] # Compute cell velocity terms per atom x = cell_velocity_atoms * dt # [n_atoms] @@ -1110,7 +1106,7 @@ def exp_iL1( # noqa: N802 # Apply periodic boundary conditions if needed if state.pbc: return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.batch + new_positions, state.current_cell, state.graph_idx ) return new_positions @@ -1137,7 +1133,7 @@ def exp_iL2( # noqa: N802 alpha (torch.Tensor): Cell scaling parameter momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_batches] + cell_velocity (torch.Tensor): Cell velocity with shape [n_graphs] dt_2 (torch.Tensor): Half timestep (dt/2) Returns: @@ -1148,10 +1144,10 @@ def exp_iL2( # noqa: N802 - Properly handles cell velocity scaling effects - Maintains time-reversibility of the integration scheme - Part of the NPT integration algorithm - - Supports batched operations with proper atom-to-batch mapping + - Supports batched operations with proper atom-to-graph mapping """ - # Map batch-level cell velocities to atom level using batch indices - cell_velocity_atoms = cell_velocity[state.batch] # [n_atoms] + # Map graph-level cell velocities to atom level using graph indices + cell_velocity_atoms = cell_velocity[state.graph_idx] # [n_atoms] # Compute scaling terms per atom x = alpha * cell_velocity_atoms * dt_2 # [n_atoms] @@ -1178,7 +1174,7 @@ def compute_cell_force( masses: torch.Tensor, stress: torch.Tensor, external_pressure: torch.Tensor, - batch: torch.Tensor, + graph_idx: torch.Tensor, ) -> torch.Tensor: """Compute the force on the cell degree of freedom in NPT dynamics. @@ -1190,16 +1186,16 @@ def compute_cell_force( Args: alpha (torch.Tensor): Cell scaling parameter - volume (torch.Tensor): Current system volume with shape [n_batches] + volume (torch.Tensor): Current system volume with shape [n_graphs] positions (torch.Tensor): Particle positions [n_particles, n_dimensions] momenta (torch.Tensor): Particle momenta [n_particles, n_dimensions] masses (torch.Tensor): Particle masses [n_particles] - stress (torch.Tensor): Stress tensor [n_batches, n_dimensions, n_dimensions] + stress (torch.Tensor): Stress tensor [n_graphs, n_dimensions, n_dimensions] external_pressure (torch.Tensor): Target external pressure - batch (torch.Tensor): Batch indices for atoms [n_particles] + graph_idx (torch.Tensor): Graph indices for atoms [n_particles] Returns: - torch.Tensor: Force on the cell degree of freedom with shape [n_batches] + torch.Tensor: Force on the cell degree of freedom with shape [n_graphs] Notes: - Force drives volume changes to maintain target pressure @@ -1209,34 +1205,34 @@ def compute_cell_force( - Supports batched operations """ N, dim = positions.shape - n_batches = len(volume) + n_graphs = len(volume) - # Compute kinetic energy contribution per batch - # Split momenta and masses by batch - KE_per_batch = torch.zeros( - n_batches, device=positions.device, dtype=positions.dtype + # Compute kinetic energy contribution per graph + # Split momenta and masses by graph + KE_per_graph = torch.zeros( + n_graphs, device=positions.device, dtype=positions.dtype ) - for b in range(n_batches): - batch_mask = batch == b - if batch_mask.any(): - batch_momenta = momenta[batch_mask] - batch_masses = masses[batch_mask] - KE_per_batch[b] = calc_kinetic_energy(batch_momenta, batch_masses) - - # Get stress tensor and compute trace per batch + for b in range(n_graphs): + graph_mask = graph_idx == b + if graph_mask.any(): + graph_momenta = momenta[graph_mask] + graph_masses = masses[graph_mask] + KE_per_graph[b] = calc_kinetic_energy(graph_momenta, graph_masses) + + # Get stress tensor and compute trace per graph # Handle stress tensor with batch dimension if stress.ndim == 3: internal_pressure = torch.diagonal(stress, dim1=-2, dim2=-1).sum( dim=-1 - ) # [n_batches] + ) # [n_graphs] else: - # Single batch case - expand to batch dimension - internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_batches) + # Single graph case - expand to batch dimension + internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_graphs) - # Compute force on cell coordinate per batch + # Compute force on cell coordinate per graph # F = alpha * KE - dU/dV - P*V*d return ( - (alpha * KE_per_batch) + (alpha * KE_per_graph) - (internal_pressure * volume) - (external_pressure * volume * dim) ) @@ -1270,9 +1266,9 @@ def npt_inner_step( momenta = state.momenta masses = state.masses forces = state.forces - cell_position = state.cell_position # [n_batches] - cell_momentum = state.cell_momentum # [n_batches] - cell_mass = state.cell_mass # [n_batches] + cell_position = state.cell_position # [n_graphs] + cell_momentum = state.cell_momentum # [n_graphs] + cell_mass = state.cell_mass # [n_graphs] n_particles, dim = positions.shape @@ -1285,8 +1281,8 @@ def npt_inner_step( model_output = model(state) # First half step: Update momenta - n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches) - alpha = 1 + 1 / n_atoms_per_batch # [n_batches] + n_atoms_per_graph = torch.bincount(state.graph_idx, minlength=state.n_graphs) + alpha = 1 + 1 / n_atoms_per_graph # [n_graphs] cell_force_val = compute_cell_force( alpha=alpha, @@ -1296,7 +1292,7 @@ def npt_inner_step( masses=masses, stress=model_output["stress"], external_pressure=external_pressure, - batch=state.batch, + graph_idx=state.graph_idx, ) # Update cell momentum and particle momenta @@ -1331,7 +1327,7 @@ def npt_inner_step( masses=masses, stress=model_output["stress"], external_pressure=external_pressure, - batch=state.batch, + graph_idx=state.graph_idx, ) cell_momentum = cell_momentum + dt_2 * cell_force_val @@ -1410,32 +1406,32 @@ def npt_nose_hoover_init( state = SimState(**state) n_particles, dim = state.positions.shape - n_batches = state.n_batches + n_graphs = state.n_graphs atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) - # Initialize cell variables with proper batch dimensions - cell_position = torch.zeros(n_batches, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_batches, device=device, dtype=dtype) + # Initialize cell variables with proper graph dimensions + cell_position = torch.zeros(n_graphs, device=device, dtype=dtype) + cell_momentum = torch.zeros(n_graphs, device=device, dtype=dtype) # Convert kT to tensor if it's not already one if not isinstance(kT, torch.Tensor): kT = torch.tensor(kT, device=device, dtype=dtype) # Handle both scalar and batched kT - kT_batch = kT.expand(n_batches) if kT.ndim == 0 else kT + kT_graph = kT.expand(n_graphs) if kT.ndim == 0 else kT - # Calculate cell masses for each batch - n_atoms_per_batch = torch.bincount(state.batch, minlength=n_batches) - cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * b_tau**2 + # Calculate cell masses for each graph + n_atoms_per_graph = torch.bincount(state.graph_idx, minlength=n_graphs) + cell_mass = dim * (n_atoms_per_graph + 1) * kT_graph * b_tau**2 cell_mass = cell_mass.to(device=device, dtype=dtype) - # Calculate cell kinetic energy (using first batch for initialization) + # Calculate cell kinetic energy (using first graph for initialization) KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1]) - # Ensure reference_cell has proper batch dimensions + # Ensure reference_cell has proper graph dimensions if state.cell.ndim == 2: # Single cell matrix - expand to batch dimension - reference_cell = state.cell.unsqueeze(0).expand(n_batches, -1, -1).clone() + reference_cell = state.cell.unsqueeze(0).expand(n_graphs, -1, -1).clone() else: # Already has batch dimension reference_cell = state.cell.clone() @@ -1445,7 +1441,7 @@ def npt_nose_hoover_init( state.cell, int | float ): cell_matrix = torch.eye(dim, device=device, dtype=dtype) * state.cell - reference_cell = cell_matrix.unsqueeze(0).expand(n_batches, -1, -1).clone() + reference_cell = cell_matrix.unsqueeze(0).expand(n_graphs, -1, -1).clone() state.cell = reference_cell # Get model output @@ -1463,7 +1459,7 @@ def npt_nose_hoover_init( atomic_numbers=atomic_numbers, cell=state.cell, pbc=state.pbc, - batch=state.batch, + graph_idx=state.graph_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, @@ -1478,14 +1474,14 @@ def npt_nose_hoover_init( momenta = kwargs.get( "momenta", calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.batch, kT, seed + npt_state.positions, npt_state.masses, npt_state.graph_idx, kT, seed ), ) # Initialize thermostat npt_state.momenta = momenta KE = calc_kinetic_energy( - npt_state.momenta, npt_state.masses, batch=npt_state.batch + npt_state.momenta, npt_state.masses, graph_idx=npt_state.graph_idx ) npt_state.thermostat = thermostat_fns.initialize( npt_state.positions.numel(), KE, kT @@ -1542,7 +1538,7 @@ def npt_nose_hoover_update( ) # Update kinetic energies for thermostats - KE = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + KE = calc_kinetic_energy(state.momenta, state.masses, graph_idx=state.graph_idx) state.thermostat.kinetic_energy = KE KE_cell = calc_kinetic_energy(state.cell_momentum, state.cell_mass) @@ -1588,44 +1584,46 @@ def npt_nose_hoover_invariant( Returns: torch.Tensor: The conserved quantity (extended Hamiltonian) of the NPT system. - Returns a scalar for single batch or tensor with shape [n_batches] for - multiple batches. + Returns a scalar for a single graph or tensor with shape [n_graphs] for + multiple graphs. """ # Calculate volume and potential energy - volume = torch.det(state.current_cell) # [n_batches] - e_pot = state.energy # Should be scalar or [n_batches] + volume = torch.det(state.current_cell) # [n_graphs] + e_pot = state.energy # Should be scalar or [n_graphs] - # Calculate kinetic energy of particles per batch - e_kin_per_batch = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + # Calculate kinetic energy of particles per graph + e_kin_per_graph = calc_kinetic_energy( + state.momenta, state.masses, graph_idx=state.graph_idx + ) - # Calculate degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - DOF_per_batch = ( - n_atoms_per_batch * state.positions.shape[-1] + # Calculate degrees of freedom per graph + n_atoms_per_graph = torch.bincount(state.graph_idx) + DOF_per_graph = ( + n_atoms_per_graph * state.positions.shape[-1] ) # n_atoms * n_dimensions # Initialize total energy with PE + KE if isinstance(e_pot, torch.Tensor) and e_pot.ndim > 0: - e_tot = e_pot + e_kin_per_batch # [n_batches] + e_tot = e_pot + e_kin_per_graph # [n_graphs] else: - e_tot = e_pot + e_kin_per_batch # [n_batches] + e_tot = e_pot + e_kin_per_graph # [n_graphs] # Add thermostat chain contributions - # Note: These are global thermostat variables, so we add them to each batch + # Note: These are global thermostat variables, so we add them to each graph # Start thermostat_energy as a tensor with the right shape thermostat_energy = torch.zeros_like(e_tot) thermostat_energy += (state.thermostat.momenta[0] ** 2) / ( 2 * state.thermostat.masses[0] ) - # Ensure kT can broadcast properly with DOF_per_batch + # Ensure kT can broadcast properly with DOF_per_graph if isinstance(kT, torch.Tensor) and kT.ndim == 0: - # Scalar kT - expand to match DOF_per_batch shape - kT_expanded = kT.expand_as(DOF_per_batch) + # Scalar kT - expand to match DOF_per_graph shape + kT_expanded = kT.expand_as(DOF_per_graph) else: kT_expanded = kT - thermostat_energy += DOF_per_batch * kT_expanded * state.thermostat.positions[0] + thermostat_energy += DOF_per_graph * kT_expanded * state.thermostat.positions[0] # Add remaining thermostat terms for pos, momentum, mass in zip( @@ -1660,11 +1658,11 @@ def npt_nose_hoover_invariant( e_tot = e_tot + barostat_energy - # Add PV term and cell kinetic energy (both are per batch) + # Add PV term and cell kinetic energy (both are per graph) e_tot += external_pressure * volume e_tot += (state.cell_momentum**2) / (2 * state.cell_mass) - # Return scalar if single batch, otherwise return per-batch values - if state.n_batches == 1: + # Return scalar if single graph, otherwise return per-graph values + if state.n_graphs == 1: return e_tot.squeeze() return e_tot diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index f10e2f73..5a003e55 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -37,9 +37,9 @@ def nve( Args: model (torch.nn.Module): Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_graphs] kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_batches] + either scalar or with shape [n_graphs] seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -72,7 +72,7 @@ def nve_init( containing positions, masses, cell, pbc, and other required state variables kT (torch.Tensor): Temperature in energy units for initializing momenta, - scalar or with shape [n_batches] + scalar or with shape [n_graphs] seed (int, optional): Random seed for reproducibility Returns: @@ -88,7 +88,7 @@ def nve_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.graph_idx, kT, seed), ) initial_state = MDState( @@ -99,7 +99,7 @@ def nve_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + graph_idx=state.graph_idx, atomic_numbers=state.atomic_numbers, ) return initial_state # noqa: RET504 @@ -116,7 +116,7 @@ def nve_update(state: MDState, dt: torch.Tensor = dt, **_) -> MDState: Args: state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] **_: Additional unused keyword arguments (for compatibility) Returns: diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index e446929d..218d2ff3 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -45,11 +45,11 @@ def nvt_langevin( # noqa: C901 Args: model (torch.nn.Module): Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or with shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] gamma (torch.Tensor, optional): Friction coefficient for Langevin thermostat, - either scalar or with shape [n_batches]. Defaults to 1/(100*dt). + either scalar or with shape [n_graphs]. Defaults to 1/(100*dt). seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: @@ -93,11 +93,11 @@ def ou_step( Args: state (MDState): Current system state containing positions, momenta, etc. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] gamma (torch.Tensor): Friction coefficient controlling noise strength, - either scalar or with shape [n_batches] + either scalar or with shape [n_graphs] Returns: MDState: Updated state with new momenta after stochastic step @@ -114,12 +114,12 @@ def ou_step( c1 = torch.exp(-gamma * dt) if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: - # kT is a tensor with shape (n_batches,) - kT = kT[state.batch] + # kT is a tensor with shape (n_graphs,) + kT = kT[state.graph_idx] - # Index c1 and c2 with state.batch to align shapes with state.momenta + # Index c1 and c2 with state.graph_idx to align shapes with state.momenta if isinstance(c1, torch.Tensor) and len(c1.shape) > 0: - c1 = c1[state.batch] + c1 = c1[state.graph_idx] c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) @@ -147,7 +147,7 @@ def langevin_init( state (SimState | StateDict): Either a SimState object or a dictionary containing positions, masses, cell, pbc, and other required state vars kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_batches] + either scalar or with shape [n_graphs] seed (int, optional): Random seed for reproducibility Returns: @@ -167,7 +167,7 @@ def langevin_init( momenta = getattr( state, "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.graph_idx, kT, seed), ) initial_state = MDState( @@ -178,7 +178,7 @@ def langevin_init( masses=state.masses, cell=state.cell, pbc=state.pbc, - batch=state.batch, + graph_idx=state.graph_idx, atomic_numbers=state.atomic_numbers, ) return initial_state # noqa: RET504 @@ -202,11 +202,11 @@ def langevin_update( Args: state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_batches] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_graphs] kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_batches] + with shape [n_graphs] gamma (torch.Tensor): Friction coefficient for Langevin thermostat, - either scalar or with shape [n_batches] + either scalar or with shape [n_graphs] Returns: MDState: Updated state after one complete Langevin step with new positions, @@ -363,21 +363,21 @@ def nvt_nose_hoover_init( model_output = model(state) momenta = kwargs.get( "momenta", - calculate_momenta(state.positions, state.masses, state.batch, kT, seed), + calculate_momenta(state.positions, state.masses, state.graph_idx, kT, seed), ) - # Calculate initial kinetic energy per batch - KE = calc_kinetic_energy(momenta, state.masses, batch=state.batch) + # Calculate initial kinetic energy per graph + KE = calc_kinetic_energy(momenta, state.masses, graph_idx=state.graph_idx) - # Calculate degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - dof_per_batch = ( - n_atoms_per_batch * state.positions.shape[-1] + # Calculate degrees of freedom per graph + n_atoms_per_graph = torch.bincount(state.graph_idx) + dof_per_graph = ( + n_atoms_per_graph * state.positions.shape[-1] ) # n_atoms * n_dimensions - # For now, sum the per-batch DOF as chain expects a single int + # For now, sum the per-graph DOF as chain expects a single int # This is a limitation that should be addressed in the chain implementation - total_dof = int(dof_per_batch.sum().item()) + total_dof = int(dof_per_graph.sum().item()) # Initialize state state = NVTNoseHooverState( @@ -431,8 +431,8 @@ def nvt_nose_hoover_update( # Full velocity Verlet step state = velocity_verlet(state=state, dt=dt, model=model) - # Update chain kinetic energy per batch - KE = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + # Update chain kinetic energy per graph + KE = calc_kinetic_energy(state.momenta, state.masses, graph_idx=state.graph_idx) chain.kinetic_energy = KE # Second half-step of chain evolution @@ -474,13 +474,13 @@ def nvt_nose_hoover_invariant( - Includes both physical and thermostat degrees of freedom - Useful for debugging thermostat behavior """ - # Calculate system energy terms per batch + # Calculate system energy terms per graph e_pot = state.energy - e_kin = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + e_kin = calc_kinetic_energy(state.momenta, state.masses, graph_idx=state.graph_idx) - # Get system degrees of freedom per batch - n_atoms_per_batch = torch.bincount(state.batch) - dof = n_atoms_per_batch * state.positions.shape[-1] # n_atoms * n_dimensions + # Get system degrees of freedom per graph + n_atoms_per_graph = torch.bincount(state.graph_idx) + dof = n_atoms_per_graph * state.positions.shape[-1] # n_atoms * n_dimensions # Start with system energy e_tot = e_pot + e_kin diff --git a/torch_sim/io.py b/torch_sim/io.py index f9505062..77eca718 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -33,7 +33,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[Atoms]: ASE Atoms objects, one per batch + list[Atoms]: ASE Atoms objects, one per graph Raises: ImportError: If ASE is not installed @@ -50,22 +50,22 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_graphs, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + graph_idx = state.graph_idx.detach().cpu().numpy() atoms_list = [] - for batch_idx in np.unique(batch): - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for ASE convention + for idx in np.unique(graph_idx): + mask = graph_idx == idx + graph_positions = positions[mask] + graph_numbers = atomic_numbers[mask] + graph_cell = cell[idx].T # Transpose for ASE convention # Convert atomic numbers to chemical symbols - symbols = [chemical_symbols[z] for z in batch_numbers] + symbols = [chemical_symbols[z] for z in graph_numbers] atoms = Atoms( - symbols=symbols, positions=batch_positions, cell=batch_cell, pbc=state.pbc + symbols=symbols, positions=graph_positions, cell=graph_cell, pbc=state.pbc ) atoms_list.append(atoms) @@ -79,7 +79,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[Structure]: Pymatgen Structure objects, one per batch + list[Structure]: Pymatgen Structure objects, one per graph Raises: ImportError: If Pymatgen is not installed @@ -98,29 +98,29 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_graphs, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + graph_idx = state.graph_idx.detach().cpu().numpy() - # Get unique batch indices and counts - unique_batches = np.unique(batch) + # Get unique graph indices and counts + unique_graphs = np.unique(graph_idx) structures = [] - for batch_idx in unique_batches: - # Get mask for current batch - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for conventional form + for unique_graph_idx in unique_graphs: + # Get mask for current graph + mask = graph_idx == unique_graph_idx + graph_positions = positions[mask] + graph_numbers = atomic_numbers[mask] + graph_cell = cell[unique_graph_idx].T # Transpose for conventional form # Create species list from atomic numbers - species = [Element.from_Z(z) for z in batch_numbers] + species = [Element.from_Z(z) for z in graph_numbers] - # Create structure for this batch + # Create structure for this graph struct = Structure( - lattice=Lattice(batch_cell), + lattice=Lattice(graph_cell), species=species, - coords=batch_positions, + coords=graph_positions, coords_are_cartesian=True, ) structures.append(struct) @@ -135,7 +135,7 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: - list[PhonopyAtoms]: PhonopyAtoms objects, one per batch + list[PhonopyAtoms]: PhonopyAtoms objects, one per graph Raises: ImportError: If Phonopy is not installed @@ -152,24 +152,24 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() - cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) + cell = state.cell.detach().cpu().numpy() # Shape: (n_graphs, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - batch = state.batch.detach().cpu().numpy() + graph_idx = state.graph_idx.detach().cpu().numpy() phonopy_atoms_list = [] - for batch_idx in np.unique(batch): - mask = batch == batch_idx - batch_positions = positions[mask] - batch_numbers = atomic_numbers[mask] - batch_cell = cell[batch_idx].T # Transpose for Phonopy convention + for idx in np.unique(graph_idx): + mask = graph_idx == idx + graph_positions = positions[mask] + graph_numbers = atomic_numbers[mask] + graph_cell = cell[idx].T # Transpose for Phonopy convention # Convert atomic numbers to chemical symbols - symbols = [chemical_symbols[z] for z in batch_numbers] + symbols = [chemical_symbols[z] for z in graph_numbers] phonopy_atoms_list.append( PhonopyAtoms( symbols=symbols, - positions=batch_positions, - cell=batch_cell, + positions=graph_positions, + cell=graph_cell, pbc=state.pbc, ) ) @@ -225,10 +225,10 @@ def atoms_to_state( np.stack([a.cell.array.T for a in atoms_list]), dtype=dtype, device=device ) - # Create batch indices using repeat_interleave - atoms_per_batch = torch.tensor([len(a) for a in atoms_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(atoms_list), device=device), atoms_per_batch + # Create graph indices using repeat_interleave + atoms_per_graph = torch.tensor([len(a) for a in atoms_list], device=device) + graph_idx = torch.repeat_interleave( + torch.arange(len(atoms_list), device=device), atoms_per_graph ) # Verify consistent pbc @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - batch=batch, + graph_idx=graph_idx, ) @@ -297,10 +297,10 @@ def structures_to_state( device=device, ) - # Create batch indices - atoms_per_batch = torch.tensor([len(s) for s in struct_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(struct_list), device=device), atoms_per_batch + # Create graph indices + atoms_per_graph = torch.tensor([len(s) for s in struct_list], device=device) + graph_idx = torch.repeat_interleave( + torch.arange(len(struct_list), device=device), atoms_per_graph ) return ts.SimState( @@ -309,7 +309,7 @@ def structures_to_state( cell=cell, pbc=True, # Structures are always periodic atomic_numbers=atomic_numbers, - batch=batch, + graph_idx=graph_idx, ) @@ -368,10 +368,10 @@ def phonopy_to_state( np.stack([a.cell.T for a in phonopy_atoms_list]), dtype=dtype, device=device ) - # Create batch indices using repeat_interleave - atoms_per_batch = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) - batch = torch.repeat_interleave( - torch.arange(len(phonopy_atoms_list), device=device), atoms_per_batch + # Create graph indices using repeat_interleave + atoms_per_graph = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) + graph_idx = torch.repeat_interleave( + torch.arange(len(phonopy_atoms_list), device=device), atoms_per_graph ) """ @@ -387,5 +387,5 @@ def phonopy_to_state( cell=cell, pbc=True, atomic_numbers=atomic_numbers, - batch=batch, + graph_idx=graph_idx, ) diff --git a/torch_sim/math.py b/torch_sim/math.py index 40228ba4..cfb085c4 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -1000,7 +1000,7 @@ def batched_vdot( batch_indices: Tensor of shape [N_total_entities] indicating batch membership. Returns: - Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + Tensor: shape [n_graphs] where each element is the sum(x_i * y_i) for entities belonging to that batch, summed over all components D and all entities in the batch. """ diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 3197c6aa..4e6e386a 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -349,8 +349,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict: if state.device != self._device: state = state.to(self._device) - if state.batch is None: - state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + if state.graph_idx is None: + state.graph_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) if self.pbc != state.pbc: raise ValueError( @@ -358,8 +358,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict: "For FairChemModel PBC needs to be defined in the model class." ) - natoms = torch.bincount(state.batch) - fixed = torch.zeros((state.batch.size(0), natoms.sum()), dtype=torch.int) + natoms = torch.bincount(state.graph_idx) + fixed = torch.zeros((state.graph_idx.size(0), natoms.sum()), dtype=torch.int) data_list = [] for i, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index af88e94d..539e262c 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -67,8 +67,8 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra """ graphs = [] - for i in range(state.n_batches): - batch_mask = state.batch == i + for i in range(state.n_graphs): + batch_mask = state.graph_idx == i R = state.positions[batch_mask] Z = state.atomic_numbers[batch_mask] cell = state.row_vector_cell[i] diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index a6ba406b..701cc699 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -65,9 +65,9 @@ class ModelInterface(ABC): output = model(sim_state) # Access computed properties - energy = output["energy"] # Shape: [n_batches] + energy = output["energy"] # Shape: [n_graphs] forces = output["forces"] # Shape: [n_atoms, 3] - stress = output["stress"] # Shape: [n_batches, 3, 3] + stress = output["stress"] # Shape: [n_graphs, 3, 3] ``` """ @@ -174,16 +174,16 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens dictionary is dependent on the model but typically must contain the following keys: - "positions": Atomic positions with shape [n_atoms, 3] - - "cell": Unit cell vectors with shape [n_batches, 3, 3] - - "batch": Batch indices for each atom with shape [n_atoms] + - "cell": Unit cell vectors with shape [n_graphs, 3, 3] + - "graph_idx": Graph indices for each atom with shape [n_atoms] - "atomic_numbers": Atomic numbers with shape [n_atoms] (optional) **kwargs: Additional model-specific parameters. Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - May include additional model-specific outputs @@ -256,7 +256,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() - og_batch = sim_state.batch.clone() + og_graph_idx = sim_state.graph_idx.clone() og_atomic_numbers = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -266,8 +266,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_positions=} != {sim_state.positions=}") if not torch.allclose(og_cell, sim_state.cell): raise ValueError(f"{og_cell=} != {sim_state.cell=}") - if not torch.allclose(og_batch, sim_state.batch): - raise ValueError(f"{og_batch=} != {sim_state.batch=}") + if not torch.allclose(og_graph_idx, sim_state.graph_idx): + raise ValueError(f"{og_graph_idx=} != {sim_state.graph_idx=}") if not torch.allclose(og_atomic_numbers, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_numbers=} != {sim_state.atomic_numbers=}") diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 2e98ad22..ee60b0c4 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -357,7 +357,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Lennard-Jones energies, forces, and stresses for a system. Main entry point for Lennard-Jones calculations that handles batched states by - dispatching each batch to the unbatched implementation and combining results. + dispatching each graph to the unbatched implementation and combining results. Args: state (SimState | StateDict): Input state containing atomic positions, @@ -366,10 +366,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) @@ -377,7 +377,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: per_atom_stresses=True) Raises: - ValueError: If batch cannot be inferred for multi-cell systems. + ValueError: If graph cannot be inferred for multi-cell systems. Example:: @@ -385,19 +385,19 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = LennardJonesModel(compute_stress=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_graphs] forces = results["forces"] # Shape: [n_atoms, 3] - stress = results["stress"] # Shape: [n_batches, 3, 3] + stress = results["stress"] # Shape: [n_graphs, 3, 3] energies = results["energies"] # Shape: [n_atoms] stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + if state.graph_idx is None and state.cell.shape[0] > 1: + raise ValueError("Graph can only be inferred for batch size 1.") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index d9d92b9b..a6b82bf3 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -91,7 +91,7 @@ class MaceModel(torch.nn.Module, ModelInterface): model (torch.nn.Module): The underlying MACE neural network model. neighbor_list_fn (Callable): Function used to compute neighbor lists. atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms]. - batch (torch.Tensor): Batch indices with shape [n_atoms]. + graph_idx (torch.Tensor): Graph indices with shape [n_atoms]. n_systems (int): Number of systems in the batch. n_atoms_per_system (list[int]): Number of atoms in each system. ptr (torch.Tensor): Pointers to the start of each system in the batch with @@ -112,13 +112,13 @@ def __init__( compute_stress: bool = True, enable_cueq: bool = False, atomic_numbers: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + graph_idx: torch.Tensor | None = None, ) -> None: """Initialize the MACE model for energy and force calculations. Sets up the MACE model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers - and batch indices, or these can be provided during the forward pass. + and graph indices, or these can be provided during the forward pass. Args: model (str | Path | torch.nn.Module | None): The MACE neural network model, @@ -129,7 +129,7 @@ def __init__( Defaults to torch.float64. atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. If provided at initialization, cannot be provided again during forward. - batch (torch.Tensor | None): Batch indices with shape [n_atoms] indicating + graph_idx (torch.Tensor | None): Graph indices with shape [n_atoms] indicating which system each atom belongs to. If not provided with atomic_numbers, all atoms are assumed to be in the same system. neighbor_list_fn (Callable): Function to compute neighbor lists. @@ -186,38 +186,40 @@ def __init__( # Set up batch information if atomic numbers are provided if atomic_numbers is not None: - if batch is None: + if graph_idx is None: # If batch is not provided, assume all atoms belong to same system - batch = torch.zeros( + graph_idx = torch.zeros( len(atomic_numbers), dtype=torch.long, device=self.device ) - self.setup_from_batch(atomic_numbers, batch) + self.setup_from_batch(atomic_numbers, graph_idx) - def setup_from_batch(self, atomic_numbers: torch.Tensor, batch: torch.Tensor) -> None: - """Set up internal state from atomic numbers and batch indices. + def setup_from_batch( + self, atomic_numbers: torch.Tensor, graph_idx: torch.Tensor + ) -> None: + """Set up internal state from atomic numbers and graph indices. - Processes the atomic numbers and batch indices to prepare the model for + Processes the atomic numbers and graph indices to prepare the model for forward pass calculations. Creates the necessary data structures for batched processing of multiple systems. Args: atomic_numbers (torch.Tensor): Atomic numbers tensor with shape [n_atoms]. - batch (torch.Tensor): Batch indices tensor with shape [n_atoms] indicating + graph_idx (torch.Tensor): Graph indices tensor with shape [n_atoms] indicating which system each atom belongs to. """ self.atomic_numbers = atomic_numbers - self.batch = batch + self.graph_idx = graph_idx # Determine number of systems and atoms per system - self.n_systems = batch.max().item() + 1 + self.n_systems = graph_idx.max().item() + 1 - # Create ptr tensor for batch boundaries + # Create ptr tensor for graph boundaries self.n_atoms_per_system = [] ptr = [0] - for b in range(self.n_systems): - batch_mask = batch == b - n_atoms = batch_mask.sum().item() + for i in range(self.n_systems): + graph_mask = graph_idx == i + n_atoms = graph_mask.sum().item() self.n_atoms_per_system.append(n_atoms) ptr.append(ptr[-1] + n_atoms) @@ -260,7 +262,7 @@ def forward( # noqa: C901 Raises: ValueError: If atomic numbers are not provided either in the constructor or in the forward pass, or if provided in both places. - ValueError: If batch indices are not provided when needed. + ValueError: If graph indices are not provided when needed. """ # Extract required data from input if isinstance(state, dict): @@ -276,13 +278,13 @@ def forward( # noqa: C901 "Atomic numbers cannot be provided in both the constructor and forward." ) - # Use batch from init if not provided - if state.batch is None: - if not hasattr(self, "batch"): + # Use graph_idx from init if not provided + if state.graph_idx is None: + if not hasattr(self, "graph_idx"): raise ValueError( - "Batch indices must be provided if not set during initialization" + "Graph indices must be provided if not set during initialization" ) - state.batch = self.batch + state.graph_idx = self.graph_idx # Update batch information if new atomic numbers are provided if ( @@ -293,7 +295,7 @@ def forward( # noqa: C901 getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), ) ): - self.setup_from_batch(state.atomic_numbers, state.batch) + self.setup_from_batch(state.atomic_numbers, state.graph_idx) # Process each system's neighbor list separately edge_indices = [] @@ -303,7 +305,7 @@ def forward( # noqa: C901 # TODO (AG): Currently doesn't work for batched neighbor lists for b in range(self.n_systems): - batch_mask = state.batch == b + batch_mask = state.graph_idx == b # Calculate neighbor list for this system edge_idx, shifts_idx = self.neighbor_list_fn( positions=state.positions[batch_mask], @@ -332,7 +334,7 @@ def forward( # noqa: C901 dict( ptr=self.ptr, node_attrs=self.node_attrs, - batch=state.batch, + batch=state.graph_idx, pbc=state.pbc, cell=state.row_vector_cell, positions=state.positions, diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 31d26966..91f989c9 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -74,7 +74,7 @@ def __init__( Sets up a metatomic model for energy, force, and stress calculations within the TorchSim framework. The model can be initialized with atomic numbers - and batch indices, or these can be provided during the forward pass. + and graph indices, or these can be provided during the forward pass. Args: model (str | Path | None): Path to the metatomic model file or a @@ -200,7 +200,7 @@ def forward( # noqa: C901, PLR0915 systems: list[System] = [] strains = [] for b in range(len(cell)): - system_mask = state.batch == b + system_mask = state.graph_idx == b system_positions = positions[system_mask] system_cell = cell[b] system_pbc = torch.tensor( diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 8b446a3c..b55d16f6 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -356,10 +356,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration @@ -372,17 +372,17 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = MorseModel(compute_forces=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_graphs] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: + if state.graph_idx is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 551a9b01..7b54b4d6 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -102,7 +102,7 @@ def state_to_atom_graphs( # noqa: PLR0915 system_config = SystemConfig(radius=6.0, max_num_neighbors=20) # Handle batch information if present - n_node = torch.bincount(state.batch) + n_node = torch.bincount(state.graph_idx) # Set default dtype if not provided output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype @@ -143,7 +143,7 @@ def state_to_atom_graphs( # noqa: PLR0915 if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)): positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node) - n_systems = state.batch.max().item() + 1 + n_systems = state.graph_idx.max().item() + 1 # Prepare lists to collect data from each system all_edges = [] @@ -157,7 +157,7 @@ def state_to_atom_graphs( # noqa: PLR0915 # Process each system in a single loop offset = 0 for i in range(n_systems): - batch_mask = state.batch == i + batch_mask = state.graph_idx == i positions_per_system = positions[batch_mask] atomic_numbers_per_system = atomic_numbers[batch_mask] atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 48110a86..30b0d61d 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -223,10 +223,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] (if + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - "energies": Per-atom energies with shape [n_atoms] (if per_atom_energies=True) @@ -239,10 +239,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - if state.batch is None and state.cell.shape[0] > 1: + if state.graph_idx is None and state.cell.shape[0] > 1: raise ValueError("Batch can only be inferred for batch size 1.") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] # we always return tensors diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 45e59f9b..87fed99f 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -181,8 +181,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: state = state.clone() data_list = [] - for b in range(state.batch.max().item() + 1): - batch_mask = state.batch == b + for b in range(state.graph_idx.max().item() + 1): + batch_mask = state.graph_idx == b pos = state.positions[batch_mask] # SevenNet uses row vector cell convention for neighbor list @@ -245,7 +245,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: results["energy"] = energy.detach() else: results["energy"] = torch.zeros( - state.batch.max().item() + 1, device=self.device + state.graph_idx.max().item() + 1, device=self.device ) forces = output[key.PRED_FORCE] diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 8cbc0e7f..e12b8e08 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -381,7 +381,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Compute soft sphere potential energies, forces, and stresses for a system. Main entry point for soft sphere potential calculations that handles batched - states by dispatching each batch to the unbatched implementation and combining + states by dispatching each graph to the unbatched implementation and combining results. Args: @@ -391,15 +391,15 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration Raises: - ValueError: If batch cannot be inferred for multi-cell systems. + ValueError: If graph indices cannot be inferred for multi-cell systems. Examples: ```py @@ -407,18 +407,18 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: model = SoftSphereModel(compute_forces=True) results = model(sim_state) - energy = results["energy"] # Shape: [n_batches] + energy = results["energy"] # Shape: [n_graphs] forces = results["forces"] # Shape: [n_atoms, 3] ``` """ if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) - # Handle batch indices if not provided - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + # Handle Graph indices if not provided + if state.graph_idx is None and state.cell.shape[0] > 1: + raise ValueError("graph_idx can only be inferred if there is only one graph") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] # Combine results @@ -816,10 +816,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Returns: dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_batches] + - "energy": Potential energy with shape [n_graphs] - "forces": Atomic forces with shape [n_atoms, 3] (if compute_forces=True) - - "stress": Stress tensor with shape [n_batches, 3, 3] + - "stress": Stress tensor with shape [n_graphs, 3, 3] (if compute_stress=True) - May include additional outputs based on configuration @@ -854,11 +854,11 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: elif state.pbc != self.pbc: raise ValueError("PBC mismatch between model and state") - # Handle batch indices if not provided - if state.batch is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + # Handle graph indices if not provided + if state.graph_idx is None and state.cell.shape[0] > 1: + raise ValueError("graph_idx can only be inferred if there is only one graph") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_batches)] + outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] # Combine results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 389ce687..b0d3d969 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -41,7 +41,7 @@ def generate_swaps( ) -> torch.Tensor: """Generate atom swaps for a given batched system. - Generates proposed swaps between atoms of different types within the same batch. + Generates proposed swaps between atoms of different types within the same graph. The function ensures that swaps only occur between atoms with different atomic numbers. @@ -51,48 +51,48 @@ def generate_swaps( reproducibility. Defaults to None. Returns: - torch.Tensor: A tensor of proposed swaps with shape [n_batches, 2], + torch.Tensor: A tensor of proposed swaps with shape [n_graphs, 2], where each row contains indices of atoms to be swapped """ - batch = state.batch + graph = state.graph_idx atomic_numbers = state.atomic_numbers - batch_lengths = batch.bincount() + graph_lengths = graph.bincount() - # change batch_lengths to batch - batch = torch.repeat_interleave( - torch.arange(len(batch_lengths), device=batch.device), batch_lengths + # change graph_lengths to graph + graph = torch.repeat_interleave( + torch.arange(len(graph_lengths), device=graph.device), graph_lengths ) # Create ragged weights tensor without loops - max_length = torch.max(batch_lengths).item() - n_batches = len(batch_lengths) + max_length = torch.max(graph_lengths).item() + n_graphs = len(graph_lengths) - # Create a range tensor for each batch - range_tensor = torch.arange(max_length, device=batch.device).expand( - n_batches, max_length + # Create a range tensor for each graph + range_tensor = torch.arange(max_length, device=graph.device).expand( + n_graphs, max_length ) - # Create a mask where values are less than the batch length - batch_lengths_expanded = batch_lengths.unsqueeze(1).expand(n_batches, max_length) - weights = (range_tensor < batch_lengths_expanded).float() + # Create a mask where values are less than the max graph length + graph_lengths_expanded = graph_lengths.unsqueeze(1).expand(n_graphs, max_length) + weights = (range_tensor < graph_lengths_expanded).float() first_index = torch.multinomial(weights, 1, replacement=False, generator=generator) - # Process each batch - we need this loop because of ragged batches - batch_starts = batch_lengths.cumsum(dim=0) - batch_lengths[0] + # Process each graph - we need this loop because of ragged graphs + graph_starts = graph_lengths.cumsum(dim=0) - graph_lengths[0] - for b in range(n_batches): + for b in range(n_graphs): # Get global index of selected atom - first_idx = first_index[b, 0].item() + batch_starts[b].item() + first_idx = first_index[b, 0].item() + graph_starts[b].item() first_type = atomic_numbers[first_idx] - # Get indices of atoms in this batch - batch_start = batch_starts[b].item() - batch_end = batch_start + batch_lengths[b].item() + # Get indices of atoms in this graph + graph_start = graph_starts[b].item() + graph_end = graph_start + graph_lengths[b].item() # Create mask for same-type atoms - same_type = atomic_numbers[batch_start:batch_end] == first_type + same_type = atomic_numbers[graph_start:graph_end] == first_type # Zero out weights for same-type atoms (accounting for padding) weights[b, : len(same_type)][same_type] = 0.0 @@ -100,7 +100,7 @@ def generate_swaps( second_index = torch.multinomial(weights, 1, replacement=False, generator=generator) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) - return zeroed_swaps + (batch_lengths.cumsum(dim=0) - batch_lengths[0]).unsqueeze(1) + return zeroed_swaps + (graph_lengths.cumsum(dim=0) - graph_lengths[0]).unsqueeze(1) def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: @@ -124,21 +124,21 @@ def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: return permutation -def validate_permutation(permutation: torch.Tensor, batch: torch.Tensor) -> None: - """Validate that permutations only swap atoms within the same batch. +def validate_permutation(permutation: torch.Tensor, graph_idx: torch.Tensor) -> None: + """Validate that permutations only swap atoms within the same graph. - Confirms that no swaps are attempted between atoms in different batches, + Confirms that no swaps are attempted between atoms in different graphs, which would lead to physically invalid configurations. Args: permutation (torch.Tensor): Permutation tensor of shape [n_atoms] - batch (torch.Tensor): Batch assignments for each atom of shape [n_atoms] + graph_idx (torch.Tensor): graph_idx for each atom of shape [n_atoms] Raises: - ValueError: If any swaps are between atoms in different batches + ValueError: If any swaps are between atoms in different graphs """ - if not torch.all(batch == batch[permutation]): - raise ValueError("Swaps must be between atoms in the same batch") + if not torch.all(graph_idx == graph_idx[permutation]): + raise ValueError("Swaps must be between atoms in the same graph") def metropolis_criterion( @@ -233,7 +233,7 @@ def init_swap_mc_state(state: SimState) -> SwapMCState: cell=state.cell, pbc=state.pbc, atomic_numbers=state.atomic_numbers, - batch=state.batch, + graph_idx=state.graph_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), ) @@ -260,12 +260,12 @@ def swap_monte_carlo_step( Notes: The function handles batched systems and ensures that swaps only occur - within the same batch. + within the same graph. """ swaps = generate_swaps(state, generator=generator) permutation = swaps_to_permutation(swaps, state.n_atoms) - validate_permutation(permutation, state.batch) + validate_permutation(permutation, state.graph_idx) energies_old = state.energy.clone() state.positions = state.positions[permutation].clone() diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index cd4c439d..f344fd86 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -766,7 +766,7 @@ def torch_nl_linked_cell( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - batch: torch.Tensor, + graph_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using the linked @@ -784,7 +784,7 @@ def torch_nl_linked_cell( pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. - batch (torch.Tensor [n_atom,] torch.long): + graph_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. @@ -806,7 +806,7 @@ def torch_nl_linked_cell( References: - https://github.com/felixmusil/torch_nl """ - n_atoms = torch.bincount(batch) + n_atoms = torch.bincount(graph_idx) mapping, batch_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 5c98dafe..74673712 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -44,12 +44,12 @@ class GDState(SimState): Attributes: positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_graphs, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + graph_idx (torch.Tensor): Graph indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_batches] + energy (torch.Tensor): Potential energy with shape [n_graphs] """ forces: torch.Tensor @@ -68,8 +68,8 @@ def gradient_descent( Args: model (torch.nn.Module): Model that computes energies and forces lr (torch.Tensor | float): Learning rate(s) for optimization. Can be a single - float applied to all batches or a tensor with shape [n_batches] for - batch-specific rates + float applied to all graphs or a tensor with shape [n_graphs] for + graph-specific rates Returns: tuple: A pair of functions: @@ -113,7 +113,7 @@ def gd_init( cell=state.cell, pbc=state.pbc, atomic_numbers=atomic_numbers, - batch=state.batch, + graph_idx=state.graph_idx, ) def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: @@ -129,9 +129,9 @@ def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: """ # Get per-atom learning rates by mapping batch learning rates to atoms if isinstance(lr, float): - lr = torch.full((state.n_batches,), lr, device=device, dtype=dtype) + lr = torch.full((state.n_graphs,), lr, device=device, dtype=dtype) - atom_lr = lr[state.batch].unsqueeze(-1) # shape: (total_atoms, 1) + atom_lr = lr[state.graph_idx].unsqueeze(-1) # shape: (total_atoms, 1) # Update positions using forces and per-atom learning rates state.positions = state.positions + atom_lr * state.forces @@ -160,25 +160,25 @@ class UnitCellGDState(GDState, DeformGradMixin): # Inherited from GDState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_graphs, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + graph_idx (torch.Tensor): Graph indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_batches] + energy (torch.Tensor): Potential energy with shape [n_graphs] # Additional attributes for cell optimization - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + stress (torch.Tensor): Stress tensor with shape [n_graphs, 3, 3] reference_cell (torch.Tensor): Reference unit cells with shape - [n_batches, 3, 3] + [n_graphs, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_batches, 1, 1] + [n_graphs, 1, 1] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] - cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + pressure (torch.Tensor): Applied pressure tensor with shape [n_graphs, 3, 3] + cell_positions (torch.Tensor): Cell positions with shape [n_graphs, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_graphs, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_graphs, 3] """ # Required attributes not in BatchedGDState @@ -224,7 +224,7 @@ def unit_cell_gradient_descent( # noqa: PLR0915, C901 is 0.01. cell_lr (float): Learning rate for unit cell optimization. Default is 0.1. cell_factor (float | torch.Tensor | None): Scaling factor for cell - optimization. If None, defaults to number of atoms per batch + optimization. If None, defaults to number of atoms per graph hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling). Default is False. constant_volume (bool): Whether to maintain constant volume during optimization @@ -270,25 +270,25 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) - n_batches = state.n_batches + n_graphs = state.n_graphs # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per graph + _, counts = torch.unique(state.graph_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all graphs cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_graphs,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_graphs, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_graphs, 1, 1) scalar_pressure = torch.full( - (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype + (state.n_graphs, 1, 1), scalar_pressure, device=device, dtype=dtype ) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device) @@ -297,11 +297,11 @@ def gd_init( model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] - stress = model_output["stress"] # Already shape: (n_batches, 3, 3) + stress = model_output["stress"] # Already shape: (n_graphs, 3, 3) # Create cell masses cell_masses = torch.ones( - (state.n_batches, 3), device=device, dtype=dtype + (state.n_graphs, 3), device=device, dtype=dtype ) # One mass per cell DOF # Get current deformation gradient @@ -311,27 +311,27 @@ def gd_init( # Calculate cell positions cell_factor_expanded = cell_factor.expand( - state.n_batches, 3, 1 - ) # shape: (n_batches, 3, 1) + state.n_graphs, 3, 1 + ) # shape: (n_graphs, 3, 1) cell_positions = ( - cur_deform_grad.reshape(state.n_batches, 3, 3) * cell_factor_expanded - ) # shape: (n_batches, 3, 3) + cur_deform_grad.reshape(state.n_graphs, 3, 3) * cell_factor_expanded + ) # shape: (n_graphs, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_graphs, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(state.n_batches, -1, -1) + ).expand(state.n_graphs, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(state.n_batches, -1, -1) + ).unsqueeze(0).expand(state.n_graphs, -1, -1) return UnitCellGDState( positions=state.positions, @@ -347,7 +347,7 @@ def gd_init( constant_volume=constant_volume, pressure=pressure, atomic_numbers=state.atomic_numbers, - batch=state.batch, + graph_idx=state.graph_idx, cell_positions=cell_positions, cell_forces=virial / cell_factor, cell_masses=cell_masses, @@ -371,29 +371,29 @@ def gd_step( Updated UnitCellGDState after one optimization step """ # Get dimensions - n_batches = state.n_batches + n_graphs = state.n_graphs - # Get per-atom learning rates by mapping batch learning rates to atoms + # Get per-atom learning rates by mapping graph learning rates to atoms if isinstance(positions_lr, float): positions_lr = torch.full( - (state.n_batches,), positions_lr, device=device, dtype=dtype + (state.n_graphs,), positions_lr, device=device, dtype=dtype ) if isinstance(cell_lr, float): - cell_lr = torch.full((state.n_batches,), cell_lr, device=device, dtype=dtype) + cell_lr = torch.full((state.n_graphs,), cell_lr, device=device, dtype=dtype) # Get current deformation gradient cur_deform_grad = state.deform_grad() # Calculate cell positions from deformation gradient - cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_expanded = state.cell_factor.expand(n_graphs, 3, 1) cell_positions = ( - cur_deform_grad.reshape(n_batches, 3, 3) * cell_factor_expanded - ) # shape: (n_batches, 3, 3) + cur_deform_grad.reshape(n_graphs, 3, 3) * cell_factor_expanded + ) # shape: (n_graphs, 3, 3) # Get per-atom and per-cell learning rates - atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) + atom_wise_lr = positions_lr[state.graph_idx].unsqueeze(-1) + cell_wise_lr = cell_lr.view(n_graphs, 1, 1) # shape: (n_graphs, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,18 +415,18 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_graphs, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_graphs, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) # Update cell forces state.cell_positions = cell_positions_new @@ -450,21 +450,21 @@ class FireState(SimState): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_graphs, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + graph_idx (torch.Tensor): Graph indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] + energy (torch.Tensor): Energy per graph with shape [n_graphs] # FIRE optimization parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per graph with shape [n_graphs] + alpha (torch.Tensor): Current mixing parameter per graph with shape [n_graphs] + n_pos (torch.Tensor): Number of positive power steps per graph with shape + [n_graphs] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -558,8 +558,8 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per graph + alpha_start: Initial mixing parameter per graph Returns: FireState with initialized optimization tensors @@ -568,18 +568,18 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_graphs = state.n_graphs # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_graphs] forces = model_output["forces"] # [n_total_atoms, 3] # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_graphs,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_graphs,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_graphs,), device=device, dtype=torch.int32) return FireState( # Create initial state # Copy SimState attributes @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - batch=state.batch.clone(), + graph_idx=state.graph_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, @@ -630,36 +630,36 @@ class UnitCellFireState(SimState, DeformGradMixin): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_graphs, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + graph_idx (torch.Tensor): Graph indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + energy (torch.Tensor): Energy per graph with shape [n_graphs] + stress (torch.Tensor): Stress tensor with shape [n_graphs, 3, 3] # Cell quantities - cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + cell_positions (torch.Tensor): Cell positions with shape [n_graphs, 3, 3] + cell_velocities (torch.Tensor): Cell velocities with shape [n_graphs, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_graphs, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_graphs, 3] # Cell optimization parameters - reference_cell (torch.Tensor): Original unit cells with shape [n_batches, 3, 3] + reference_cell (torch.Tensor): Original unit cells with shape [n_graphs, 3, 3] cell_factor (torch.Tensor): Cell optimization scaling factor with shape - [n_batches, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] + [n_graphs, 1, 1] + pressure (torch.Tensor): Applied pressure tensor with shape [n_graphs, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # FIRE optimization parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per graph with shape [n_graphs] + alpha (torch.Tensor): Current mixing parameter per graph with shape [n_graphs] + n_pos (torch.Tensor): Number of positive power steps per graph with shape + [n_graphs] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -728,7 +728,7 @@ def unit_cell_fire( alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per batch + If None, defaults to number of atoms per graph hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization @@ -782,11 +782,11 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. - Single value or tensor of shape [n_batches]. + cell_factor: Cell optimization scaling factor. If None, uses atoms per graph. + Single value or tensor of shape [n_graphs]. scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per graph + alpha_start: Initial mixing parameter per graph Returns: UnitCellFireState with initialized optimization tensors @@ -795,64 +795,64 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_graphs = state.n_graphs # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per graph + _, counts = torch.unique(state.graph_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all graphs cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_graphs,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_graphs, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_graphs, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) + pressure = pressure.unsqueeze(0).expand(n_graphs, -1, -1) # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_graphs] forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_batches, 3, 3] + stress = model_output["stress"] # [n_graphs, 3, 3] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_graphs, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_graphs, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) cell_forces = virial / cell_factor - # Sum masses per batch using segment_reduce + # Sum masses per graph using segment_reduce # TODO (AG): check this - batch_counts = torch.bincount(state.batch) + graph_counts = torch.bincount(state.graph_idx) cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=batch_counts - ) # shape: (n_batches,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) + state.masses, reduce="sum", lengths=graph_counts + ) # shape: (n_graphs,) + cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_graphs, 3) # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_graphs,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_graphs,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_graphs,), device=device, dtype=torch.int32) return UnitCellFireState( # Create initial state # Copy SimState attributes @@ -860,14 +860,14 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - batch=state.batch.clone(), + graph_idx=state.graph_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes - cell_positions=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), + cell_positions=torch.zeros(n_graphs, 3, 3, device=device, dtype=dtype), cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, @@ -913,37 +913,37 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] + cell (torch.Tensor): Unit cell vectors with shape [n_graphs, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - batch (torch.Tensor): Batch indices with shape [n_atoms] + graph_idx (torch.Tensor): Graph indices with shape [n_atoms] # Additional atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Energy per batch with shape [n_batches] + energy (torch.Tensor): Energy per graph with shape [n_graphs] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] + stress (torch.Tensor): Stress tensor with shape [n_graphs, 3, 3] # Optimization-specific attributes - reference_cell (torch.Tensor): Original unit cell with shape [n_batches, 3, 3] + reference_cell (torch.Tensor): Original unit cell with shape [n_graphs, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_batches, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] + [n_graphs, 1, 1] + pressure (torch.Tensor): Applied pressure tensor with shape [n_graphs, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # Cell attributes using log parameterization cell_positions (torch.Tensor): Cell positions using log parameterization with - shape [n_batches, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] + shape [n_graphs, 3, 3] + cell_velocities (torch.Tensor): Cell velocities with shape [n_graphs, 3, 3] + cell_forces (torch.Tensor): Cell forces with shape [n_graphs, 3, 3] + cell_masses (torch.Tensor): Cell masses with shape [n_graphs, 3] # FIRE algorithm parameters - dt (torch.Tensor): Current timestep per batch with shape [n_batches] - alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] - n_pos (torch.Tensor): Number of positive power steps per batch with shape - [n_batches] + dt (torch.Tensor): Current timestep per graph with shape [n_graphs] + alpha (torch.Tensor): Current mixing parameter per graph with shape [n_graphs] + n_pos (torch.Tensor): Number of positive power steps per graph with shape + [n_graphs] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], @@ -1013,7 +1013,7 @@ def frechet_cell_fire( alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per batch + If None, defaults to number of atoms per graph hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization @@ -1067,11 +1067,11 @@ def fire_init( Args: state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. - Single value or tensor of shape [n_batches]. + cell_factor: Cell optimization scaling factor. If None, uses atoms per graph. + Single value or tensor of shape [n_graphs]. scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per batch - alpha_start: Initial mixing parameter per batch + dt_start: Initial timestep per graph + alpha_start: Initial mixing parameter per graph Returns: FrechetCellFIREState with initialized optimization tensors @@ -1080,78 +1080,78 @@ def fire_init( state = SimState(**state) # Get dimensions - n_batches = state.n_batches + n_graphs = state.n_graphs # Setup cell_factor if cell_factor is None: - # Count atoms per batch - _, counts = torch.unique(state.batch, return_counts=True) + # Count atoms per graph + _, counts = torch.unique(state.graph_idx, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): - # Use same factor for all batches + # Use same factor for all graphs cell_factor = torch.full( - (state.n_batches,), cell_factor, device=device, dtype=dtype + (state.n_graphs,), cell_factor, device=device, dtype=dtype ) - # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_batches, 1, 1) + # Reshape to (n_graphs, 1, 1) for broadcasting + cell_factor = cell_factor.view(n_graphs, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) + pressure = pressure.unsqueeze(0).expand(n_graphs, -1, -1) # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] # [n_batches] + energy = model_output["energy"] # [n_graphs] forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_batches, 3, 3] + stress = model_output["stress"] # [n_graphs, 3, 3] # Calculate initial cell positions using matrix logarithm # Calculate current deformation gradient (identity matrix at start) cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 state.row_vector_cell, state.row_vector_cell - ) # shape: (n_batches, 3, 3) + ) # shape: (n_graphs, 3, 3) # For identity matrix, logm gives zero matrix # Initialize cell positions to zeros - cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) + cell_positions = torch.zeros((n_graphs, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_graphs, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 - ).expand(n_batches, -1, -1) + ).expand(n_graphs, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) # Calculate UCF-style cell gradient ucf_cell_grad = torch.zeros_like(virial) - for b in range(n_batches): + for b in range(n_graphs): ucf_cell_grad[b] = virial[b] @ torch.linalg.inv(cur_deform_grad[b].T) # Calculate cell forces using Frechet derivative approach (all zeros for identity) cell_forces = ucf_cell_grad / cell_factor - # Sum masses per batch - batch_counts = torch.bincount(state.batch) + # Sum masses per graph + graph_counts = torch.bincount(state.graph_idx) cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=batch_counts - ) # shape: (n_batches,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) + state.masses, reduce="sum", lengths=graph_counts + ) # shape: (n_graphs,) + cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_graphs, 3) # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) + dt_start = torch.full((n_graphs,), dt_start, device=device, dtype=dtype) + alpha_start = torch.full((n_graphs,), alpha_start, device=device, dtype=dtype) + n_pos = torch.zeros((n_graphs,), device=device, dtype=torch.int32) return FrechetCellFIREState( # Create initial state # Copy SimState attributes @@ -1159,7 +1159,7 @@ def fire_init( masses=state.masses, cell=state.cell, atomic_numbers=state.atomic_numbers, - batch=state.batch, + graph_idx=state.graph_idx, pbc=state.pbc, velocities=None, forces=forces, @@ -1239,7 +1239,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 Returns: Updated state after performing one VV-FIRE step. """ - n_batches = state.n_batches + n_graphs = state.n_graphs device = state.positions.device dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None @@ -1252,14 +1252,14 @@ def _vv_fire_step( # noqa: C901, PLR0915 f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( - (n_batches, 3, 3), device=device, dtype=dtype + (n_graphs, 3, 3), device=device, dtype=dtype ) - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype + alpha_start_graph = torch.full( + (n_graphs,), alpha_start.item(), device=device, dtype=dtype ) - atom_wise_dt = state.dt[state.batch].unsqueeze(-1) + atom_wise_dt = state.dt[state.graph_idx].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: @@ -1271,13 +1271,13 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) + cell_factor_reshaped = state.cell_factor.view(n_graphs, 1, 1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() deform_grad_log = torch.zeros_like(cur_deform_grad) - for b in range(n_batches): + for b in range(n_graphs): deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) cell_positions_log_scaled = deform_grad_log * cell_factor_reshaped @@ -1295,9 +1295,9 @@ def _vv_fire_step( # noqa: C901, PLR0915 if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() - cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_expanded = state.cell_factor.expand(n_graphs, 3, 1) current_cell_positions_scaled = ( - cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded + cur_deform_grad.view(n_graphs, 3, 3) * cell_factor_expanded ) cell_positions_scaled_new = ( @@ -1316,19 +1316,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_graphs, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): @@ -1341,7 +1341,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 directions[idx, mu, nu] = 1.0 new_cell_forces = torch.zeros_like(ucf_cell_grad) - for b in range(n_batches): + for b in range(n_graphs): expm_derivs = torch.stack( [ tsm.expm_frechet( @@ -1366,49 +1366,51 @@ def _vv_fire_step( # noqa: C901, PLR0915 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) + graph_power = tsm.batched_vdot(state.forces, state.velocities, state.graph_idx) if is_cell_optimization: - batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + graph_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + pos_mask_graph = graph_power > 0.0 + neg_mask_graph = ~pos_mask_graph - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.n_pos[pos_mask_graph] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_graph state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.dt[neg_mask_graph] *= f_dec + state.alpha[neg_mask_graph] = alpha_start_graph[neg_mask_graph] + state.n_pos[neg_mask_graph] = 0 - v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) - f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + v_scaling_graph = tsm.batched_vdot( + state.velocities, state.velocities, state.graph_idx + ) + f_scaling_graph = tsm.batched_vdot(state.forces, state.forces, state.graph_idx) if is_cell_optimization: - v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_graph += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_graph += state.cell_forces.pow(2).sum(dim=(1, 2)) - v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_scaling_cell = torch.sqrt(v_scaling_graph.view(n_graphs, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_graph.view(n_graphs, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + alpha_cell_bc = state.alpha.view(n_graphs, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(n_batches, 1, 1), + pos_mask_graph.view(n_graphs, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_graph[state.graph_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_graph[state.graph_idx].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + alpha_atom = state.alpha[state.graph_idx].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( - pos_mask_batch[state.batch].unsqueeze(-1), + pos_mask_graph[state.graph_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) @@ -1455,7 +1457,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 Updated state after performing one ASE-FIRE step. """ device, dtype = state.positions.device, state.positions.dtype - n_batches = state.n_batches + n_graphs = state.n_graphs cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError @@ -1468,92 +1470,92 @@ def _ase_fire_step( # noqa: C901, PLR0915 f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( - (n_batches, 3, 3), device=device, dtype=dtype + (n_graphs, 3, 3), device=device, dtype=dtype ) cur_deform_grad = state.deform_grad() else: - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype + alpha_start_graph = torch.full( + (n_graphs,), alpha_start.item(), device=device, dtype=dtype ) if is_cell_optimization: cur_deform_grad = state.deform_grad() forces = torch.bmm( - state.forces.unsqueeze(1), cur_deform_grad[state.batch] + state.forces.unsqueeze(1), cur_deform_grad[state.graph_idx] ).squeeze(1) else: forces = state.forces - # 1. Current power (F·v) per batch (atoms + cell) - batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) + # 1. Current power (F·v) per graph (atoms + cell) + graph_power = tsm.batched_vdot(forces, state.velocities, state.graph_idx) if is_cell_optimization: - batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + graph_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + pos_mask_graph = graph_power > 0.0 + neg_mask_graph = ~pos_mask_graph - inc_mask = (state.n_pos > n_min) & pos_mask_batch + inc_mask = (state.n_pos > n_min) & pos_mask_graph state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha - state.n_pos[pos_mask_batch] += 1 + state.n_pos[pos_mask_graph] += 1 - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.dt[neg_mask_graph] *= f_dec + state.alpha[neg_mask_graph] = alpha_start_graph[neg_mask_graph] + state.n_pos[neg_mask_graph] = 0 # 3. Velocity mixing BEFORE acceleration (ASE ordering) - v_scaling_batch = tsm.batched_vdot( - state.velocities, state.velocities, state.batch + v_scaling_graph = tsm.batched_vdot( + state.velocities, state.velocities, state.graph_idx ) - f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) + f_scaling_graph = tsm.batched_vdot(forces, forces, state.graph_idx) if is_cell_optimization: - v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_graph += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_graph += state.cell_forces.pow(2).sum(dim=(1, 2)) - v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_scaling_cell = torch.sqrt(v_scaling_graph.view(n_graphs, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_graph.view(n_graphs, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + alpha_cell_bc = state.alpha.view(n_graphs, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(n_batches, 1, 1), + pos_mask_graph.view(n_graphs, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_graph[state.graph_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_graph[state.graph_idx].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + alpha_atom = state.alpha[state.graph_idx].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( - pos_mask_batch[state.batch].unsqueeze(-1), + pos_mask_graph[state.graph_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += forces * state.dt[state.batch].unsqueeze(-1) - dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) - dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) + state.velocities += forces * state.dt[state.graph_idx].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.graph_idx].unsqueeze(-1) + dr_scaling_graph = tsm.batched_vdot(dr_atom, dr_atom, state.graph_idx) if is_cell_optimization: - state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) + state.cell_velocities += state.cell_forces * state.dt.view(n_graphs, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_graphs, 1, 1) - dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) - dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) + dr_scaling_graph += dr_cell.pow(2).sum(dim=(1, 2)) + dr_scaling_cell = torch.sqrt(dr_scaling_graph).view(n_graphs, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) + dr_scaling_atom = torch.sqrt(dr_scaling_graph)[state.graph_idx].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom @@ -1562,7 +1564,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.positions = ( torch.linalg.solve( - cur_deform_grad[state.batch], state.positions.unsqueeze(-1) + cur_deform_grad[state.graph_idx], state.positions.unsqueeze(-1) ).squeeze(-1) + dr_atom ) @@ -1580,7 +1582,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") F_current = state.deform_grad() - cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) + cell_factor_exp_mult = state.cell_factor.expand(n_graphs, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell @@ -1590,7 +1592,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.row_vector_cell = new_row_vector_cell state.positions = torch.bmm( - state.positions.unsqueeze(1), F_new[state.batch].mT + state.positions.unsqueeze(1), F_new[state.graph_idx].mT ).squeeze(1) else: state.positions = state.positions + dr_atom @@ -1602,7 +1604,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_graphs, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( # noqa: T201 @@ -1616,13 +1618,13 @@ def _ase_fire_step( # noqa: C901, PLR0915 diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_batches, -1, -1) + ).unsqueeze(0).expand(n_graphs, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): @@ -1645,7 +1647,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 directions[idx, mu, nu] = 1.0 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) - for b_idx in range(n_batches): + for b_idx in range(n_graphs): expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 2dbcc52b..7394b19e 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -24,7 +24,7 @@ def calc_kT( # noqa: N802 momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + graph_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate temperature in energy units from momenta/velocities and masses. @@ -32,7 +32,7 @@ def calc_kT( # noqa: N802 momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + graph_idx (torch.Tensor | None): Optional tensor indicating graph membership of each particle Returns: @@ -51,29 +51,29 @@ def calc_kT( # noqa: N802 # If momentum provided, calculate v^2 = p^2/m^2 squared_term = (momenta**2) / masses.unsqueeze(-1) - if batch is None: + if graph_idx is None: # Count total degrees of freedom dof = count_dof(squared_term) return torch.sum(squared_term) / dof - # Sum squared terms for each batch + # Sum squared terms for each graph flattened_squared = torch.sum(squared_term, dim=-1) - # Count degrees of freedom per batch - batch_sizes = torch.bincount(batch) - dof_per_batch = batch_sizes * squared_term.shape[-1] # multiply by n_dimensions + # Count degrees of freedom per graph + graph_sizes = torch.bincount(graph_idx) + dof_per_graph = graph_sizes * squared_term.shape[-1] # multiply by n_dimensions - # Calculate temperature per batch - batch_sums = torch.segment_reduce( - flattened_squared, reduce="sum", lengths=batch_sizes + # Calculate temperature per graph + graph_sums = torch.segment_reduce( + flattened_squared, reduce="sum", lengths=graph_sizes ) - return batch_sums / dof_per_batch + return graph_sums / dof_per_graph def calc_temperature( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + graph_idx: torch.Tensor | None = None, units: object = MetalUnits.temperature, ) -> torch.Tensor: """Calculate temperature from momenta/velocities and masses. @@ -82,14 +82,14 @@ def calc_temperature( momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + graph_idx (torch.Tensor | None): Optional tensor indicating graph membership of each particle units (object): Units to return the temperature in Returns: torch.Tensor: Temperature value in specified units """ - return calc_kT(momenta, masses, velocities, batch) / units + return calc_kT(momenta, masses, velocities, graph_idx) / units # @torch.jit.script @@ -97,7 +97,7 @@ def calc_kinetic_energy( momenta: torch.Tensor, masses: torch.Tensor, velocities: torch.Tensor | None = None, - batch: torch.Tensor | None = None, + graph_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Computes the kinetic energy of a system. @@ -105,12 +105,12 @@ def calc_kinetic_energy( momenta (torch.Tensor): Particle momenta, shape (n_particles, n_dim) masses (torch.Tensor): Particle masses, shape (n_particles,) velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) - batch (torch.Tensor | None): Optional tensor indicating batch membership of + graph_idx (torch.Tensor | None): Optional tensor indicating graph membership of each particle Returns: - If batch is None: Scalar tensor containing the total kinetic energy - If batch is provided: Tensor of kinetic energies per batch + If graph_idx is None: Scalar tensor containing the total kinetic energy + If graph_idx is provided: Tensor of kinetic energies per graph """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") @@ -122,11 +122,11 @@ def calc_kinetic_energy( else: # Using momenta squared_term = (momenta**2) / masses.unsqueeze(-1) - if batch is None: + if graph_idx is None: return 0.5 * torch.sum(squared_term) flattened_squared = torch.sum(squared_term, dim=-1) return 0.5 * torch.segment_reduce( - flattened_squared, reduce="sum", lengths=torch.bincount(batch) + flattened_squared, reduce="sum", lengths=torch.bincount(graph_idx) ) @@ -142,18 +142,18 @@ def get_pressure( def batchwise_max_force(state: SimState) -> torch.Tensor: - """Compute the maximum force per batch. + """Compute the maximum force per graph. Args: - state (SimState): State to compute the maximum force per batch for. + state (SimState): State to compute the maximum force per graph for. Returns: - torch.Tensor: Maximum forces per batch + torch.Tensor: Maximum forces per graph """ - batch_wise_max_force = torch.zeros( - state.n_batches, device=state.device, dtype=state.dtype + graph_wise_max_force = torch.zeros( + state.n_graphs, device=state.device, dtype=state.dtype ) max_forces = state.forces.norm(dim=1) - return batch_wise_max_force.scatter_reduce( - dim=0, index=state.batch, src=max_forces, reduce="amax" + return graph_wise_max_force.scatter_reduce( + dim=0, index=state.graph_idx, src=max_forces, reduce="amax" ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 291034b7..7013a2d3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -172,7 +172,7 @@ def integrate( pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Integrate") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_graphs, **pbar_kwargs) for state, batch_indices in batch_iterator: state = init_fn(state) @@ -194,7 +194,7 @@ def integrate( # finish the trajectory reporter final_states.append(state) if tqdm_pbar: - tqdm_pbar.update(state.n_batches) + tqdm_pbar.update(state.n_graphs) if trajectory_reporter: trajectory_reporter.finish() @@ -307,7 +307,7 @@ def convergence_fn( """Check if the system has converged. Returns: - torch.Tensor: Boolean tensor of shape (n_batches,) indicating + torch.Tensor: Boolean tensor of shape (n_graphs,) indicating convergence status for each batch. """ force_conv = batchwise_max_force(state) < force_tol @@ -343,7 +343,7 @@ def convergence_fn( """Check if the system has converged. Returns: - torch.Tensor: Boolean tensor of shape (n_batches,) indicating + torch.Tensor: Boolean tensor of shape (n_graphs,) indicating convergence status for each batch. """ return torch.abs(state.energy - last_energy) < energy_tol @@ -372,7 +372,7 @@ def optimize( # noqa: C901 model (ModelInterface): Neural network model module optimizer (Callable): Optimization algorithm function convergence_fn (Callable | None): Condition for convergence, should return a - boolean tensor of length n_batches + boolean tensor of length n_graphs optimizer_kwargs: Additional keyword arguments for optimizer init function trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking optimization trajectory. If a dict, will be passed to the @@ -434,7 +434,7 @@ def optimize( # noqa: C901 pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Optimize") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_graphs, **pbar_kwargs) while (result := autobatcher.next_batch(state, convergence_tensor))[0] is not None: state, converged_states, batch_indices = result @@ -545,7 +545,7 @@ class StaticState(type(state)): pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Static") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) + tqdm_pbar = tqdm(total=state.n_graphs, **pbar_kwargs) for sub_state, batch_indices in batch_iterator: # set up trajectory reporters @@ -568,7 +568,7 @@ class StaticState(type(state)): all_props.extend(props) if tqdm_pbar: - tqdm_pbar.update(sub_state.n_batches) + tqdm_pbar.update(sub_state.n_graphs) trajectory_reporter.finish() diff --git a/torch_sim/state.py b/torch_sim/state.py index f01c3ca5..ac2de1f0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -29,48 +29,49 @@ class SimState: Contains the fundamental properties needed to describe an atomistic system: positions, masses, unit cell, periodic boundary conditions, and atomic numbers. Supports batched operations where multiple atomistic systems can be processed - simultaneously, managed through batch indices. + simultaneously, managed through graph indices. States support slicing, cloning, splitting, popping, and movement to other data structures or devices. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state containing only the first three - batches. The other operations are available through the `pop`, `split`, `clone`, + graphs. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. Attributes: positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) masses (torch.Tensor): Atomic masses with shape (n_atoms,) - cell (torch.Tensor): Unit cell vectors with shape (n_batches, 3, 3). + cell (torch.Tensor): Unit cell vectors with shape (n_graphs, 3, 3). Note that we use a column vector convention, i.e. the cell vectors are stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - batch (torch.Tensor, optional): Batch indices with shape (n_atoms,), - defaults to None, must be unique consecutive integers starting from 0 + graph_idx (torch.Tensor, optional): Maps each atom index to its graph index. + Has shape (n_atoms,), defaults to None, must be unique consecutive + integers starting from 0 Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary conditions device (torch.device): Device of the positions tensor dtype (torch.dtype): Data type of the positions tensor - n_atoms (int): Total number of atoms across all batches - n_batches (int): Number of unique batches in the system + n_atoms (int): Total number of atoms across all graphs + n_graphs (int): Number of unique graphs in the system Notes: - positions, masses, and atomic_numbers must have shape (n_atoms, 3). - cell must be in the conventional matrix form. - - batch indices must be unique consecutive integers starting from 0. + - graph indices must be unique consecutive integers starting from 0. Examples: >>> state = initialize_state( ... [ase_atoms_1, ase_atoms_2, ase_atoms_3], device, dtype ... ) - >>> state.n_batches + >>> state.n_graphs 3 >>> new_state = state[[0, 1]] - >>> new_state.n_batches + >>> new_state.n_graphs 2 >>> cloned_state = state.clone() """ @@ -80,11 +81,11 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - batch: torch.Tensor | None = field(default=None, kw_only=True) + graph_idx: torch.Tensor | None = field(default=None, kw_only=True) def __post_init__(self) -> None: """Validate and process the state after initialization.""" - # data validation and fill batch + # data validation and fill graph_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way devices = { @@ -106,23 +107,27 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.batch is None: + if self.cell.ndim != 3 and self.graph_idx is None: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_batches, 3, 3)") + raise ValueError("Cell must have shape (n_graphs, 3, 3)") - if self.batch is None: - self.batch = torch.zeros(self.n_atoms, device=self.device, dtype=torch.int64) + if self.graph_idx is None: + self.graph_idx = torch.zeros( + self.n_atoms, device=self.device, dtype=torch.int64 + ) else: - # assert that batch indices are unique consecutive integers - _, counts = torch.unique_consecutive(self.batch, return_counts=True) - if not torch.all(counts == torch.bincount(self.batch)): - raise ValueError("Batch indices must be unique consecutive integers") - - if self.cell.shape[0] != self.n_batches: + # assert that graph indices are unique consecutive integers + # TODO(curtis): I feel like this logic is not reliable. + # I'll come up with something better later. + _, counts = torch.unique_consecutive(self.graph_idx, return_counts=True) + if not torch.all(counts == torch.bincount(self.graph_idx)): + raise ValueError("Graph indices must be unique consecutive integers") + + if self.cell.shape[0] != self.n_graphs: raise ValueError( - f"Cell must have shape (n_batches, 3, 3), got {self.cell.shape}" + f"Cell must have shape (n_graphs, 3, 3), got {self.cell.shape}" ) @property @@ -145,22 +150,58 @@ def dtype(self) -> torch.dtype: @property def n_atoms(self) -> int: - """Total number of atoms in the system across all batches.""" + """Total number of atoms in the system across all graphs.""" return self.positions.shape[0] @property - def n_atoms_per_batch(self) -> torch.Tensor: - """Number of atoms per batch.""" + def n_atoms_per_graph(self) -> torch.Tensor: + """Number of atoms per graph.""" return ( - self.batch.bincount() - if self.batch is not None + self.graph_idx.bincount() + if self.graph_idx is not None else torch.tensor([self.n_atoms], device=self.device) ) + @property + def n_atoms_per_batch(self) -> torch.Tensor: + """Number of atoms per batch. + + deprecated:: + Use :attr:`n_atoms_per_graph` instead. + """ + return self.n_atoms_per_graph + + @property + def batch(self) -> torch.Tensor | None: + """Graph indices. + + deprecated:: + Use :attr:`graph_idx` instead. + """ + return self.graph_idx + + @batch.setter + def batch(self, graph_idx: torch.Tensor) -> None: + """Set the graph indices from a batch index. + + deprecated:: + Use :attr:`graph_idx` instead. + """ + self.graph_idx = graph_idx + @property def n_batches(self) -> int: - """Number of batches in the system.""" - return torch.unique(self.batch).shape[0] + """Number of batches in the system. + + deprecated:: + Use :attr:`n_graphs` instead. + """ + return self.n_graphs + + @property + def n_graphs(self) -> int: + """Number of graphs in the system.""" + return torch.unique(self.graph_idx).shape[0] @property def volume(self) -> torch.Tensor: @@ -217,7 +258,7 @@ def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. Returns: - list[Atoms]: A list of ASE Atoms objects, one per batch + list[Atoms]: A list of ASE Atoms objects, one per graph """ return ts.io.state_to_atoms(self) @@ -225,7 +266,7 @@ def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. Returns: - list[Structure]: A list of pymatgen Structure objects, one per batch + list[Structure]: A list of pymatgen Structure objects, one per graph """ return ts.io.state_to_structures(self) @@ -233,43 +274,43 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """Convert the SimState to a list of PhonopyAtoms objects. Returns: - list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per batch + list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per graph """ return ts.io.state_to_phonopy(self) def split(self) -> list[Self]: - """Split the SimState into a list of single-batch SimStates. + """Split the SimState into a list of single-graph SimStates. - Divides the current state into separate states, each containing a single batch, - preserving all properties appropriately for each batch. + Divides the current state into separate states, each containing a single graph, + preserving all properties appropriately for each graph. Returns: - list[SimState]: A list of SimState objects, one per batch + list[SimState]: A list of SimState objects, one per graph """ return _split_state(self) - def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: - """Pop off states with the specified batch indices. + def pop(self, graph_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: + """Pop off states with the specified graph indices. This method modifies the original state object by removing the specified - batches and returns the removed batches as separate SimState objects. + graphs and returns the removed graphs as separate SimState objects. Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices + graph_indices (int | list[int] | slice | torch.Tensor): The graph indices to pop Returns: - list[SimState]: Popped SimState objects, one per batch index + list[SimState]: Popped SimState objects, one per graph index Notes: This method modifies the original SimState in-place. """ - batch_indices = _normalize_batch_indices( - batch_indices, self.n_batches, self.device + graph_indices = _normalize_graph_indices( + graph_indices, self.n_graphs, self.device ) # Get the modified state and popped states - modified_state, popped_states = _pop_states(self, batch_indices) + modified_state, popped_states = _pop_states(self, graph_indices) # Update all attributes of self with the modified state's attributes for attr_name, attr_value in vars(modified_state).items(): @@ -293,23 +334,23 @@ def to( """ return state_to_device(self, device, dtype) - def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self: + def __getitem__(self, graph_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices + graph_indices (int | list[int] | slice | torch.Tensor): The graph indices to include Returns: - SimState: A new SimState containing only the specified batches + SimState: A new SimState containing only the specified graphs """ # TODO: need to document that slicing is supported # Reuse the existing slice method - batch_indices = _normalize_batch_indices( - batch_indices, self.n_batches, self.device + graph_indices = _normalize_graph_indices( + graph_indices, self.n_graphs, self.device ) - return _slice_state(self, batch_indices) + return _slice_state(self, graph_indices) class DeformGradMixin: @@ -356,44 +397,44 @@ def deform_grad(self) -> torch.Tensor: return self._deform_grad(self.reference_row_vector_cell, self.row_vector_cell) -def _normalize_batch_indices( - batch_indices: int | list[int] | slice | torch.Tensor, - n_batches: int, +def _normalize_graph_indices( + graph_indices: int | list[int] | slice | torch.Tensor, + n_graphs: int, device: torch.device, ) -> torch.Tensor: - """Normalize batch indices to handle negative indices and different input types. + """Normalize graph indices to handle negative indices and different input types. - Converts various batch index representations to a consistent tensor format, + Converts various graph index representations to a consistent tensor format, handling negative indices in the Python style (counting from the end). Args: - batch_indices (int | list[int] | slice | torch.Tensor): The batch indices to + graph_indices (int | list[int] | slice | torch.Tensor): The graph indices to normalize - n_batches (int): Total number of batches in the system + n_graphs (int): Total number of graphs in the system device (torch.device): Device to place the output tensor on Returns: - torch.Tensor: Normalized batch indices as a tensor + torch.Tensor: Normalized graph indices as a tensor Raises: - TypeError: If batch_indices is of an unsupported type + TypeError: If graph_indices is of an unsupported type """ - if isinstance(batch_indices, int): + if isinstance(graph_indices, int): # Handle negative integer indexing - if batch_indices < 0: - batch_indices = n_batches + batch_indices - return torch.tensor([batch_indices], device=device) - if isinstance(batch_indices, list): + if graph_indices < 0: + graph_indices = n_graphs + graph_indices + return torch.tensor([graph_indices], device=device) + if isinstance(graph_indices, list): # Handle negative indices in lists - normalized = [idx if idx >= 0 else n_batches + idx for idx in batch_indices] + normalized = [idx if idx >= 0 else n_graphs + idx for idx in graph_indices] return torch.tensor(normalized, device=device) - if isinstance(batch_indices, slice): + if isinstance(graph_indices, slice): # Let PyTorch handle the slice conversion with negative indices - return torch.arange(n_batches, device=device)[batch_indices] - if isinstance(batch_indices, torch.Tensor): + return torch.arange(n_graphs, device=device)[graph_indices] + if isinstance(graph_indices, torch.Tensor): # Handle negative indices in tensors - return torch.where(batch_indices < 0, n_batches + batch_indices, batch_indices) - raise TypeError(f"Unsupported index type: {type(batch_indices)}") + return torch.where(graph_indices < 0, n_graphs + graph_indices, graph_indices) + raise TypeError(f"Unsupported index type: {type(graph_indices)}") def state_to_device( @@ -435,8 +476,8 @@ def state_to_device( def infer_property_scope( state: SimState, ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_batch"], list[str]]: - """Infer whether a property is global, per-atom, or per-batch. +) -> dict[Literal["global", "per_atom", "per_graph"], list[str]]: + """Infer whether a property is global, per-atom, or per-graph. Analyzes the shapes of tensor attributes to determine their scope within the atomistic system representation. @@ -450,27 +491,27 @@ def infer_property_scope( - "globalize_warn": Treat ambiguous properties as global with a warning Returns: - dict[Literal["global", "per_atom", "per_batch"], list[str]]: Map of scope + dict[Literal["global", "per_atom", "per_graph"], list[str]]: Map of scope category to list of property names Raises: - ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or + ValueError: If n_atoms equals n_graphs (making scope inference ambiguous) or if ambiguous_handling="error" and an ambiguous property is encountered """ # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_batches, they will be classified incorrectly, + # length of n_atoms or n_graphs, they will be classified incorrectly, # no clear fix - if state.n_atoms == state.n_batches: + if state.n_atoms == state.n_graphs: raise ValueError( - f"n_atoms ({state.n_atoms}) and n_batches ({state.n_batches}) are equal, " + f"n_atoms ({state.n_atoms}) and n_graphs ({state.n_graphs}) are equal, " "which means shapes cannot be inferred unambiguously." ) scope = { "global": [], "per_atom": [], - "per_batch": [], + "per_graph": [], } # Iterate through all attributes @@ -489,15 +530,15 @@ def infer_property_scope( # Vector/matrix with first dimension matching number of atoms elif shape[0] == state.n_atoms: scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of batches - elif shape[0] == state.n_batches: - scope["per_batch"].append(attr_name) + # Tensor with first dimension matching number of graphs + elif shape[0] == state.n_graphs: + scope["per_graph"].append(attr_name) # Any other shape is ambiguous elif ambiguous_handling == "error": raise ValueError( f"Cannot categorize property '{attr_name}' with shape {shape}. " f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_batches} (per-batch), or a scalar (global)." + f"{state.n_graphs} (per-graph), or a scalar (global)." ) elif ambiguous_handling in ("globalize", "globalize_warn"): scope["global"].append(attr_name) @@ -516,10 +557,10 @@ def infer_property_scope( def _get_property_attrs( state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" ) -> dict[str, dict]: - """Get global, per-atom, and per-batch attributes from a state. + """Get global, per-atom, and per-graph attributes from a state. Categorizes all attributes of the state based on their scope - (global, per-atom, or per-batch). + (global, per-atom, or per-graph). Args: state (SimState): The state to extract attributes from @@ -527,12 +568,12 @@ def _get_property_attrs( properties Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_batch', each + dict[str, dict]: Keys are 'global', 'per_atom', and 'per_graph', each containing a dictionary of attribute names to values """ scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - attrs = {"global": {}, "per_atom": {}, "per_batch": {}} + attrs = {"global": {}, "per_atom": {}, "per_graph": {}} # Process global properties for attr_name in scope["global"]: @@ -542,9 +583,9 @@ def _get_property_attrs( for attr_name in scope["per_atom"]: attrs["per_atom"][attr_name] = getattr(state, attr_name) - # Process per-batch properties - for attr_name in scope["per_batch"]: - attrs["per_batch"][attr_name] = getattr(state, attr_name) + # Process per-graph properties + for attr_name in scope["per_graph"]: + attrs["per_graph"][attr_name] = getattr(state, attr_name) return attrs @@ -552,19 +593,19 @@ def _get_property_attrs( def _filter_attrs_by_mask( attrs: dict[str, dict], atom_mask: torch.Tensor, - batch_mask: torch.Tensor, + graph_mask: torch.Tensor, ) -> dict: - """Filter attributes by atom and batch masks. + """Filter attributes by atom and graph masks. - Selects subsets of attributes based on boolean masks for atoms and batches. + Selects subsets of attributes based on boolean masks for atoms and graphs. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_batch', each + attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_graph', each containing a dictionary of attribute names to values atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) - batch_mask (torch.Tensor): Boolean mask for batches to include with shape - (n_batches,) + graph_mask (torch.Tensor): Boolean mask for graphs to include with shape + (n_graphs,) Returns: dict: Filtered attributes with appropriate handling for each scope @@ -576,31 +617,31 @@ def _filter_attrs_by_mask( # Filter per-atom attributes for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "batch": - # Get the old batch indices for the selected atoms - old_batch = attr_value[atom_mask] + if attr_name == "graph_idx": + # Get the old graph indices for the selected atoms + old_graph_idxs = attr_value[atom_mask] - # Get the batch indices that are kept + # Get the graph indices that are kept kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ - batch_mask + graph_mask ] - # Create a mapping from old batch indices to new consecutive indices - batch_map = {idx.item(): i for i, idx in enumerate(kept_indices)} + # Create a mapping from old graph indices to new consecutive indices + graph_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)} - # Create new batch tensor with remapped indices - new_batch = torch.tensor( - [batch_map[b.item()] for b in old_batch], + # Create new graph tensor with remapped indices + new_graph_idxs = torch.tensor( + [graph_idx_map[b.item()] for b in old_graph_idxs], device=attr_value.device, dtype=attr_value.dtype, ) - filtered_attrs[attr_name] = new_batch + filtered_attrs[attr_name] = new_graph_idxs else: filtered_attrs[attr_name] = attr_value[atom_mask] - # Filter per-batch attributes - for attr_name, attr_value in attrs["per_batch"].items(): - filtered_attrs[attr_name] = attr_value[batch_mask] + # Filter per-graph attributes + for attr_name, attr_value in attrs["per_graph"].items(): + filtered_attrs[attr_name] = attr_value[graph_mask] return filtered_attrs @@ -609,10 +650,10 @@ def _split_state( state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[SimState]: - """Split a SimState into a list of states, each containing a single batch element. + """Split a SimState into a list of states, each containing a single graph. - Divides a multi-batch state into individual single-batch states, preserving - appropriate properties for each batch. + Divides a multi-graph state into individual single-graph states, preserving + appropriate properties for each graph. Args: state (SimState): The SimState to split @@ -623,37 +664,39 @@ def _split_state( Returns: list[SimState]: A list of SimState objects, each containing a single - batch element + graph """ attrs = _get_property_attrs(state, ambiguous_handling) - batch_sizes = torch.bincount(state.batch).tolist() + graph_sizes = torch.bincount(state.graph_idx).tolist() - # Split per-atom attributes by batch + # Split per-atom attributes by graph split_per_atom = {} for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "batch": + if attr_name == "graph_idx": continue - split_per_atom[attr_name] = torch.split(attr_value, batch_sizes, dim=0) + split_per_atom[attr_name] = torch.split(attr_value, graph_sizes, dim=0) - # Split per-batch attributes into individual elements - split_per_batch = {} - for attr_name, attr_value in attrs["per_batch"].items(): - split_per_batch[attr_name] = torch.split(attr_value, 1, dim=0) + # Split per-graph attributes into individual elements + split_per_graph = {} + for attr_name, attr_value in attrs["per_graph"].items(): + split_per_graph[attr_name] = torch.split(attr_value, 1, dim=0) - # Create a state for each batch + # Create a state for each graph states = [] - for i in range(state.n_batches): - batch_attrs = { - # Create a batch tensor with all zeros for this batch - "batch": torch.zeros(batch_sizes[i], device=state.device, dtype=torch.int64), + for i in range(state.n_graphs): + graph_attrs = { + # Create a graph tensor with all zeros for this graph + "graph_idx": torch.zeros( + graph_sizes[i], device=state.device, dtype=torch.int64 + ), # Add the split per-atom attributes **{attr_name: split_per_atom[attr_name][i] for attr_name in split_per_atom}, - # Add the split per-batch attributes - **{attr_name: split_per_batch[attr_name][i] for attr_name in split_per_batch}, + # Add the split per-graph attributes + **{attr_name: split_per_graph[attr_name][i] for attr_name in split_per_graph}, # Add the global attributes **attrs["global"], } - states.append(type(state)(**batch_attrs)) + states.append(type(state)(**graph_attrs)) return states @@ -665,11 +708,11 @@ def _pop_states( ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. - Extracts and removes the specified batch indices from the state. + Extracts and removes the specified graph indices from the state. Args: state (SimState): The SimState to modify - pop_indices (list[int] | torch.Tensor): The batch indices to extract and remove + pop_indices (list[int] | torch.Tensor): The graph indices to extract and remove ambiguous_handling ("error" | "globalize"): How to handle ambiguous properties. If "error", an error is raised if a property has ambiguous scope. If "globalize", properties with ambiguous scope are treated as @@ -677,8 +720,8 @@ def _pop_states( Returns: tuple[SimState, list[SimState]]: A tuple containing: - - The modified original state with specified batches removed - - A list of the extracted SimStates, one per popped batch + - The modified original state with specified graphs removed + - A list of the extracted SimStates, one per popped graph Notes: Unlike the pop method, this function does not modify the input state. @@ -691,17 +734,17 @@ def _pop_states( attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and batches to keep and pop - batch_range = torch.arange(state.n_batches, device=state.device) - pop_batch_mask = torch.isin(batch_range, pop_indices) - keep_batch_mask = ~pop_batch_mask + # Create masks for the atoms and graphs to keep and pop + graph_range = torch.arange(state.n_graphs, device=state.device) + pop_graph_mask = torch.isin(graph_range, pop_indices) + keep_graph_mask = ~pop_graph_mask - pop_atom_mask = torch.isin(state.batch, pop_indices) + pop_atom_mask = torch.isin(state.graph_idx, pop_indices) keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_batch_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_batch_mask) + keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_graph_mask) + pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_graph_mask) # Create the keep state keep_state = type(state)(**keep_attrs) @@ -715,17 +758,17 @@ def _pop_states( def _slice_state( state: SimState, - batch_indices: list[int] | torch.Tensor, + graph_indices: list[int] | torch.Tensor, ambiguous_handling: Literal["error", "globalize"] = "error", ) -> SimState: - """Slice a substate from the SimState containing only the specified batch indices. + """Slice a substate from the SimState containing only the specified graph indices. - Creates a new SimState containing only the specified batches, preserving + Creates a new SimState containing only the specified graphs, preserving all relevant properties. Args: state (SimState): The state to slice - batch_indices (list[int] | torch.Tensor): Batch indices to include in the + graph_indices (list[int] | torch.Tensor): Graph indices to include in the sliced state ambiguous_handling ("error" | "globalize"): How to handle ambiguous properties. If "error", an error is raised if a property has ambiguous @@ -733,28 +776,28 @@ def _slice_state( global. Returns: - SimState: A new SimState object containing only the specified batches + SimState: A new SimState object containing only the specified graphs Raises: - ValueError: If batch_indices is empty + ValueError: If graph_indices is empty """ - if isinstance(batch_indices, list): - batch_indices = torch.tensor( - batch_indices, device=state.device, dtype=torch.int64 + if isinstance(graph_indices, list): + graph_indices = torch.tensor( + graph_indices, device=state.device, dtype=torch.int64 ) - if len(batch_indices) == 0: - raise ValueError("batch_indices cannot be empty") + if len(graph_indices) == 0: + raise ValueError("graph_indices cannot be empty") attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and batches to include - batch_range = torch.arange(state.n_batches, device=state.device) - batch_mask = torch.isin(batch_range, batch_indices) - atom_mask = torch.isin(state.batch, batch_indices) + # Create masks for the atoms and graphs to include + graph_range = torch.arange(state.n_graphs, device=state.device) + graph_mask = torch.isin(graph_range, graph_indices) + atom_mask = torch.isin(state.graph_idx, graph_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, batch_mask) + filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, graph_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -765,8 +808,8 @@ def concatenate_states( ) -> SimState: """Concatenate a list of SimStates into a single SimState. - Combines multiple states into a single state with multiple batches. - Global properties are taken from the first state, and per-atom and per-batch + Combines multiple states into a single state with multiple graphs. + Global properties are taken from the first state, and per-atom and per-graph properties are concatenated. Args: @@ -775,7 +818,7 @@ def concatenate_states( Defaults to the device of the first state. Returns: - SimState: A new SimState containing all input states as separate batches + SimState: A new SimState containing all input states as separate graphs Raises: ValueError: If states is empty @@ -796,20 +839,20 @@ def concatenate_states( target_device = device or first_state.device # Get property scopes from the first state to identify - # global/per-atom/per-batch properties + # global/per-atom/per-graph properties first_scope = infer_property_scope(first_state) global_props = set(first_scope["global"]) per_atom_props = set(first_scope["per_atom"]) - per_batch_props = set(first_scope["per_batch"]) + per_graph_props = set(first_scope["per_graph"]) # Initialize result with global properties from first state concatenated = {prop: getattr(first_state, prop) for prop in global_props} # Pre-allocate lists for tensors to concatenate per_atom_tensors = {prop: [] for prop in per_atom_props} - per_batch_tensors = {prop: [] for prop in per_batch_props} - new_batch_indices = [] - batch_offset = 0 + per_graph_tensors = {prop: [] for prop in per_graph_props} + new_graph_indices = [] + graph_offset = 0 # Process all states in a single pass for state in states: @@ -822,28 +865,28 @@ def concatenate_states( # if hasattr(state, prop): per_atom_tensors[prop].append(getattr(state, prop)) - # Collect per-batch properties - for prop in per_batch_props: + # Collect per-graph properties + for prop in per_graph_props: # if hasattr(state, prop): - per_batch_tensors[prop].append(getattr(state, prop)) + per_graph_tensors[prop].append(getattr(state, prop)) - # Update batch indices - num_batches = state.n_batches - new_indices = state.batch + batch_offset - new_batch_indices.append(new_indices) - batch_offset += num_batches + # Update graph indices + num_graphs = state.n_graphs + new_indices = state.graph_idx + graph_offset + new_graph_indices.append(new_indices) + graph_offset += num_graphs # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) - for prop, tensors in per_batch_tensors.items(): + for prop, tensors in per_graph_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) - # Concatenate batch indices - concatenated["batch"] = torch.cat(new_batch_indices) + # Concatenate graph indices + concatenated["graph_idx"] = torch.cat(new_graph_indices) # Create a new instance of the same class return state_class(**concatenated) @@ -877,10 +920,10 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_batches == 1 for state in system): + if not all(state.n_graphs == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " - "all states must have n_batches == 1. To fix this, you can split the " + "all states must have n_graphs == 1. To fix this, you can split the " "states into individual states with the split_state function." ) return concatenate_states(system) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index bc24de63..57a7e8a1 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -208,11 +208,11 @@ def report( """Report a state and step to the trajectory files. Writes states and calculated properties to all trajectory files at the - specified frequencies. Splits multi-batch states across separate trajectory - files. The number of batches must match the number of trajectory files. + specified frequencies. Splits multi-graph states across separate trajectory + files. The number of graphs must match the number of trajectory files. Args: - state (SimState): Current system state with n_batches equal to + state (SimState): Current system state with n_graphs equal to len(filenames) step (int): Current simulation step, setting step to 0 will write the state and all properties. @@ -224,27 +224,27 @@ def report( are being collected separately. Returns: - list[dict[str, torch.Tensor]]: Map of property names to tensors for each batch + list[dict[str, torch.Tensor]]: Map of property names to tensors for each graph Raises: - ValueError: If number of batches doesn't match number of trajectory files + ValueError: If number of graphs doesn't match number of trajectory files """ - # Get unique batch indices - batch_indices = range(state.n_batches) - # batch_indices = torch.unique(state.batch).cpu().tolist() + # Get unique graph indices + graph_indices = range(state.n_graphs) + # graph_indices = torch.unique(state.graph_idx).cpu().tolist() # Ensure we have the right number of trajectories - if self.filenames is not None and len(batch_indices) != len(self.trajectories): + if self.filenames is not None and len(graph_indices) != len(self.trajectories): raise ValueError( - f"Number of batches ({len(batch_indices)}) doesn't match " + f"Number of graphs ({len(graph_indices)}) doesn't match " f"number of trajectory files ({len(self.trajectories)})" ) split_states = state.split() all_props: list[dict[str, torch.Tensor]] = [] - # Process each batch separately + # Process each graph separately for idx, substate in enumerate(split_states): - # Slice the state once to get only the data for this batch + # Slice the state once to get only the data for this graph self.shape_warned = True # Write state to trajectory if it's time @@ -256,7 +256,7 @@ def report( self.trajectories[idx].write_state(substate, step, **self.state_kwargs) all_state_props = {} - # Process property calculators for this batch + # Process property calculators for this graph for report_frequency, calculators in self.prop_calculators.items(): if step % report_frequency != 0 or report_frequency == 0: continue @@ -672,7 +672,7 @@ def write_state( # noqa: C901 self, state: SimState | list[SimState], steps: int | list[int], - batch_index: int | None = None, + graph_index: int | None = None, *, save_velocities: bool = False, save_forces: bool = False, @@ -692,7 +692,7 @@ def write_state( # noqa: C901 Args: state (SimState | list[SimState]): SimState or list of SimStates to write steps (int | list[int]): Step number(s) for the frame(s) - batch_index (int, optional): Batch index to save. + graph_index (int, optional): Graph index to save. save_velocities (bool, optional): Whether to save velocities. save_forces (bool, optional): Whether to save forces. variable_cell (bool, optional): Whether the cell varies between frames. @@ -712,16 +712,14 @@ def write_state( # noqa: C901 if isinstance(steps, int): steps = [steps] - if isinstance(batch_index, int): - batch_index = [batch_index] - sub_states = [state[batch_index] for state in state] - elif batch_index is None and torch.unique(state[0].batch) == 0: - batch_index = 0 + if isinstance(graph_index, int): + graph_index = [graph_index] + sub_states = [state[graph_index] for state in state] + elif graph_index is None and torch.unique(state[0].graph_idx) == 0: + graph_index = 0 sub_states = state else: - raise ValueError( - "Batch index must be specified if there are multiple batches" - ) + raise ValueError("Graph index must be specified if there are multiple graphs") if len(sub_states) != len(steps): raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}") diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 44d27181..1d3eaeb9 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -23,9 +23,9 @@ def get_fractional_coordinates( Args: positions (torch.Tensor): Atomic positions in Cartesian coordinates. - Shape: [..., 3] where ... represents optional batch dimensions. + Shape: [..., 3] where ... represents optional graph dimensions. cell (torch.Tensor): Unit cell matrix with lattice vectors as rows. - Shape: [..., 3, 3] where ... matches positions' batch dimensions. + Shape: [..., 3, 3] where ... matches positions' graph dimensions. Returns: torch.Tensor: Atomic positions in fractional coordinates with same shape as input @@ -42,21 +42,21 @@ def get_fractional_coordinates( """ if cell.ndim == 3: # Handle batched cell tensors # For batched cells, we need to determine if this is: - # 1. A single batch (n_batches=1) - can be squeezed and handled normally - # 2. Multiple batches - need proper batch handling + # 1. A single graph (n_graphs=1) - can be squeezed and handled normally + # 2. Multiple graphs - need proper graph handling if cell.shape[0] == 1: - # Single batch case - squeeze and use the 2D implementation + # Single graph case - squeeze and use the 2D implementation cell_2d = cell.squeeze(0) # Remove batch dimension return torch.linalg.solve(cell_2d.mT, positions.mT).mT - # Multiple batches case - this would require batch indices to know which - # atoms belong to which batch. For now, this is not implemented. + # Multiple graphs case - this would require graph indices to know which + # atoms belong to which graph. For now, this is not implemented. raise NotImplementedError( - f"Multiple batched cell tensors with shape {cell.shape} are not yet " - "supported in get_fractional_coordinates. For multiple batch systems, " - "you need to provide batch indices to determine which atoms belong to " - "which batch. For single batch systems, consider squeezing the batch " - "dimension or using individual calls per batch." + f"Multiple graph cell tensors with shape {cell.shape} are not yet " + "supported in get_fractional_coordinates. For multiple graph systems, " + "you need to provide graph indices to determine which atoms belong to " + "which graph. For single graph systems, consider squeezing the batch " + "dimension or using individual calls per graph." ) # Original case for 2D cell matrix @@ -155,20 +155,20 @@ def pbc_wrap_general( def pbc_wrap_batched( - positions: torch.Tensor, cell: torch.Tensor, batch: torch.Tensor + positions: torch.Tensor, cell: torch.Tensor, graph_idx: torch.Tensor ) -> torch.Tensor: """Apply periodic boundary conditions to batched systems. This function handles wrapping positions for multiple atomistic systems - (batches) in one operation. It uses the batch indices to determine which + (graphs) in one operation. It uses the graph indices to determine which atoms belong to which system and applies the appropriate cell vectors. Args: positions (torch.Tensor): Tensor of shape (n_atoms, 3) containing particle positions in real space. - cell (torch.Tensor): Tensor of shape (n_batches, 3, 3) containing + cell (torch.Tensor): Tensor of shape (n_graphs, 3, 3) containing lattice vectors as column vectors. - batch (torch.Tensor): Tensor of shape (n_atoms,) containing batch + graph_idx (torch.Tensor): Tensor of shape (n_atoms,) containing graph indices for each atom. Returns: @@ -182,33 +182,33 @@ def pbc_wrap_batched( if positions.shape[-1] != cell.shape[-1]: raise ValueError("Position dimensionality must match lattice vectors.") - # Get unique batch indices and counts - unique_batches = torch.unique(batch) - n_batches = len(unique_batches) + # Get unique graph indices and counts + unique_graphs = torch.unique(graph_idx) + n_graphs = len(unique_graphs) - if n_batches != cell.shape[0]: + if n_graphs != cell.shape[0]: raise ValueError( - f"Number of unique batches ({n_batches}) doesn't " + f"Number of unique graphs ({n_graphs}) doesn't " f"match number of cells ({cell.shape[0]})" ) # Efficient approach without explicit loops - # Get the cell for each atom based on its batch index - B = torch.linalg.inv(cell) # Shape: (n_batches, 3, 3) - B_per_atom = B[batch] # Shape: (n_atoms, 3, 3) + # Get the cell for each atom based on its graph index + B = torch.linalg.inv(cell) # Shape: (n_graphs, 3, 3) + B_per_atom = B[graph_idx] # Shape: (n_atoms, 3, 3) # Transform to fractional coordinates: f = B·r - # For each atom, multiply its position by its batch's inverse cell matrix + # For each atom, multiply its position by its graph's inverse cell matrix frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) # Wrap to reference cell [0,1) using modulo wrapped_frac = frac_coords % 1.0 # Transform back to real space: r = A·f - # Get the cell for each atom based on its batch index - cell_per_atom = cell[batch] # Shape: (n_atoms, 3, 3) + # Get the cell for each atom based on its graph index + cell_per_atom = cell[graph_idx] # Shape: (n_atoms, 3, 3) - # For each atom, multiply its wrapped fractional coords by its batch's cell matrix + # For each atom, multiply its wrapped fractional coords by its graph's cell matrix return torch.bmm(cell_per_atom, wrapped_frac.unsqueeze(2)).squeeze(2) @@ -535,7 +535,7 @@ def compute_distances_with_cell_shifts( def compute_cell_shifts( - cell: torch.Tensor, shifts_idx: torch.Tensor, batch_mapping: torch.Tensor + cell: torch.Tensor, shifts_idx: torch.Tensor, graph_mapping: torch.Tensor ) -> torch.Tensor: """Compute the cell shifts based on the provided indices and cell matrix. @@ -547,18 +547,18 @@ def compute_cell_shifts( representing the unit cell matrices. shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) representing the indices for shifts. - batch_mapping (torch.Tensor): A tensor of shape (n_batches,) + graph_mapping (torch.Tensor): A tensor of shape (n_graphs,) that maps the shifts to the corresponding cells. Returns: - torch.Tensor: A tensor of shape (n_batches, 3) containing + torch.Tensor: A tensor of shape (n_graphs, 3) containing the computed cell shifts. """ if cell is None: cell_shifts = None else: cell_shifts = torch.einsum( - "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[batch_mapping] + "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[graph_mapping] ) return cell_shifts @@ -625,7 +625,7 @@ def build_naive_neighborhood( This function computes a neighborhood list of atoms within a specified cutoff distance, considering periodic boundary conditions defined by the unit cell. It returns the mapping of atom pairs, - the batch mapping for each structure, and the corresponding shifts. + the graph mapping for each structure, and the corresponding shifts. Args: positions (torch.Tensor): A tensor of shape (n_atoms, 3) @@ -645,7 +645,7 @@ def build_naive_neighborhood( tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - mapping (torch.Tensor): A tensor of shape (n_pairs, 2) representing the pairs of indices for neighboring atoms. - - batch_mapping (torch.Tensor): A tensor of shape (n_pairs,) + - graph_mapping (torch.Tensor): A tensor of shape (n_pairs,) indicating the structure index for each pair. - shifts_idx (torch.Tensor): A tensor of shape (n_pairs, 3) representing the shifts applied for periodic boundary @@ -659,7 +659,7 @@ def build_naive_neighborhood( stride = strides_of(n_atoms) ids = torch.arange(positions.shape[0], device=device, dtype=torch.long) - mapping, batch_mapping, shifts_idx_ = [], [], [] + mapping, graph_mapping, shifts_idx_ = [], [], [] for i_structure in range(n_atoms.shape[0]): num_repeats = num_repeats_[i_structure] shifts_idx = get_cell_shift_idx(num_repeats, dtype) @@ -669,7 +669,7 @@ def build_naive_neighborhood( i_ids=i_ids, shifts_idx=shifts_idx, self_interaction=self_interaction ) mapping.append(s_mapping) - batch_mapping.append( + graph_mapping.append( torch.full( (s_mapping.shape[0],), i_structure, @@ -680,7 +680,7 @@ def build_naive_neighborhood( shifts_idx_.append(shifts_idx) return ( torch.cat(mapping, dim=0).t(), - torch.cat(batch_mapping, dim=0), + torch.cat(graph_mapping, dim=0), torch.cat(shifts_idx_, dim=0), ) @@ -998,7 +998,7 @@ def build_linked_cell_neighborhood( - mapping (torch.Tensor): A tensor containing pairs of indices where mapping[0] represents the central atom indices and mapping[1] represents their corresponding neighbor indices. - - batch_mapping (torch.Tensor): A tensor containing the structure indices + - graph_mapping (torch.Tensor): A tensor containing the structure indices corresponding to each neighbor atom. - cell_shifts_idx (torch.Tensor): A tensor containing the cell shift indices for each neighbor atom, which are necessary for @@ -1014,7 +1014,7 @@ def build_linked_cell_neighborhood( stride = strides_of(n_atoms) - mapping, batch_mapping, cell_shifts_idx = [], [], [] + mapping, graph_mapping, cell_shifts_idx = [], [], [] for i_structure in range(n_structure): # Compute the neighborhood with the linked cell algorithm neigh_atom, neigh_shift_idx = linked_cell( @@ -1025,7 +1025,7 @@ def build_linked_cell_neighborhood( self_interaction, ) - batch_mapping.append( + graph_mapping.append( i_structure * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device) ) # Shift the mapping indices to access positions @@ -1034,7 +1034,7 @@ def build_linked_cell_neighborhood( return ( torch.cat(mapping, dim=1), - torch.cat(batch_mapping, dim=0), + torch.cat(graph_mapping, dim=0), torch.cat(cell_shifts_idx, dim=0), ) diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 13f0db94..79da1a73 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -15,7 +15,7 @@ MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] -StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "batch"] +StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "graph_idx"] StateDict = dict[StateKey, torch.Tensor] SimStateVar = TypeVar("SimStateVar", bound="SimState") diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 140c0031..f4006378 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -730,9 +730,9 @@ def get_unit_cell_relaxed_structure( device, dtype = model.device, model.dtype logger = { - "energy": torch.zeros((max_iter, state.n_batches), device=device, dtype=dtype), + "energy": torch.zeros((max_iter, state.n_graphs), device=device, dtype=dtype), "stress": torch.zeros( - (max_iter, state.n_batches, 3, 3), device=device, dtype=dtype + (max_iter, state.n_graphs, 3, 3), device=device, dtype=dtype ), } From be83288838fcf58bfe1216004587591d60343bde Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 7 Jul 2025 16:16:00 -0700 Subject: [PATCH 2/2] code rabbit reviewgs --- torch_sim/models/morse.py | 2 +- torch_sim/models/particle_life.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index b55d16f6..1ff73e29 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -380,7 +380,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) if state.graph_idx is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + raise ValueError("graph_idx can only be inferred if there is only one graph.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0] diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 30b0d61d..62b91050 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -240,7 +240,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) if state.graph_idx is None and state.cell.shape[0] > 1: - raise ValueError("Batch can only be inferred for batch size 1.") + raise ValueError("graph_idx can only be inferred if there is only one graph.") outputs = [self.unbatched_forward(state[i]) for i in range(state.n_graphs)] properties = outputs[0]