Skip to content

Commit 0245126

Browse files
committed
Add torch default dtype constant
1 parent 8a9ee75 commit 0245126

File tree

9 files changed

+38
-32
lines changed

9 files changed

+38
-32
lines changed

autoemulate/experimental/data/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
InputLike,
77
OutputLike,
88
TensorLike,
9+
TorchDefaultDType,
910
TorchScalarDType,
1011
)
1112
from sklearn.utils.validation import check_X_y
@@ -28,9 +29,9 @@ def _convert_to_dataset(
2829
"""
2930
# Convert input to Dataset if not already
3031
if isinstance(x, np.ndarray):
31-
x = torch.tensor(x, dtype=torch.float32)
32+
x = torch.tensor(x, dtype=TorchDefaultDType)
3233
if isinstance(y, np.ndarray):
33-
y = torch.tensor(y, dtype=torch.float32)
34+
y = torch.tensor(y, dtype=TorchDefaultDType)
3435

3536
if isinstance(x, torch.Tensor | np.ndarray) and isinstance(
3637
y, torch.Tensor | np.ndarray
@@ -77,7 +78,7 @@ def _convert_to_tensors(
7778
cls,
7879
x: InputLike,
7980
y: InputLike | None = None,
80-
dtype: torch.dtype = torch.float32,
81+
dtype: torch.dtype = TorchDefaultDType,
8182
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
8283
"""
8384
Convert InputLike x, y to Tensor or tuple of Tensors.

autoemulate/experimental/emulators/polynomials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5656
# Transform input using the fitted PolynomialFeatures
5757
x_np, _ = self._convert_to_numpy(x)
5858
x_poly = self.poly.transform(x_np)
59-
x_poly_tensor = torch.tensor(x_poly, dtype=torch.float32, device=self.device)
59+
x_poly_tensor = torch.tensor(x_poly, dtype=x.dtype, device=self.device)
6060
return self.linear(x_poly_tensor)
6161

6262
@staticmethod

autoemulate/experimental/learners/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from autoemulate.experimental.emulators.base import Emulator
1212
from autoemulate.experimental.simulations.base import Simulator
1313

14-
from ..types import GaussianLike, TensorLike
14+
from ..types import GaussianLike, TensorLike, TorchDefaultDType
1515

1616

1717
@dataclass(kw_only=True)
@@ -40,6 +40,7 @@ class Learner(ValidationMixin, ABC):
4040
y_train: TensorLike
4141
in_dim: int = field(init=False)
4242
out_dim: int = field(init=False)
43+
dtype: torch.dtype = field(default=TorchDefaultDType)
4344

4445
def __post_init__(self):
4546
"""
@@ -199,8 +200,8 @@ def summary(self):
199200
cumulative number of queries.
200201
"""
201202
# pull histories into float tensors
202-
mse = torch.tensor(self.metrics["mse"], dtype=torch.float32)
203-
q = torch.tensor(self.metrics["n_queries"], dtype=torch.float32)
203+
mse = torch.tensor(self.metrics["mse"], dtype=self.dtype)
204+
q = torch.tensor(self.metrics["n_queries"], dtype=self.dtype)
204205

205206
# build per-query ratios safely (avoid zero division)
206207
d = {}

autoemulate/experimental/simulations/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from autoemulate.experimental.data.utils import ValidationMixin, set_random_seed
88
from autoemulate.experimental.logging_config import configure_logging
9-
from autoemulate.experimental.types import TensorLike
9+
from autoemulate.experimental.types import TensorLike, TorchDefaultDType
1010
from autoemulate.experimental_design import LatinHypercube
1111

1212
logger = logging.getLogger("autoemulate")
@@ -101,7 +101,10 @@ def out_dim(self) -> int:
101101
return self._out_dim
102102

103103
def sample_inputs(
104-
self, n_samples: int, random_seed: int | None = None
104+
self,
105+
n_samples: int,
106+
random_seed: int | None = None,
107+
dtype: torch.dtype = TorchDefaultDType,
105108
) -> TensorLike:
106109
"""
107110
Generate random samples using Latin Hypercube Sampling.
@@ -123,10 +126,8 @@ def sample_inputs(
123126
set_random_seed(random_seed) # type: ignore PGH003
124127
lhd = LatinHypercube(self.param_bounds)
125128
sample_array = lhd.sample(n_samples)
126-
# TODO: have option to set dtype and ensure consistency throughout codebase?
127-
# added here as method was returning float64 and elsewhere had tensors of
128-
# float32 and this caused issues
129-
return torch.tensor(sample_array, dtype=torch.float32)
129+
130+
return torch.tensor(sample_array, dtype=dtype)
130131

131132
@abstractmethod
132133
def _forward(self, x: TensorLike) -> TensorLike:

autoemulate/experimental/simulations/epidemic.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from autoemulate.experimental.simulations.base import Simulator
4-
from autoemulate.experimental.types import TensorLike
4+
from autoemulate.experimental.types import TensorLike, TorchDefaultDType
55
from autoemulate.simulations.epidemic import simulate_epidemic
66

77

@@ -10,11 +10,7 @@ class Epidemic(Simulator):
1010
Simulator of infectious disease spread (SIR).
1111
"""
1212

13-
def __init__(
14-
self,
15-
param_ranges=None,
16-
output_names=None,
17-
):
13+
def __init__(self, param_ranges=None, output_names=None):
1814
if param_ranges is None:
1915
param_ranges = {"beta": (0.1, 0.5), "gamma": (0.01, 0.2)}
2016
if output_names is None:
@@ -36,5 +32,4 @@ def _forward(self, x: TensorLike) -> TensorLike:
3632
Peak infection rate.
3733
"""
3834
y = simulate_epidemic(x.cpu().numpy()[0])
39-
# TODO (#537): update with default dtype
40-
return torch.tensor([y], dtype=torch.float32).view(-1, 1)
35+
return torch.tensor([y], dtype=TorchDefaultDType).view(-1, 1)

autoemulate/experimental/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,12 @@
2020
DeviceLike: TypeAlias = str | torch.device
2121

2222
# Torch dtype's
23-
TorchScalarDType = (torch.float32, torch.float64, torch.int32, torch.int64)
23+
TorchScalarDType: tuple[torch.dtype, ...] = (
24+
torch.float32,
25+
torch.float64,
26+
torch.int32,
27+
torch.int64,
28+
)
29+
30+
# Default torch dtype (float32)
31+
TorchDefaultDType: torch.dtype = torch.float32

tests/experimental/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
import torch
4+
from autoemulate.experimental.types import TorchDefaultDType
45
from sklearn.datasets import make_regression
56

67
N_S = 20
@@ -115,7 +116,7 @@ def noisy_data():
115116
Generate a highly noisy dataset to test stochasticity effects.
116117
"""
117118
rng = np.random.RandomState(0)
118-
x = torch.tensor(rng.normal(size=(100, 2)), dtype=torch.float32)
119-
y = torch.tensor(rng.normal(size=(100,)), dtype=torch.float32)
119+
x = torch.tensor(rng.normal(size=(100, 2)), dtype=TorchDefaultDType)
120+
y = torch.tensor(rng.normal(size=(100,)), dtype=TorchDefaultDType)
120121
x2 = x[:4].clone()
121122
return x, y, x2

tests/experimental/test_experimental_base_simulator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import pytest
22
import torch
33
from autoemulate.experimental.simulations.base import Simulator
4-
from autoemulate.experimental.types import TensorLike
4+
from autoemulate.experimental.types import TensorLike, TorchDefaultDType
55
from torch import Tensor
66

77

88
class MockSimulator(Simulator):
99
"""Mock implementation of Simulator for testing purposes"""
1010

1111
def __init__(
12-
self,
13-
parameters_range: dict[str, tuple[float, float]],
14-
output_names: list[str],
12+
self, parameters_range: dict[str, tuple[float, float]], output_names: list[str]
1513
):
1614
# Call parent constructor
1715
super().__init__(parameters_range, output_names)
@@ -90,7 +88,7 @@ def test_sample_inputs(mock_simulator):
9088
def test_forward(mock_simulator):
9189
"""Test that forward method works correctly"""
9290
# Create test input
93-
test_input = torch.tensor([[0.5, 0.0, 7.5]], dtype=torch.float32)
91+
test_input = torch.tensor([[0.5, 0.0, 7.5]], dtype=TorchDefaultDType)
9492

9593
# Get output
9694
output = mock_simulator.forward(test_input)
@@ -109,7 +107,7 @@ def test_forward_batch(mock_simulator):
109107
# Create test batch
110108
n_samples = 3
111109
batch = torch.tensor(
112-
[[0.5, 0.0, 7.5], [0.2, 0.3, 6.0], [0.8, -0.5, 9.0]], dtype=torch.float32
110+
[[0.5, 0.0, 7.5], [0.2, 0.3, 6.0], [0.8, -0.5, 9.0]], dtype=TorchDefaultDType
113111
)
114112

115113
# Process batch
@@ -188,7 +186,7 @@ def _forward(self, x: TensorLike) -> TensorLike:
188186
[0.1, 0.5, 0.5], # Below threshold
189187
[0.7, 0.5, 0.5], # Above threshold
190188
],
191-
dtype=torch.float32,
189+
dtype=TorchDefaultDType,
192190
)
193191

194192
# This should process all samples without errors

tests/experimental/test_experimental_preprocessors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import torch
33
from autoemulate.experimental.data.preprocessors import Standardizer
4+
from autoemulate.experimental.types import TorchDefaultDType
45

56

67
class TestStandardizer:
@@ -37,7 +38,7 @@ def test_small_std_replaced(self):
3738
Test that small std values are replaced with 1.0.
3839
"""
3940
mean = torch.zeros((1, 2))
40-
std = torch.tensor([[1e-40, 2.0]], dtype=torch.float32)
41+
std = torch.tensor([[1e-40, 2.0]], dtype=TorchDefaultDType)
4142
s = Standardizer(mean, std)
4243
assert s.std[0, 0] == 1.0
4344
assert s.std[0, 1] == 2.0

0 commit comments

Comments
 (0)