1
1
import pytest
2
2
import torch
3
3
from autoemulate .experimental .simulations .base import Simulator
4
- from autoemulate .experimental .types import TensorLike
4
+ from autoemulate .experimental .types import TensorLike , TorchDefaultDType
5
5
from torch import Tensor
6
6
7
7
8
8
class MockSimulator (Simulator ):
9
9
"""Mock implementation of Simulator for testing purposes"""
10
10
11
11
def __init__ (
12
- self ,
13
- parameters_range : dict [str , tuple [float , float ]],
14
- output_names : list [str ],
12
+ self , parameters_range : dict [str , tuple [float , float ]], output_names : list [str ]
15
13
):
16
14
# Call parent constructor
17
15
super ().__init__ (parameters_range , output_names )
@@ -90,7 +88,7 @@ def test_sample_inputs(mock_simulator):
90
88
def test_forward (mock_simulator ):
91
89
"""Test that forward method works correctly"""
92
90
# Create test input
93
- test_input = torch .tensor ([[0.5 , 0.0 , 7.5 ]], dtype = torch . float32 )
91
+ test_input = torch .tensor ([[0.5 , 0.0 , 7.5 ]], dtype = TorchDefaultDType )
94
92
95
93
# Get output
96
94
output = mock_simulator .forward (test_input )
@@ -109,7 +107,7 @@ def test_forward_batch(mock_simulator):
109
107
# Create test batch
110
108
n_samples = 3
111
109
batch = torch .tensor (
112
- [[0.5 , 0.0 , 7.5 ], [0.2 , 0.3 , 6.0 ], [0.8 , - 0.5 , 9.0 ]], dtype = torch . float32
110
+ [[0.5 , 0.0 , 7.5 ], [0.2 , 0.3 , 6.0 ], [0.8 , - 0.5 , 9.0 ]], dtype = TorchDefaultDType
113
111
)
114
112
115
113
# Process batch
@@ -188,7 +186,7 @@ def _forward(self, x: TensorLike) -> TensorLike:
188
186
[0.1 , 0.5 , 0.5 ], # Below threshold
189
187
[0.7 , 0.5 , 0.5 ], # Above threshold
190
188
],
191
- dtype = torch . float32 ,
189
+ dtype = TorchDefaultDType ,
192
190
)
193
191
194
192
# This should process all samples without errors
0 commit comments