Skip to content

Commit 47127ac

Browse files
authored
Add parallel scaling parameter dde.config.parallel_scaling (#1279)
1 parent 50cbf84 commit 47127ac

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

deepxde/config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .backend import backend_name, tf, torch, paddle
88
from .real import Real
99

10+
# Data parallel
11+
parallel_scaling = None
1012
# Data parallel via Horovod
1113
hvd = None
1214
comm = None
@@ -21,9 +23,10 @@
2123
if world_size > 1:
2224
from mpi4py import MPI
2325

26+
parallel_scaling = "weak"
2427
comm = MPI.COMM_WORLD
2528
tf.compat.v1.disable_eager_execution() # Without this line, Horovod broadcasting fails.
26-
rank = hvd.rank() # Only single node acceleration supported so far.
29+
rank = hvd.rank() # Only single node acceleration supported so far.
2730
if rank == 0:
2831
print(f"\nParallel training with {world_size} processes.\n")
2932
else:
@@ -198,3 +201,15 @@ def disable_xla_jit():
198201
This is equivalent with ``enable_xla_jit(False)``.
199202
"""
200203
enable_xla_jit(False)
204+
205+
206+
def set_parallel_scaling(scaling_mode):
207+
"""Sets the scaling mode for data parallel acceleration.
208+
Weak scaling involves increasing the problem size proportionally with the number of processors,
209+
while strong scaling involves keeping the problem size fixed and increasing the number of processors.
210+
211+
Args:
212+
scaling_mode (str): Whether 'weak' or 'strong'
213+
"""
214+
global parallel_scaling
215+
parallel_scaling = scaling_mode

deepxde/data/pde.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def __init__(
100100
raise ValueError(
101101
"Parallel training via Horovod only supports pseudo train distribution."
102102
)
103+
if config.parallel_scaling == "strong":
104+
raise ValueError(
105+
"Strong scaling is not supported with tensorflow.compat.v1. Please use weak scaling."
106+
)
103107
self.anchors = None if anchors is None else anchors.astype(config.real(np))
104108
self.exclusions = exclusions
105109

0 commit comments

Comments
 (0)