Skip to content

Commit 4363618

Browse files
authored
Add utils mpi_scatter_from_rank0 (#1298)
1 parent 4e33568 commit 4363618

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

deepxde/data/pde.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .. import backend as bkd
55
from .. import config
66
from ..backend import backend_name
7-
from ..utils import get_num_args, run_if_all_none
7+
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0
88

99

1010
class PDE(Data):
@@ -186,22 +186,7 @@ def train_next_batch(self, batch_size=None):
186186
config.comm.Bcast(self.train_x_bc, root=0)
187187
self.train_x = self.train_x_bc
188188
if config.parallel_scaling == "strong":
189-
# Split the training points over each rank.
190-
# We drop last points in order to have the same number of points per rank
191-
if len(self.train_x_all) < config.world_size:
192-
raise ValueError(
193-
"The number of training points is smaller than the number of processes. Please use more points."
194-
)
195-
train_x_all_shape = list(
196-
self.train_x_all.shape
197-
) # We transform to list to support item assignment
198-
num_split = train_x_all_shape[0] // config.world_size
199-
train_x_all_shape[0] = num_split
200-
train_x_all_split = np.empty(
201-
train_x_all_shape, dtype=self.train_x_all.dtype
202-
)
203-
config.comm.Scatter(self.train_x_all, train_x_all_split, root=0)
204-
self.train_x_all = train_x_all_split
189+
self.train_x_all = mpi_scatter_from_rank0(self.train_x_all)
205190
if self.pde is not None:
206191
self.train_x = np.vstack((self.train_x, self.train_x_all))
207192
self.train_y = self.soln(self.train_x) if self.soln else None

deepxde/utils/internal.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,34 @@ def get_num_args(func):
200200
# g = dummy(a.f)
201201
params = inspect.signature(func).parameters
202202
return len(params) - ("self" in params)
203+
204+
205+
def mpi_scatter_from_rank0(array, drop_last=True):
206+
"""Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks.
207+
208+
Args:
209+
array: Numpy array to be split.
210+
drop_last (bool): Whether to discard the remainder samples
211+
not divisible by world_size. Default: True.
212+
213+
Returns:
214+
array: Scattered Numpy array.
215+
"""
216+
# TODO: support drop_last=False
217+
if config.world_size == 1:
218+
return array
219+
if not drop_last:
220+
raise ValueError("Only support drop_last=True now.")
221+
if len(array) < config.world_size:
222+
raise ValueError(
223+
"The number of training points is smaller than the number of processes. Please use more points."
224+
)
225+
array_shape = list(array.shape) # We transform to list to support item assignment
226+
num_split = array_shape[0] // config.world_size
227+
array_shape[0] = num_split
228+
array_split = np.empty(array_shape, dtype=array.dtype)
229+
array = array[
230+
: num_split * config.world_size
231+
] # We truncate array size to be a multiple of num_split to prevent a MPI error.
232+
config.comm.Scatter(array, array_split, root=0)
233+
return array_split

0 commit comments

Comments
 (0)