Skip to content

Commit 41fcdf2

Browse files
authored
Fixes for Julia nightly (#300)
* Adapt to JuliaLang/julia#56509 * Adapt to JuliaLang/julia#54734 * Use StmtRange explicitly * Adapt to JuliaLang/julia#57230 * Reuse Cthulhu code structure for Compiler cache/finish overrides * Adapt to JuliaLang/julia#57475 * Adapt to JuliaLang/julia#55976 * Adapt to JuliaLang/julia#54734 * Use CC instead of .Compiler * Implement ir.argtypes[1] fix from JuliaLang/julia#54458 * Comment out failing tests To highlight which are broken, should probably be fixed before merging * Treat `getproperty(::Module, ::Symbol)` like GlobalRefs * Uncomment passing tests, explicitly mark others as broken * Evaluate GlobalRef only if binding is defined * Use `rrule` for getproperty(::Module, ::Symbol) * Bump compat bound for StructArrays * Raise compat bound for Cthulhu * Revert `isconst` change now that it is fixed * Adapt to `finishinfer!` signature change --------- Co-authored-by: Cédric Belmant <[email protected]>
1 parent 21747f8 commit 41fcdf2

File tree

14 files changed

+108
-52
lines changed

14 files changed

+108
-52
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ ChainRules = "1.44.6"
2222
ChainRulesCore = "1.20"
2323
Combinatorics = "1"
2424
Compiler = "~0"
25-
Cthulhu = "2.10.1"
25+
Cthulhu = "2.16.3"
2626
OffsetArrays = "1"
2727
PrecompileTools = "1"
2828
StaticArrays = "1"
29-
StructArrays = "0.6"
29+
StructArrays = "0.6, 0.7"
3030
julia = "1.10"
3131

3232
[extras]

src/analysis/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
3434
# discover what they are. frules should be written in such a way that
3535
# whether or not they return `nothing`, only depends on the non-tangent arguments
3636
frule_arginfo = ArgInfo(nothing, frule_argtypes)
37-
frule_si = StmtInfo(true)
37+
frule_si = StmtInfo(true, false)
3838
# turn off frule analysis in the frule to avoid cycling
3939
interp′ = disable_forward(interp)
4040
frule_call = CC.abstract_call_gf_by_type(interp′,

src/codegen/forward_demand.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,11 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
352352
end
353353
end
354354

355-
method_info = CC.MethodInfo(src)
355+
info = @static VERSION v"1.12.0-DEV.1293" ? CC.SpecInfo(src) : CC.MethodInfo(src)
356356
argtypes = ir.argtypes[1:mi.def.nargs]
357357
world = get_inference_world(interp)
358-
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
359-
rt = CC._ir_abstract_constant_propagation(interp, irsv)
358+
irsv = IRInterpretationState(interp, info, ir, mi, argtypes, world, src.min_world, src.max_world)
359+
rt = CC.ir_abstract_constant_propagation(interp, irsv)
360360

361361
ir = compact!(ir)
362362

src/codegen/reverse.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci,
1414
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
1515
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
1616
end
17-
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
1817
else
1918
oc_nargs = Int64(meth_nargs)
20-
Expr(:new_opaque_closure, typ, Union{}, Any,
21-
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
19+
ocm = Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci)
2220
end
21+
oc = Expr(:new_opaque_closure, typ, Union{}, Any, true, ocm, revs...)
22+
@static VERSION < v"1.12.0-DEV.691" ? deleteat!(oc.args, 4) : nothing
23+
oc
2324
end
2425

2526
function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::Int, interp=nothing, curs=nothing)

src/extra_rules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
268268
val, Δ->(NoTangent(), NoTangent(), Δ)
269269
end
270270

271+
# XXX: We should instead skip differentiation in the IR.
272+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol)
273+
val = getproperty(mod, name)
274+
val, Δ->(NoTangent(), NoTangent(), NoTangent())
275+
end
276+
271277
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
272278

273279
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495

src/stage1/compiler_utils.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Utilities that should probably go into CC
2-
using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter
2+
using .CC: IRCode, CFG, BasicBlock, BBIdxIter
33

44
function Base.push!(cfg::CFG, bb::BasicBlock)
55
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
@@ -30,10 +30,6 @@ if VERSION < v"1.12.0-DEV.1268"
3030

3131
Base.copy(ir::IRCode) = CC.copy(ir)
3232

33-
CC.BasicBlock(x::UnitRange) =
34-
BasicBlock(StmtRange(first(x), last(x)))
35-
CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
36-
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
3733
Base.length(c::CC.NewNodeStream) = CC.length(c)
3834
Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...)
3935
Base.size(x::CC.UnitRange) = CC.size(x)

src/stage1/generated.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ struct ∂⃖recurse{N}; end
66

77
include("recurse.jl")
88

9-
function generate_lambda_ex(world::UInt, source::LineNumberNode,
9+
# source is a Method starting from https://github.com/JuliaLang/julia/pull/57230
10+
function generate_lambda_ex(world::UInt, source::Union{Method,LineNumberNode},
1011
args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr)
1112
stub = Core.GeneratedFunctionStub(identity, args, sparams)
1213
return stub(world, source, body)
@@ -16,7 +17,7 @@ struct NonTransformableError
1617
args
1718
end
1819

19-
function perform_optic_transform(world::UInt, source::LineNumberNode,
20+
function perform_optic_transform(world::UInt, source::Union{Method,LineNumberNode},
2021
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
2122
@assert N >= 1
2223

src/stage1/recurse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ function split_critical_edges!(ir)
183183
bb = ir.stmts[i][:inst].args[1]
184184
ir.stmts[i][:inst] = nothing
185185
bbnew = bb + ninserted
186-
insert!(cfg.blocks, bbnew, BasicBlock(i:i))
186+
insert!(cfg.blocks, bbnew, BasicBlock(StmtRange(i:i)))
187187
bb_rename_offset[bb] += 1
188188
bblock = cfg.blocks[bbnew+1]
189-
cfg.blocks[bbnew+1] = BasicBlock((i+1):last(bblock.stmts),
189+
cfg.blocks[bbnew+1] = BasicBlock(StmtRange((i+1):last(bblock.stmts)),
190190
bblock.preds, bblock.succs)
191191
i += 1
192192
while i <= last(bblock.stmts)

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
222222
return ci
223223
end
224224

225-
function perform_fwd_transform(world::UInt, source::LineNumberNode,
225+
function perform_fwd_transform(world::UInt, source::Union{Method,LineNumberNode},
226226
@nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E}
227227
if all(x->x <: ZeroBundle, args)
228228
return generate_lambda_ex(world, source,

src/stage2/forward.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ end
2121
# unlikely to be the actual interface. For now, it is used for testing.
2222
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false)
2323
interp = ADInterpreter(; forward=true, backward=false)
24-
match = Base._which(tt)
25-
frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
26-
mi = frame.linfo
24+
mi = @ccall jl_method_lookup_by_tt(tt::Any, Base.tls_world_age()::Csize_t, #= method table =# nothing::Any)::Ref{MethodInstance}
25+
ci = CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
2726

2827
src = CC.copy(interp.unopt[0][mi].src)
29-
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)
28+
ir = CC.copy((@atomic :monotonic ci.inferred).ir::IRCode)
3029

3130
# Find all Return Nodes
3231
vals = Pair{SSAValue, Int}[]
@@ -83,6 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = fa
8382
end
8483

8584
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode)
85+
ir.argtypes[1] = Tuple{}
8686

8787
return OpaqueClosure(ir)
8888
end

0 commit comments

Comments
 (0)