|
71 | 71 |
|
72 | 72 | # %% [markdown]
|
73 | 73 | """
|
74 |
| -SimState attributes fall into three categories: atomwise, batchwise, and global. |
| 74 | +SimState attributes fall into three categories: atomwise, systemwise, and global. |
75 | 75 |
|
76 | 76 | * Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`,
|
77 |
| - `masses`, `atomic_numbers`, and `batch`. Names are plural. |
78 |
| -* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for |
| 77 | + `masses`, `atomic_numbers`, and `system_idx`. Names are plural. |
| 78 | +* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for |
79 | 79 | the base SimState. Names are singular.
|
80 | 80 | * Global attributes have any other shape or type, just `pbc` here. Names are singular.
|
81 | 81 |
|
82 |
| -You can use the `infer_property_scope` function to analyze a state's properties. This |
| 82 | +For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's |
| 83 | +name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`: |
| 84 | +
|
| 85 | +_atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx") |
| 86 | +_system_attributes = ("cell",) |
| 87 | +_global_attributes = ("pbc",) |
| 88 | +
|
| 89 | +You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This |
83 | 90 | is mostly used internally but can be useful for debugging.
|
84 | 91 | """
|
85 | 92 |
|
86 | 93 | # %%
|
87 |
| -from torch_sim.state import infer_property_scope |
| 94 | +from torch_sim.state import get_attrs_for_scope |
88 | 95 |
|
89 |
| -scope = infer_property_scope(si_state) |
90 |
| -print(scope) |
| 96 | +# loop through each attribute: |
| 97 | +for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"): |
| 98 | + print(f"per-atom attribute: {attr_name}") |
| 99 | + print(f"value: {attr_value}") |
91 | 100 |
|
| 101 | +# or access the attributes via a dict: |
| 102 | +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501 |
| 103 | +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) |
92 | 104 |
|
93 | 105 | # %% [markdown]
|
94 | 106 | """
|
|
112 | 124 | f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems"
|
113 | 125 | )
|
114 | 126 |
|
115 |
| -# we can see how the shapes of batchwise, atomwise, and global properties change |
| 127 | +# we can see how the shapes of atomwise, systemwise, and global properties change |
116 | 128 | print(f"Positions shape: {multi_state.positions.shape}")
|
117 | 129 | print(f"Cell shape: {multi_state.cell.shape}")
|
118 | 130 | print(f"PBC: {multi_state.pbc}")
|
|
142 | 154 |
|
143 | 155 | SimState supports many convenience operations for manipulating batched states. Slicing
|
144 | 156 | is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state
|
145 |
| -containing only the first three batches. The other operations are available through the |
| 157 | +containing only the first three systems. The other operations are available through the |
146 | 158 | `pop`, `split`, `clone`, and `to` methods.
|
147 | 159 | """
|
148 | 160 |
|
|
182 | 194 | # %% [markdown]
|
183 | 195 | """
|
184 | 196 |
|
185 |
| -You can extract specific batches from a batched state using Python's slicing syntax. |
| 197 | +You can extract specific systems from a batched state using Python's slicing syntax. |
186 | 198 | This is extremely useful for analyzing specific systems or for implementing complex
|
187 | 199 | workflows where different systems need separate processing:
|
188 | 200 |
|
189 | 201 | The slicing interface follows Python's standard indexing conventions, making it
|
190 | 202 | intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the
|
191 |
| -selected batches, maintaining all the necessary properties and relationships. |
| 203 | +selected systems, maintaining all the necessary properties and relationships. |
192 | 204 |
|
193 | 205 | Note the difference between these operations:
|
194 |
| -- `split()` returns all batches as separate states but doesn't modify the original |
195 |
| -- `pop()` removes specified batches from the original state and returns them as |
| 206 | +- `split()` returns all systems as separate states but doesn't modify the original |
| 207 | +- `pop()` removes specified systems from the original state and returns them as |
196 | 208 | separate states
|
197 |
| -- `__getitem__` (slicing) creates a new state with specified batches without modifying |
| 209 | +- `__getitem__` (slicing) creates a new state with specified systems without modifying |
198 | 210 | the original
|
199 | 211 |
|
200 | 212 | This flexibility allows you to structure your simulation workflows in the most
|
|
203 | 215 | ### Splitting and Popping Batches
|
204 | 216 |
|
205 | 217 | SimState provides methods to split a batched state into separate states or to remove
|
206 |
| -specific batches: |
| 218 | +specific systems: |
207 | 219 | """
|
208 | 220 |
|
209 | 221 | # %% [markdown]
|
|
257 | 269 | )
|
258 | 270 |
|
259 | 271 | print("MDState properties:")
|
260 |
| -scope = infer_property_scope(md_state) |
261 |
| -print("Global properties:", scope["global"]) |
262 |
| -print("Per-atom properties:", scope["per_atom"]) |
263 |
| -print("Per-system properties:", scope["per_system"]) |
| 272 | +print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom"))) |
| 273 | +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) |
| 274 | +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) |
264 | 275 |
|
265 | 276 |
|
266 | 277 | # %% [markdown]
|
|
0 commit comments