Skip to content

Commit

Permalink
Missed a few other TrackedArray, TrackedReal references
Browse files Browse the repository at this point in the history
  • Loading branch information
dingraha committed Jul 20, 2023
1 parent 5b72d81 commit a9f1dcc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/external.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,5 @@ function ChainRulesCore.rrule(::typeof(_provide_rule), func, x, p, mode, jacobia
return y, pullback
end

ReverseDiff.@grad_from_chainrules _provide_rule(func, x::TrackedArray, p, mode, jacobian, jvp, vjp)
ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:TrackedReal}, p, mode, jacobian, jvp, vjp)
ReverseDiff.@grad_from_chainrules _provide_rule(func, x::ReverseDiff.TrackedArray, p, mode, jacobian, jvp, vjp)
ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:ReverseDiff.TrackedReal}, p, mode, jacobian, jvp, vjp)
6 changes: 3 additions & 3 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af)
end

# register above rule for ReverseDiff
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{TrackedArray, AbstractArray{<:TrackedReal}}, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b::Union{TrackedArray, AbstractVector{<:TrackedReal}}, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, Af)


# function implicit_linear_inplace(A, b, y, Af)
Expand Down

0 comments on commit a9f1dcc

Please sign in to comment.