Skip to content

Commit 9602c6b

Browse files
Merge pull request #1205 from devmotion/dw/notimplemented
Handle `ChainRulesCore.NotImplemented`
2 parents 1eb80c5 + b15eff1 commit 9602c6b

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.37"
3+
version = "0.6.38"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/compiler/chainrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
106106
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
107107
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
108108
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
109+
@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing
109110
for T_outer in (:Tuple, :NamedTuple)
110111
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
111112
# branch that changes output type, because nested AD on that kinda thing makes Zygote less

test/chainrules.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,18 @@ using Zygote: ZygoteRuleConfig
263263
@test (1.0,) == Zygote.gradient(oout_id_outer, π)
264264
@test oout_id_rrule_hitcount[] == 0
265265
end
266+
267+
# issue #1204
268+
@testset "NotImplemented" begin
269+
f_notimplemented(x) = x
270+
@scalar_rule f_notimplemented(x) @not_implemented("not implemented :(")
271+
@test Zygote.gradient(f_notimplemented, 0.1) === (nothing,)
272+
@test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,)
273+
if isdefined(Base, :only)
274+
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,)
275+
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
276+
end
277+
end
266278
end
267279

268280
@testset "ChainRulesCore.rrule_via_ad" begin

0 commit comments

Comments
 (0)