Skip to content

Commit

Permalink
add rrule frule; adjust how parameters are passed (#326)
Browse files Browse the repository at this point in the history
* add rrule frule; adjust how parameters are passed

* update docs

* relax versioning

* version proof
  • Loading branch information
jverzani authored Sep 28, 2022
1 parent 89e2325 commit 1afc30c
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 13 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "Roots"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "2.0.4"
version = "2.0.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
ChainRulesCore = "1"
CommonSolve = "0.1, 0.2"
Setfield = "0.7, 0.8, 1"
julia = "1.0"
Expand All @@ -21,6 +23,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "Unitful"]
test = ["JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "Unitful", "Zygote"]
70 changes: 70 additions & 0 deletions docs/src/roots.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,76 @@ plot!(x -> flight(x, tstar), 0, howfar(tstar))
show(current()) # hide
```

## Sensitivity

For functions with parameters, $f(x,p)$, derivatives with respect to the $p$ variable(s) may be of interest, for example to see how sensitive the solution is to variations in the parameter.

Using the implicit function theorem and following these [notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf) or this [paper](https://arxiv.org/pdf/2105.15183.pdf) on the adjoint method, we can auto-differentiate without pushing that machinery through `find_zero`.

The solution, $x^*(p)$, provided by `find_zero` depends on the parameter(s), $p$. Notationally,

$$
f(x^*(p), p) = 0
$$

The implicit function theorem has conditions guaranteeing the
existence and differentiability of $x^*(p)$. Assuming these hold, taking the gradient (derivative) in $p$ of both sides, gives by the chain rule:

$$
\frac{\partial}{\partial_x}f(x^*(p),p)
\frac{\partial}{\partial_p}x^*(p) +
\frac{\partial}{\partial_p}f(x^*(p),p) I = 0.
$$

Since the partial in $x$ is a scalar quantity, we can divide to solve:

$$
\frac{\partial}{\partial_p}x^*(p) =
-\frac{
\frac{\partial}{\partial_p}f(x^*(p),p)
}{
\frac{\partial}{\partial_x}f(x^*(p),p)
}
$$


For example, using `ForwardDiff`, we have:

```@example roots
f(x, p) = x^2 - p # p a scalar
p = 2
xᵅ = find_zero(f, 1, Order1(), p)
fₓ = ForwardDiff.derivative(x -> f(x, p), xᵅ)
fₚ = ForwardDiff.derivative(p -> f(xᵅ, p), p)
- fₚ / fₓ
```

This problem can be solved analytically, of course, to see $x^\alpha(p) = \sqrt{p}$, so we can easily compare:

```@example roots
ForwardDiff.derivative(sqrt, 2)
```


The use with a vector of parameters is similar, only `derivative` is replaced by `gradient` for the `p` variable:

```@example roots
f(x, p) = x^2 - p[1]*x + p[2]
p = [3.0, 1.0]
x₀ = 1.0
xᵅ = find_zero(f, x₀, Order1(), p)
fₓ = ForwardDiff.derivative(x -> f(x, p), xᵅ)
fₚ = ForwardDiff.gradient(p -> f(xᵅ, p), p)
- fₚ / fₓ
```

The package provides a `ChainRulesCore.rrule` and `ChainRulesCore.frule` implementation that should allow automatic differentiation packages relying on `ChainRulesCore` (e.g., `Zygote`) to differentiate in `p` using the above approach.



## Potential issues
Expand Down
3 changes: 2 additions & 1 deletion src/Roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Printf
import CommonSolve
import CommonSolve: solve, solve!, init
using Setfield
import ChainRulesCore

export fzero, fzeros, secant_method

Expand Down Expand Up @@ -37,7 +38,7 @@ include("functions.jl")
include("trace.jl")
include("find_zero.jl")
include("hybrid.jl")

include("chain_rules.jl")

include("Bracketing/bracketing.jl")
include("Bracketing/bisection.jl")
Expand Down
70 changes: 70 additions & 0 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# View find_zero as solving `f(x, p) = 0` for `xᵅ(p)`.
# This is implicitly defined. By the implicit function theorem, we have:
# ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0
# or ∂xᵅ/∂ₚ = ∂/∂ₚf(xᵅ, p) / ∂/∂ₓ f(xᵅ, p)

# does this work?
# It doesn't pass a few of the tests of ChainRulesTestUtils
function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig,
(Δself, Δp),
::typeof(solve),
ZP::ZeroProblem,
M::AbstractUnivariateZeroMethod,
p;
kwargs...)


xᵅ = solve(ZP, M, p; kwargs...)

F = p -> Callable_Function(M, ZP.F, p)
fₓ(x) = first(F(p)(x))
fₚ(p) = - first(F(p)(xᵅ))

function pushforward_find_zero(fₓ, fₚ, xᵅ, p, Δp)
# is scalar?
o = typeof(p) == eltype(p) ? one(p) : ones(eltype(p), size(p))
fx = ChainRulesCore.frule_via_ad(config,
(ChainRulesCore.NoTangent(), o),
fₓ, xᵅ)[2]
fp = ChainRulesCore.frule_via_ad(config,
(ChainRulesCore.NoTangent(), o),
fₚ, p)[2]

dp = ChainRulesCore.unthunk(Δp)
δ = - (fp * dp) / fx
δ
end

xᵅ, pushforward_find_zero(fₓ, fₚ, xᵅ, p, Δp)

end

## modified from
## https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/src/implicit_function.jl
function ChainRulesCore.rrule(
rc::ChainRulesCore.RuleConfig,
::typeof(solve),
ZP::ZeroProblem,
M::AbstractUnivariateZeroMethod,
p;
kwargs...)

xᵅ = solve(ZP, M, p; kwargs...)
F = p -> Callable_Function(M, ZP.F, p)
fₓ(x) = first(F(p)(x))
fₚ(p) = - first(F(p)(xᵅ))

pullback_Aᵀ = last ChainRulesCore.rrule_via_ad(rc, fₓ, xᵅ)[2]
pullback_Bᵀ = last ChainRulesCore.rrule_via_ad(rc, fₚ, p)[2]

function pullback_find_zero(dy)
dy = ChainRulesCore.unthunk(dy)
u = inv(pullback_Aᵀ(1/dy))
dx = pullback_Bᵀ(u)
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent(), dx)
end

return xᵅ, pullback_find_zero
end
13 changes: 9 additions & 4 deletions src/find_zero.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
find_zero(f, x0, M, [N::AbstractBracketingMethod]; kwargs...)
find_zero(f, x0, M, [N::AbstractBracketingMethod], p′=nothing; kwargs...)
Interface to one of several methods for finding zeros of a univariate function, e.g. solving ``f(x)=0``.
Expand Down Expand Up @@ -206,13 +206,14 @@ the algorithm.
function find_zero(
f,
x0,
M::AbstractUnivariateZeroMethod;
M::AbstractUnivariateZeroMethod,
p′=nothing;
p=nothing,
verbose=false,
tracks::AbstractTracks=NullTracks(),
kwargs...,
)
xstar = solve(ZeroProblem(f, x0), M; p=p, verbose=verbose, tracks=tracks, kwargs...)
xstar = solve(ZeroProblem(f, x0), M, p′ === nothing ? p : p′; verbose=verbose, tracks=tracks, kwargs...)

isnan(xstar) && throw(ConvergenceFailed("Algorithm failed to converge"))

Expand All @@ -227,8 +228,12 @@ function find_zero_default_method(x0)
T = eltype(float.(_extrema(x0)))
T <: Union{Float16, Float32, Float64} ? Bisection() : A42()
end

find_zero(f, x0; kwargs...) = find_zero(f, x0, find_zero_default_method(x0); kwargs...)

find_zero(f, x0, p; kwargs...) = find_zero(f, x0, find_zero_default_method(x0), p; kwargs...)


## ---------------

## Create an Iterator interface
Expand Down Expand Up @@ -279,7 +284,7 @@ function init(
tracks=NullTracks(),
kwargs...,
)
F = Callable_Function(M, 𝑭𝑿.F, p === nothing ? p′ : p)
F = Callable_Function(M, 𝑭𝑿.F, something(p′, p, missing))
state = init_state(M, F, 𝑭𝑿.x₀)
options = init_options(M, state; kwargs...)
l = Tracks(verbose, tracks, state)
Expand Down
5 changes: 3 additions & 2 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ struct Callable_Function{Single,Tup,F,P}
Single = Val{fn_argout(M)}
Tup = Val{isa(f, Tuple)}
F = typeof(f)
P = typeof(p)
new{Single,Tup,F,P}(f, p)
p′ = ismissing(p) ? nothing : p
P′ = typeof(p′)
new{Single,Tup,F,P′}(f, p′)
end
end

Expand Down
10 changes: 6 additions & 4 deletions src/hybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
function init(
𝑭𝑿::ZeroProblem,
M::AbstractNonBracketingMethod,
N::AbstractBracketingMethod;
N::AbstractBracketingMethod,
p′=nothing;
p=nothing,
verbose::Bool=false,
tracks=NullTracks(),
kwargs...,
)
F = Callable_Function(M, 𝑭𝑿.F, p)
F = Callable_Function(M, 𝑭𝑿.F, something(p′, p, missing))
state = init_state(M, F, 𝑭𝑿.x₀)
options = init_options(M, state; kwargs...)
l = Tracks(verbose, tracks, state)
Expand Down Expand Up @@ -167,10 +168,11 @@ function find_zero(
fs,
x0,
M::AbstractUnivariateZeroMethod,
N::AbstractBracketingMethod;
N::AbstractBracketingMethod,
p′=nothing;
verbose=false,
kwargs...,
)
𝐏 = ZeroProblem(fs, x0)
solve!(init(𝐏, M, N; verbose=verbose, kwargs...), verbose=verbose)
solve!(init(𝐏, M, N, p′; verbose=verbose, kwargs...), verbose=verbose)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("./test_simple.jl")
include("./test_find_zeros.jl")
include("./test_fzero.jl")
include("./test_newton.jl")
include("./test_chain_rules.jl")
include("./test_simple.jl")

include("./test_composable.jl")
Expand Down
26 changes: 26 additions & 0 deletions test/test_chain_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Roots
using Zygote
using Test

# issue #325 add frule, rrule

@testset "Test rrule" begin

# single function
f(x, p) = log(x) - p
F(p) = find_zero(f, 1, Order1(), p)
@test first(Zygote.gradient(F, 1)) exp(1)

g(x, p) = x^2 - p[1]*x - p[2]
G(p) = find_zero(g, 1, Order1(), p)
@test first(Zygote.gradient(G, [0,4])) [1/2, 1/4]

fx(x,p) = 1/x
F2(p) = find_zero((f,fx), 1, Roots.Newton(), p)
@test first(Zygote.gradient(F2, 1)) exp(1)

gp(x, p) = 2x - p[1]
G2(p) = find_zero((g, gp), 1, Roots.Newton(), p)
@test first(Zygote.gradient(G2, [0,4])) [1/2, 1/4]

end

2 comments on commit 1afc30c

@jverzani
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69167

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.5 -m "<description of version>" 1afc30c3daca427aae5843ac1ce3ecf56f43d4d7
git push origin v2.0.5

Please sign in to comment.