diff --git a/src/external.jl b/src/external.jl index f6592dc..4c87777 100644 --- a/src/external.jl +++ b/src/external.jl @@ -196,5 +196,5 @@ function ChainRulesCore.rrule(::typeof(_provide_rule), func, x, p, mode, jacobia return y, pullback end -ReverseDiff.@grad_from_chainrules _provide_rule(func, x::TrackedArray, p, mode, jacobian, jvp, vjp) -ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:TrackedReal}, p, mode, jacobian, jvp, vjp) +ReverseDiff.@grad_from_chainrules _provide_rule(func, x::ReverseDiff.TrackedArray, p, mode, jacobian, jvp, vjp) +ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:ReverseDiff.TrackedReal}, p, mode, jacobian, jvp, vjp) diff --git a/src/linear.jl b/src/linear.jl index 4ba4f16..080f67b 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -199,9 +199,9 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af) end # register above rule for ReverseDiff -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{TrackedArray, AbstractArray{<:TrackedReal}}, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b::Union{TrackedArray, AbstractVector{<:TrackedReal}}, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, Af) # function implicit_linear_inplace(A, b, y, Af)