Skip to content

Conversation

t-reents
Copy link
Contributor

@t-reents t-reents commented Jul 17, 2025

The velocities and cell_velocities are initialized to None in the (FrechetCell)FIREState. However, when using the InFlightAutoBatcher during an optimization, the current and new states are concatenated in torch_sim.state.concatenate_states.

https://github.com/Radical-AI/torch-sim/blob/317985c731170aad578673ebe69a9334f5abe5be/torch_sim/state.py#L902-L909

When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a Tensor with a None.

Here, I initialize the (cell_)velocities as tensors full of nan instead, so that one can merge already processed and newly initialized states. During the first initialization, the fire methods look for nan rows and replace them with zeros.

The error would also have been caught by the existing test_autobatching test if the number of iterations between swaps had been set to a smaller value (I changed it here). The reason is simply that within the current 10 iterations all states converge so that a completely new set of states is selected in the next iteration and already processed and newly initialised states never get merged.

You can either change the test as done here or use the following code snippet to reproduce the error with the current version:

Code to reproduce
import torch_sim as ts
import torch
from torch_sim.autobatching import calculate_memory_scaler
from ase.build import bulk
from torch_sim.autobatching import estimate_max_memory_scaler
from mace.calculators.foundations_models import mace_mp
from torch_sim.models.mace import MaceModel


si_atoms = bulk("Si", "fcc", a=3.26, cubic=True)
si_atoms.rattle(0.05)

cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True)
cu_atoms.rattle(0.5)

many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64)

mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(model=mace, device=device)

fire_init, fire_update = ts.optimizers.fire(mace_model)
fire_state = fire_init(state)

batcher = ts.InFlightAutoBatcher(
    model=mace_model,
    memory_scales_with="n_atoms",
    max_memory_scaler=40,
    max_iterations=10000,  # Optional: maximum convergence attempts per state
)

batcher.load_states(fire_state)

convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False)

all_converged_states, convergence_tensor = [], None
while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None:
    fire_state, converged_states = result
    all_converged_states.extend(converged_states)

    for _ in range(3):
        fire_state = fire_update(fire_state)

    convergence_tensor = convergence_fn(fire_state, None)
    print(f"Convergence tensor: {convergence_tensor}")
    print(f"Convergence tensor: {batcher.current_idx}")

else:
    all_converged_states.extend(result[1])

final_states = batcher.restore_original_order(all_converged_states)

Summary by CodeRabbit

  • Bug Fixes

    • Improved handling of velocity initialization to ensure more robust and consistent behavior during optimization steps.
  • Tests

    • Updated a test to reduce the number of simulation state updates for improved efficiency.

The `velocities` and `cell_velocities` are initialized to `None` in the `(FrechetCell)FIREState`. However, when using the `InFlightAutoBatcher` during an optimization, the current and new states are concatenated in `torch_sim.state.concatenate_states`. When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a `Tensor` with a `None`.

Here, we initialize the `(cell_)velocities` as tensors full of `nan` instead, so that one can merge already processed and newly initialized states. During the first initialization, the `fire` methods look for `nan` rows and replace them with zeros.
@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Jul 17, 2025
Copy link

coderabbitai bot commented Jul 17, 2025

Walkthrough

The updates modify the initialization and handling of velocity tensors in optimizer state classes by using NaN-filled tensors instead of None, and update corresponding checks to detect NaNs. Additionally, a test was adjusted to reduce the number of simulation iterations. No public APIs or function signatures were changed.

Changes

File(s) Change Summary
torch_sim/optimizers.py Velocity and cell velocity initialization changed from None to NaN tensors; checks updated to detect NaNs and initialize zeros where needed.
tests/test_autobatching.py Reduced simulation state update loop iterations in test_in_flight_with_fire from 10 to 5.

Poem

A hop, a skip, a NaN or two,
Now velocities know just what to do!
No more None to trip or stall,
Just check for NaNs and fix them all.
With fewer loops, our tests take flight—
The code hops forward, feeling light!
🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 317985c and 059e57d.

📒 Files selected for processing (2)
  • tests/test_autobatching.py (1 hunks)
  • torch_sim/optimizers.py (5 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (37)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, graphpes, tests/models/test_graphpes.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: build-docs
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
🔇 Additional comments (8)
torch_sim/optimizers.py (7)

592-594: LGTM: Consistent NaN initialization for velocities

The change from None to torch.nan initialization maintains semantic equivalence while enabling proper tensor concatenation. The shape and dtype match the positions tensor appropriately.


867-869: LGTM: Consistent NaN initialization for atomic velocities

The initialization pattern matches the FireState implementation, using NaN-filled tensors instead of None for proper concatenation support.


875-877: LGTM: Consistent NaN initialization for cell velocities

The cell velocities are properly initialized with NaN values matching the cell_forces shape, maintaining consistency with the atomic velocities initialization pattern.


1170-1172: LGTM: Consistent NaN initialization in FrechetCellFIREState

The atomic velocities initialization follows the same pattern as other state classes, using NaN-filled tensors matching the positions shape.


1178-1180: LGTM: Consistent NaN initialization for Frechet cell velocities

The cell velocities are properly initialized with NaN values matching the cell_forces shape, maintaining consistency across all FIRE state implementations.


1257-1270: LGTM: Proper NaN detection and zero initialization

The logic correctly:

  • Detects NaN values in velocities using isnan().any(dim=1)
  • Initializes NaN rows to zero tensors with matching shapes
  • Handles both atomic and cell velocities consistently
  • Maintains the same semantic behavior as the previous None checks

1478-1492: LGTM: Consistent NaN detection in ASE FIRE step

The NaN detection and initialization logic mirrors the velocity-verlet implementation:

  • Proper detection of NaN values in both velocities and cell_velocities
  • Zero initialization for NaN rows
  • Consistent handling across atomic and cell optimization modes
tests/test_autobatching.py (1)

493-493: LGTM: Reduced iterations to expose concatenation issue

The reduction from 10 to 5 iterations aligns with the PR objective of exposing the concatenation issue that occurs when merging states at different processing stages. This change ensures the test validates the NaN-based velocity initialization fix.

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@kianpu34593
Copy link
Contributor

I just encountered the same issue. This fix is what I was trying to do as well. Glad you jumped on it. I added an improvement that is related to this ( #222 ).

@curtischong
Copy link
Collaborator

Thank you for the contribution! I'm personally not too familiar with this part of the code, but I'll bring this up with the rest of the team.

@curtischong
Copy link
Collaborator

I read your code and I think it works, but I'm slightly hesitant to merge it in because it overrides the meaning of NaN (and makes it fail silently if a user were to inadvertently pass in NaN tensors.. I feel like the real problem is that we didn't properly architect it to handle these cases which I'm trying to fix in my free time.

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @t-reents, good catch and nice change on the test. It seems like it would be simpler to just change the fire_init to initialize to zeros instead of nan since that's what's being set in the _vv_fire_step. Do the different step functions initialize things differently? Am I missing something here?

Comment on lines -865 to +869
velocities=None,
velocities=torch.full(
state.positions.shape, torch.nan, device=device, dtype=dtype
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just initialize to zeros?

@t-reents
Copy link
Contributor Author

TL;DR:

  • My implementation needs to be adapted anyway, but we should agree on how to do it first
  • We have to rethink the logic of the ASE fire step when dealing with states that get merged (i.e., systems are at different iterations) --> Systems at iteration 0 have to treated differently
  • Agree on a way to indicate the iteration 0

Thanks for your comments!

I'm slightly hesitant to merge it in because it overrides the meaning of NaN (and makes it fail silently if a user were to inadvertently pass in NaN tensors

I totally agree with this and I also wasn't 100% happy with it when implementing it (however, it seemed to be the best option within the current setup). I'm happy to discuss it together to come up with a better/more consistent solution

It seems like it would be simpler to just change the fire_init to initialize to zeros instead of nan since that's what's being set in the _vv_fire_step. Do the different step functions initialize things differently? Am I missing something here?

This was the case in the previous version. The initialization as None instead of 0 was one of the changes that we did to achieve agreement with the ASE implementation, see #203.
It is necessary to skip the else block during the first iteration (see the ASE implementation for reference: https://gitlab.com/ase/ase/-/blob/master/ase/optimize/fire.py?ref_type=heads#L170-205):
https://github.com/Radical-AI/torch-sim/blob/317985c731170aad578673ebe69a9334f5abe5be/torch_sim/optimizers.py#L1476

In any case, while writing this explanation I realize that my change is not fully correct as well. If we merge states (some at more advanced iterations, some at the first iteration), we will skip the else block for all systems and therefore don't perform the correct update for systems at more advanced iterations. This being said, the logic has to be split into two branches, one for the systems at iteration 0 and one for the others.

Just a spontaneous idea without thinking too much about it, so there might be other disadvantages, what about having a optimization_step attribute on the different optimization states (i.e., a tensor containing a counter for each system). In that case, one could just check if this counter is 0 for a given system. This could also help to address the following comment:
https://github.com/Radical-AI/torch-sim/blob/317985c731170aad578673ebe69a9334f5abe5be/torch_sim/runners.py#L457-L460

@CompRhys
Copy link
Collaborator

CompRhys commented Jul 24, 2025

Is the answer here not to use a zero-dim tensor? @t-reents

>>> zero_dim_tensor = torch.empty((10, 3, 0), dtype=torch.float32)
>>> zero_dim_tensor
tensor([], size=(10, 3, 0))

that makes the most sense to me rather than worrying about a default value. The question would then be does the zero_dim_tensor cause any logic issues elsewhere?

@t-reents
Copy link
Contributor Author

t-reents commented Jul 24, 2025

@CompRhys Thanks! I was actually thinking (maybe even trying to use it) about torch.empty as well, don't remember why I decided not to use it.

EDIT:

The question would then be does the zero_dim_tensor cause any logic issues elsewhere?

I think that it would fail again when trying to concatenate with other states:
https://github.com/Radical-AI/torch-sim/blob/317985c731170aad578673ebe69a9334f5abe5be/torch_sim/state.py#L902-L909

Moreover, I still think that the issue remains that the current logic wouldn't not work for those "mixed" states, would it? I think this is independent of how we initialize.
So _ase_fire_step probably needs a small refactoring of the logic in any case, if I'm not mistaken.

@orionarcher
Copy link
Collaborator

what about having a optimization_step attribute on the different optimization states

I think this makes sense if the optimization has different behavior at different optimization steps, as it sounds like the ase_fire optimization does. It's a good solution for handling the branched logic otherwise, having None or empty be a special marker for the first step is less explicit.

I wouldn't advocate including an optimization step for other states unless required though.

@CompRhys CompRhys dismissed orionarcher’s stale review August 8, 2025 13:57

Decided to stay with nan

@CompRhys CompRhys merged commit e90b272 into TorchSim:main Aug 8, 2025
176 of 182 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla-signed Contributor license agreement signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants