|
4 | 4 | from .. import backend as bkd
|
5 | 5 | from .. import config
|
6 | 6 | from ..backend import backend_name
|
7 |
| -from ..utils import get_num_args, run_if_all_none |
| 7 | +from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class PDE(Data):
|
@@ -186,22 +186,7 @@ def train_next_batch(self, batch_size=None):
|
186 | 186 | config.comm.Bcast(self.train_x_bc, root=0)
|
187 | 187 | self.train_x = self.train_x_bc
|
188 | 188 | if config.parallel_scaling == "strong":
|
189 |
| - # Split the training points over each rank. |
190 |
| - # We drop last points in order to have the same number of points per rank |
191 |
| - if len(self.train_x_all) < config.world_size: |
192 |
| - raise ValueError( |
193 |
| - "The number of training points is smaller than the number of processes. Please use more points." |
194 |
| - ) |
195 |
| - train_x_all_shape = list( |
196 |
| - self.train_x_all.shape |
197 |
| - ) # We transform to list to support item assignment |
198 |
| - num_split = train_x_all_shape[0] // config.world_size |
199 |
| - train_x_all_shape[0] = num_split |
200 |
| - train_x_all_split = np.empty( |
201 |
| - train_x_all_shape, dtype=self.train_x_all.dtype |
202 |
| - ) |
203 |
| - config.comm.Scatter(self.train_x_all, train_x_all_split, root=0) |
204 |
| - self.train_x_all = train_x_all_split |
| 189 | + self.train_x_all = mpi_scatter_from_rank0(self.train_x_all) |
205 | 190 | if self.pde is not None:
|
206 | 191 | self.train_x = np.vstack((self.train_x, self.train_x_all))
|
207 | 192 | self.train_y = self.soln(self.train_x) if self.soln else None
|
|
0 commit comments