Skip to content

Commit

Permalink
Merge pull request #3151 from AayushSabharwal/as/homotopy-rational-poly
Browse files Browse the repository at this point in the history
feat: support rational functions in `HomotopyContinuationProblem`
  • Loading branch information
ChrisRackauckas authored Oct 26, 2024
2 parents 7ff50d1 + 7aae63d commit 1266976
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.56.1"
SciMLBase = "2.57.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
100 changes: 87 additions & 13 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MTKHomotopyContinuationExt

using ModelingToolkit
using ModelingToolkit.SciMLBase
using ModelingToolkit.Symbolics: unwrap, symtype
using ModelingToolkit.Symbolics: unwrap, symtype, BasicSymbolic, simplify_fractions
using ModelingToolkit.SymbolicIndexingInterface
using ModelingToolkit.DocStringExtensions
using HomotopyContinuation
Expand All @@ -27,7 +27,7 @@ function is_polynomial(x, wrt)
contains_variable(x, wrt) || return true
any(isequal(x), wrt) && return true

if operation(x) in (*, +, -)
if operation(x) in (*, +, -, /)
return all(y -> is_polynomial(y, wrt), arguments(x))
end
if operation(x) == (^)
Expand Down Expand Up @@ -57,6 +57,57 @@ end
"""
$(TYPEDSIGNATURES)
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.
"""
function handle_rational_polynomials(x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
den *= d
end
else
return x, 1
end
# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
return num / den, 1
end
return num, den
end

"""
$(TYPEDSIGNATURES)
Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`.
"""
function symbolics_to_hc(expr)
Expand Down Expand Up @@ -139,51 +190,74 @@ function MTK.HomotopyContinuationProblem(
dvs = unknowns(sys)
eqs = equations(sys)

for eq in eqs
denoms = []
eqs2 = map(eqs) do eq
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
end
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
push!(denoms, den)
return 0 ~ num
end

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap;
sys2 = MTK.@set sys.eqs = eqs2

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
jac = true, eval_expression, eval_module)

denominator = MTK.build_explicit_observed_function(sys, denoms)

hvars = symbolics_to_hc.(dvs)
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))

obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)

return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
end

"""
$(TYPEDSIGNATURES)
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to
`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl`
will be available in the `.original` field of the returned `NonlinearSolution`.
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
`NonlinearSolution`.
All keyword arguments have their default values in HomotopyContinuation.jl, except
`show_progress` which defaults to `false`.
Extra keyword arguments:
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
the denominator to be below `denominator_abstol` will be discarded.
"""
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
alg = nothing; show_progress = false, kwargs...)
alg = nothing; show_progress = false, denominator_abstol = 1e-8, kwargs...)
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
realsols = HomotopyContinuation.results(sol; only_real = true)
if isempty(realsols)
u = state_values(prob)
resid = prob.homotopy_continuation_system(u)
retcode = SciMLBase.ReturnCode.ConvergenceFailure
else
T = eltype(state_values(prob))
distance, idx = findmin(realsols) do result
if any(<=(denominator_abstol),
prob.denominator(real.(result.solution), parameter_values(prob)))
return T(Inf)
end
norm(result.solution - state_values(prob))
end
u = real.(realsols[idx].solution)
resid = prob.homotopy_continuation_system(u)
retcode = SciMLBase.ReturnCode.Success
# all roots cause denominator to be zero
if isinf(distance)
u = state_values(prob)
retcode = SciMLBase.ReturnCode.Infeasible
else
u = real.(realsols[idx].solution)
retcode = SciMLBase.ReturnCode.Success
end
end
resid = prob.homotopy_continuation_system(u)

return SciMLBase.build_solution(
prob, :HomotopyContinuation, u, resid; retcode, original = sol)
Expand Down
8 changes: 7 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
create and solve.
"""
struct HomotopyContinuationProblem{uType, H, O} <:
struct HomotopyContinuationProblem{uType, H, D, O} <:
SciMLBase.AbstractNonlinearProblem{uType, true}
"""
The initial values of states in the system. If there are multiple real roots of
Expand All @@ -586,6 +586,12 @@ struct HomotopyContinuationProblem{uType, H, O} <:
"""
homotopy_continuation_system::H
"""
A function with signature `(u, p) -> resid`. In case of rational functions, this
is used to rule out roots of the system which would cause the denominator to be
zero.
"""
denominator::D
"""
The `NonlinearSystem` used to create this problem. Used for symbolic indexing.
"""
sys::NonlinearSystem
Expand Down
44 changes: 42 additions & 2 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,47 @@ end
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([((x^2) / (x + 3))^2 + x ~ 0])
@test_warn ["Base", "not a polynomial", "Unrecognized operation", "/"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
end

@testset "Rational functions" begin
@variables x=2.0 y=2.0
@parameters n = 4
@mtkbuild sys = NonlinearSystem([
0 ~ (x^2 - n * x + n) * (x - 1) / (x - 2) / (x - 3)
])
prob = HomotopyContinuationProblem(sys, [])
sol = solve(prob; threading = false)
@test sol[x] 1.0
p = parameter_values(prob)
for invalid in [2.0, 3.0]
@test prob.denominator([invalid], p)[1] <= 1e-8
end

@named sys = NonlinearSystem(
[
0 ~ (x - 2) / (x - 4) + ((x - 3) / (y - 7)) / ((x^2 - 4x + y) / (x - 2.5)),
0 ~ ((y - 3) / (y - 4)) * (n / (y - 5)) + ((x - 1.5) / (x - 5.5))^2
],
[x, y],
[n])
sys = complete(sys)
prob = HomotopyContinuationProblem(sys, [])
sol = solve(prob; threading = false)
disallowed_x = [4, 5.5]
disallowed_y = [7, 5, 4]
@test all(!isapprox(sol[x]; atol = 1e-8), disallowed_x)
@test all(!isapprox(sol[y]; atol = 1e-8), disallowed_y)
@test sol[x^2 - 4x + y] >= 1e-8

p = parameter_values(prob)
for val in disallowed_x
@test any(<=(1e-8), prob.denominator([val, 2.0], p))
end
for val in disallowed_y
@test any(<=(1e-8), prob.denominator([2.0, val], p))
end
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8
end

0 comments on commit 1266976

Please sign in to comment.