Skip to content

Commit d0688a6

Browse files
committed
Merge pull request #4054 from amitmurthy/amitm/pmapfix
Changed iterator task in pmap to a function. Fixes #4035 and #4034
2 parents 286d166 + 9420bbd commit d0688a6

File tree

2 files changed

+72
-36
lines changed

2 files changed

+72
-36
lines changed

base/multi.jl

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,70 +1313,66 @@ pmap(f) = f()
13131313
function pmap(f, lsts...; err_retry=true, err_stop=false)
13141314
len = length(lsts)
13151315
np = nprocs()
1316-
retrycond = Condition()
13171316

13181317
results = Dict{Int,Any}()
1319-
function setresult(idx,v)
1320-
results[idx] = v
1321-
notify(retrycond)
1322-
end
13231318

13241319
retryqueue = {}
1325-
function retry(idx,v,ex)
1326-
push!(retryqueue, (idx,v,ex))
1327-
notify(retrycond)
1328-
end
1329-
13301320
task_in_err = false
13311321
is_task_in_error() = task_in_err
13321322
set_task_in_error() = (task_in_err = true)
13331323

13341324
nextidx = 0
13351325
getnextidx() = (nextidx += 1)
1336-
getcurridx() = nextidx
13371326

13381327
states = [start(lsts[idx]) for idx in 1:len]
1339-
function producer()
1340-
while true
1341-
if (is_task_in_error() && err_stop)
1342-
break
1343-
elseif !isempty(retryqueue)
1344-
produce(shift!(retryqueue)[1:2])
1345-
elseif all([!done(lsts[idx],states[idx]) for idx in 1:len])
1346-
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
1347-
map(idx->states[idx]=nxts[idx][2], 1:len)
1348-
nxtvals = [x[1] for x in nxts]
1349-
produce((getnextidx(), nxtvals))
1350-
elseif (length(results) == getcurridx())
1351-
break
1352-
else
1353-
wait(retrycond)
1354-
end
1328+
function getnext_tasklet()
1329+
if is_task_in_error() && err_stop
1330+
return nothing
1331+
elseif all([!done(lsts[idx],states[idx]) for idx in 1:len])
1332+
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
1333+
map(idx->states[idx]=nxts[idx][2], 1:len)
1334+
nxtvals = [x[1] for x in nxts]
1335+
return (getnextidx(), nxtvals)
1336+
1337+
elseif !isempty(retryqueue)
1338+
return shift!(retryqueue)
1339+
else
1340+
return nothing
13551341
end
13561342
end
13571343

1358-
pt = Task(producer)
13591344
@sync begin
13601345
for p=1:np
13611346
wpid = PGRP.workers[p].id
13621347
if wpid != myid() || np == 1
13631348
@async begin
1364-
for (idx,nxtvals) in pt
1349+
tasklet = getnext_tasklet()
1350+
while (tasklet != nothing)
1351+
(idx, fvals) = tasklet
13651352
try
1366-
result = remotecall_fetch(wpid, f, nxtvals...)
1367-
isa(result, Exception) ? ((wpid == myid()) ? rethrow(result) : throw(result)) : setresult(idx, result)
1353+
result = remotecall_fetch(wpid, f, fvals...)
1354+
if isa(result, Exception)
1355+
((wpid == myid()) ? rethrow(result) : throw(result))
1356+
else
1357+
results[idx] = result
1358+
end
13681359
catch ex
1369-
err_retry ? retry(idx,nxtvals,ex) : setresult(idx, ex)
1360+
if err_retry
1361+
push!(retryqueue, (idx,fvals, ex))
1362+
else
1363+
results[idx] = ex
1364+
end
13701365
set_task_in_error()
13711366
break # remove this worker from accepting any more tasks
13721367
end
1368+
1369+
tasklet = getnext_tasklet()
13731370
end
13741371
end
13751372
end
13761373
end
13771374
end
13781375

1379-
!istaskdone(pt) && throwto(pt, InterruptException())
13801376
for failure in retryqueue
13811377
results[failure[1]] = failure[3]
13821378
end

test/parallel.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,46 @@ et=toq()
5151
@test isready(rr1)
5252
@test !isready(rr3)
5353

54-
# make sure exceptions propagate when waiting on Tasks
55-
# TODO: should be enabled but the error is printed by the event loop
56-
#@test_throws (@sync (@async error("oops")))
54+
55+
# TODO: The below block should be always enabled but the error is printed by the event loop
56+
57+
# Hence in the event of any relevant changes to the parallel codebase,
58+
# please define an ENV variable PTEST_FULL and ensure that the below block is
59+
# executed successfully before committing/merging
60+
61+
if haskey(ENV, "PTEST_FULL")
62+
println("START of parallel tests that print errors")
63+
64+
# make sure exceptions propagate when waiting on Tasks
65+
@test_throws (@sync (@async error("oops")))
66+
67+
# pmap tests
68+
# needs at least 4 processors (which are being created above for the @parallel tests)
69+
s = "a"*"bcdefghijklmnopqrstuvwxyz"^100;
70+
ups = "A"*"BCDEFGHIJKLMNOPQRSTUVWXYZ"^100;
71+
@test ups == bytestring(Uint8[uint8(c) for c in pmap(x->uppercase(x), s)])
72+
@test ups == bytestring(Uint8[uint8(c) for c in pmap(x->uppercase(char(x)), s.data)])
73+
74+
# retry, on error exit
75+
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=true, err_stop=true);
76+
@test length(res) < length(ups)
77+
@test isa(res[1], Exception)
78+
79+
# no retry, on error exit
80+
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=false, err_stop=true);
81+
@test length(res) < length(ups)
82+
@test isa(res[1], Exception)
83+
84+
# retry, on error continue
85+
res = pmap(x->iseven(myid()) ? error("test error. don't panic.") : uppercase(x), s; err_retry=true, err_stop=false);
86+
@test length(res) == length(ups)
87+
@test ups == bytestring(Uint8[uint8(c) for c in res])
88+
89+
# no retry, on error continue
90+
res = pmap(x->(x=='a') ? error("test error. don't panic.") : uppercase(x), s; err_retry=false, err_stop=false);
91+
@test length(res) == length(ups)
92+
@test isa(res[1], Exception)
93+
94+
println("END of parallel tests that print errors")
95+
end
96+

0 commit comments

Comments
 (0)