Skip to content

Commit aad9be5

Browse files
committed
update model interface to have __call__ method
1 parent ef19777 commit aad9be5

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch_sim/models/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens
169169
```
170170
"""
171171

172+
@abstractmethod
173+
def __call__(*args, **kwargs) -> dict[str, torch.Tensor]:
174+
"""Where the input is fed into the model. This is to help typecheckers."""
175+
172176

173177
def validate_model_outputs( # noqa: C901, PLR0915
174178
model: ModelInterface, device: torch.device, dtype: torch.dtype

0 commit comments

Comments
 (0)