Skip to content

Commit

Permalink
Merge pull request #98 from ikirill/gh-93
Browse files Browse the repository at this point in the history
Fix Alefeld-Potra-Shi to cache function values properly.
  • Loading branch information
jverzani authored Jan 4, 2018
2 parents 4769213 + 665508b commit ec3880e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 87 deletions.
166 changes: 90 additions & 76 deletions src/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,23 +317,23 @@ By John Travers
"""
function a42(f, a, b;
xtol=zero(float(a)),
maxeval::Int=15,
verbose::Bool=false
)
xtol=zero(float(a)),
maxeval::Int=15,
verbose::Bool=false)
if a > b
a,b = b,a
end
fa, fb = f(a), f(b)

if a >= b || sign(f(a))*sign(f(b)) >= 0
if a >= b || sign(fa)*sign(fb) >= 0
error("on input a < b and f(a)f(b) < 0 must both hold")
end
if xtol < 0.0
error("tolerance must be >= 0.0")
end

c = secant(f, a, b)
a42a(f, a, b, c,
c, fc = secant(f, a, fa, b, fb)
a42a(f, a, fa, b, fb, c, fc,
xtol=xtol, maxeval=maxeval, verbose=verbose)
end

Expand All @@ -345,51 +345,61 @@ where `c` is passed in.
Solve f(x) = 0 over bracketing interval [a,b] starting at c, with a < c < b
"""
function a42a(f, a, b, c=(0.5)*(a+b);
xtol=zero(float(a)),
maxeval::Int=15,
verbose::Bool=false
)
a42a(f, a, b, c=(a+b)/2; args...) = a42a(f, a, f(a), b, f(b), c, f(c); args...)

function a42a(f, a, fa, b, fb, c, fc;
xtol=zero(float(a)),
maxeval::Int=15,
verbose::Bool=false)

try
# re-bracket and check termination
a, b, d = bracket(f, a, b, c, xtol)
ee = d
a, fa, b, fb, d, fd = bracket(f, a, fa, b, fb, c, fc, xtol)
ee, fee = d, fd
for n = 2:maxeval
# use either a cubic (if possible) or quadratic interpolation
if n > 2 && distinct(f, a, b, d, ee)
c = ipzero(f, a, b, d, ee)
if n > 2 && distinct(f, a, fa, b, fb, d, fd, ee, fee)
c, fc = ipzero(f, a, fa, b, fb, d, fd, ee, fee)
else
c = newton_quadratic(f, a, b, d, 2)
c, fc = newton_quadratic(f, a, fa, b, fb, d, fd, 2)
end
# re-bracket and check termination
ab, bb, db = bracket(f, a, b, c, xtol)
eb = d
ab, fab, bb, fbb, db, fdb = bracket(f, a, fa, b, fb, c, fc, xtol)
eb, feb = d, fd
# use another cubic (if possible) or quadratic interpolation
if distinct(f, ab, bb, db, eb)
cb = ipzero(f, ab, bb, db, eb)
if distinct(f, ab, fab, bb, fbb, db, fdb, eb, feb)
cb, fcb = ipzero(f, ab, fab, bb, fbb, db, fdb, eb, feb)
else
cb = newton_quadratic(f, ab, bb, db, 3)
cb, fcb = newton_quadratic(f, ab, fab, bb, fbb, db, fdb, 3)
end
# re-bracket and check termination
ab, bb, db = bracket(f, ab, bb, cb, xtol)
ab, fab, bb, fbb, db, fdb = bracket(f, ab, fab, bb, fbb, cb, fcb, xtol)
# double length secant step; if we fail, use bisection
u = abs(f(ab)) < abs(f(bb)) ? ab : bb
cb = u - 2*f(u)/(f(bb) - f(ab))*(bb - ab)
ch = abs(cb - u) > (bb - ab)/2 ? ab + (bb - ab)/2 : cb
if abs(fab) < abs(fbb)
u, fu = ab, fab
else
u, fu = bb, fbb
end
# u = abs(fab) < abs(fbb) ? ab : bb
cb = u - 2*fu/(fbb - fab)*(bb - ab)
fcb = f(cb)
if abs(cb - u) > (bb - ab)/2
ch, fch = ab+(bb-ab)/2, f(ab+(bb-ab)/2)
else
ch, fch = cb, fcb
end
# ch = abs(cb - u) > (bb - ab)/2 ? ab + (bb - ab)/2 : cb
# re-bracket and check termination
ah, bh, dh = bracket(f, ab, bb, ch, xtol)
ah, fah, bh, fbh, dh, fdh = bracket(f, ab, fab, bb, fbb, ch, fch, xtol)
# if not converging fast enough bracket again on a bisection
if bh - ah < 0.5*(b - a)
a = ah
b = bh
d = dh
ee = db
a, fa = ah, fah
b, fb = bh, fbh
d, fd = dh, fdh
ee, fee = db, fdb
else
ee = dh
a, b, d = bracket(f, ah, bh, ah + (bh - ah)/2,
xtol)
ee, fee = dh, fdh
a, fa, b, fb, d, fd = bracket(f, ah, fah, bh, fbh, ah + (bh - ah)/2, f(ah+(bh-ah)/2), xtol)
end

verbose && println(@sprintf("a=%18.15f, n=%s", float(a), n))
Expand Down Expand Up @@ -441,87 +451,90 @@ end
# root.
#
# based on algorithm on page 341 of [1]
function bracket(f, a, b, c, tol)

fa = f(a)
fb = f(b)

function bracket(f, a, fa, b, fb, c, fc, tol)
# @assert fa == f(a)
# @assert fb == f(b)
# @assert fc == f(c)
if !(a <= c <= b)
error("c must be in (a,b)")
end
delta = 0.7*tole(a, b, fa, fb, tol)
if b - a <= 4delta
c = (a + b)/2
fc = f(c)
elseif c <= a + 2delta
c = a + 2delta
fc = f(c)
elseif c >= b - 2delta
c = b - 2delta
fc = f(c)
end
fc = f(c)
if fc == 0
throw(StateConverged(c))
throw(Roots.StateConverged(c))
elseif sign(fa)*sign(fc) < 0
aa = a
bb = c
db = b
aa, faa = a, fa
bb, fbb = c, fc
db, fdb = b, fb
else
aa = c
bb = b
db = a
aa, faa = c, fc
bb, fbb = b, fb
db, fdb = a, fa
end
faa = f(aa)
fbb = f(bb)
if bb - aa < 2*tole(aa, bb, faa, fbb, tol)
x0 = abs(faa) < abs(fbb) ? aa : bb
throw(StateConverged(x0))
throw(Roots.StateConverged(x0))
end
aa, bb, db
aa, faa, bb, fbb, db, fdb
end


# take a secant step, if the resulting guess is very close to a or b, then
# use bisection instead
function secant{T}(f, a::T, b)
c = a - f(a)/(f(b) - f(a))*(b - a)
function secant{T}(f, a::T, fa, b, fb)
# @assert fa == f(a)
# @assert fb == f(b)
c = a - fa/(fb - fa)*(b - a)
tol = eps(T)*5
if isnan(c) || c <= a + abs(a)*tol || c >= b - abs(b)*tol
return a + (b - a)/2
return a + (b - a)/2, f(a+(b-a)/2)
end
return c
return c, f(c)
end


# approximate zero of f using quadratic interpolation
# if the new guess is outside [a, b] we use a secant step instead
# based on algorithm on page 330 of [1]
function newton_quadratic(f, a, b, d, k::Int)
fa = f(a)
fb = f(b)
fd = f(d)
function newton_quadratic(f, a, fa, b, fb, d, fd, k::Int)
# @assert fa == f(a)
# @assert fb == f(b)
# @assert fd == f(d)
B = (fb - fa)/(b - a)
A = ((fd - fb)/(d - b) - B)/(d - a)
if A == 0
return secant(f, a, b)
return secant(f, a, fa, b, fb)
end
r = A*fa > 0 ? a : b
for i = 1:k
r -= (fa + (B + A*(r - b))*(r - a))/(B + A*(2*r - a - b))
end
if isnan(r) || (r <= a || r >= b)
r = secant(f, a, b)
r, fr = secant(f, a, fa, b, fb)
else
fr = f(r)
end
return r
return r, fr
end


# approximate zero of f using inverse cubic interpolation
# if the new guess is outside [a, b] we use a quadratic step instead
# based on algorithm on page 333 of [1]
function ipzero(f, a, b, c, d)
fa = f(a)
fb = f(b)
fc = f(c)
fd = f(d)
function ipzero(f, a, fa, b, fb, c, fc, d, fd)
# @assert fa == f(a)
# @assert fb == f(b)
# @assert fc == f(c)
# @assert fd == f(d)
Q11 = (c - d)*fc/(fd - fc)
Q21 = (b - c)*fb/(fc - fb)
Q31 = (a - b)*fa/(fb - fa)
Expand All @@ -532,32 +545,33 @@ function ipzero(f, a, b, c, d)
D32 = (D31 - Q21)*fc/(fc - fa)
Q33 = (D32 - Q22)*fa/(fd - fa)
c = a + (Q31 + Q32 + Q33)
fc = f(c)
if (c <= a) || (c >= b)
c = newton_quadratic(f, a, b, d, 3)
c, fc = newton_quadratic(f, a, fa, b, fb, d, fd, 3)
end
return c
return c, fc
end


# floating point comparison function
function almost_equal(x, y)
# FIXME This should be eps(T), why is this Float64?
# FIXME Also, realmin is 1e-308, it's too tiny to be useful here.
const min_diff = realmin(Float64)*32
abs(x - y) < min_diff
end


# check that all interpolation values are distinct
function distinct(f, a, b, d, e)
f1 = f(a)
f2 = f(b)
f3 = f(d)
f4 = f(e)
function distinct(f, a, f1, b, f2, d, f3, e, f4)
# @assert f1 == f(a)
# @assert f2 == f(b)
# @assert f3 == f(d)
# @assert f4 == f(e)
!(almost_equal(f1, f2) || almost_equal(f1, f3) || almost_equal(f1, f4) ||
almost_equal(f2, f3) || almost_equal(f2, f4) || almost_equal(f3, f4))
end



## ----------------------------

"""
Expand Down
23 changes: 12 additions & 11 deletions test/RootTesting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ known_functions = Func[]

## Construct a function object, and check root brackets
macro Func(name)
quote
f = Func($name, val, bracket, params)
for p in params
b = bracket(p)
@assert val(p, b[1]) * f.val(p, b[2]) < 0 "Invalid bracket"
@gensym f p b
esc(quote
$f = Func($name, val, bracket, params)
for $p in params
$b = bracket($p)
@assert val($p, $b[1]) * $f.val($p, $b[2]) < 0 "Invalid bracket"
end
push!(known_functions, f)
f
end
push!(known_functions, $f)
$f
end)
end

func1 = let
Expand Down Expand Up @@ -200,7 +201,7 @@ end
function run_benchmark_tests()
@printf "%s\n" run_tests((f,b) -> Roots.a42(f, b...), name="a42")
@printf "%s\n" run_tests((f,b) -> find_zero(f, b, Bisection()), name="Bisection")


for m in [Order0(), Order1(), Order2(), Order5(), Order8(), Order16()]
@printf "%s\n" run_tests((f, b) -> find_zero(f, mean(b), m), name="$m")
Expand All @@ -211,8 +212,8 @@ function run_benchmark_tests()
@printf "%s\n" run_tests((f, b) -> halley(f, mean(b)), name="halley")

println("---- using BigFloat ----")
@printf "%s\n" run_tests((f,b) -> find_zero(f, big(b), Bisection()), name="a42 (no bisection with Big values)")

@printf "%s\n" run_tests((f,b) -> find_zero(f, big(b), Bisection()), name="a42 (no bisection with Big values)")

for m in [Order0(), Order1(), Order2(), Order5(), Order8(), Order16()]
@printf "%s\n" run_tests((f, b) -> find_zero(f, mean(big(b)), m), name="$m/BigFloat")
Expand Down

0 comments on commit ec3880e

Please sign in to comment.