-
Notifications
You must be signed in to change notification settings - Fork 40
Fix concatenation of states in InFlightAutoBatcher
#219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix concatenation of states in InFlightAutoBatcher
#219
Conversation
The `velocities` and `cell_velocities` are initialized to `None` in the `(FrechetCell)FIREState`. However, when using the `InFlightAutoBatcher` during an optimization, the current and new states are concatenated in `torch_sim.state.concatenate_states`. When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a `Tensor` with a `None`. Here, we initialize the `(cell_)velocities` as tensors full of `nan` instead, so that one can merge already processed and newly initialized states. During the first initialization, the `fire` methods look for `nan` rows and replace them with zeros.
WalkthroughThe updates modify the initialization and handling of velocity tensors in optimizer state classes by using NaN-filled tensors instead of None, and update corresponding checks to detect NaNs. Additionally, a test was adjusted to reduce the number of simulation iterations. No public APIs or function signatures were changed. Changes
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (37)
🔇 Additional comments (8)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
I just encountered the same issue. This fix is what I was trying to do as well. Glad you jumped on it. I added an improvement that is related to this ( #222 ). |
Thank you for the contribution! I'm personally not too familiar with this part of the code, but I'll bring this up with the rest of the team. |
I read your code and I think it works, but I'm slightly hesitant to merge it in because it overrides the meaning of NaN (and makes it fail silently if a user were to inadvertently pass in NaN tensors.. I feel like the real problem is that we didn't properly architect it to handle these cases which I'm trying to fix in my free time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @t-reents, good catch and nice change on the test. It seems like it would be simpler to just change the fire_init
to initialize to zeros instead of nan since that's what's being set in the _vv_fire_step
. Do the different step
functions initialize things differently? Am I missing something here?
velocities=None, | ||
velocities=torch.full( | ||
state.positions.shape, torch.nan, device=device, dtype=dtype | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just initialize to zeros?
TL;DR:
Thanks for your comments!
I totally agree with this and I also wasn't 100% happy with it when implementing it (however, it seemed to be the best option within the current setup). I'm happy to discuss it together to come up with a better/more consistent solution
This was the case in the previous version. The initialization as In any case, while writing this explanation I realize that my change is not fully correct as well. If we merge states (some at more advanced iterations, some at the first iteration), we will skip the else block for all systems and therefore don't perform the correct update for systems at more advanced iterations. This being said, the logic has to be split into two branches, one for the systems at iteration 0 and one for the others. Just a spontaneous idea without thinking too much about it, so there might be other disadvantages, what about having a |
Is the answer here not to use a zero-dim tensor? @t-reents >>> zero_dim_tensor = torch.empty((10, 3, 0), dtype=torch.float32)
>>> zero_dim_tensor
tensor([], size=(10, 3, 0)) that makes the most sense to me rather than worrying about a default value. The question would then be does the |
@CompRhys Thanks! I was actually thinking (maybe even trying to use it) about EDIT:
I think that it would fail again when trying to concatenate with other states: Moreover, I still think that the issue remains that the current logic wouldn't not work for those "mixed" states, would it? I think this is independent of how we initialize. |
I think this makes sense if the optimization has different behavior at different optimization steps, as it sounds like the I wouldn't advocate including an optimization step for other states unless required though. |
The
velocities
andcell_velocities
are initialized toNone
in the(FrechetCell)FIREState
. However, when using theInFlightAutoBatcher
during an optimization, the current and new states are concatenated intorch_sim.state.concatenate_states
.https://github.com/Radical-AI/torch-sim/blob/317985c731170aad578673ebe69a9334f5abe5be/torch_sim/state.py#L902-L909
When trying to merge states that were already processed for a few iterations (i.e.,
velocities
are notNone
anymore) and newly initialized ones, an error is raised because the code tries to merge aTensor
with aNone
.Here, I initialize the
(cell_)velocities
as tensors full ofnan
instead, so that one can merge already processed and newly initialized states. During the first initialization, thefire
methods look fornan
rows and replace them with zeros.The error would also have been caught by the existing
test_autobatching
test if the number of iterations between swaps had been set to a smaller value (I changed it here). The reason is simply that within the current 10 iterations all states converge so that a completely new set of states is selected in the next iteration and already processed and newly initialised states never get merged.You can either change the test as done here or use the following code snippet to reproduce the error with the current version:
Code to reproduce
Summary by CodeRabbit
Bug Fixes
Tests