Skip to content

Commit

Permalink
Merge pull request #3123 from isaacsas/dispatch_collect_vars
Browse files Browse the repository at this point in the history
add trait and dispatch for collect_vars!
  • Loading branch information
ChrisRackauckas authored Oct 15, 2024
2 parents 28a5af3 + a8c0930 commit 1f53f6a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
9 changes: 3 additions & 6 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...)
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
Expand All @@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], iv)
collect_vars!(allunknowns, ps, eq[2], iv)
collect_vars!(allunknowns, ps, eq, iv)
else
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
end
end
for ssys in get(kwargs, :systems, ODESystem[])
Expand Down
9 changes: 3 additions & 6 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
ps = OrderedSet()
iv = value(iv)
for eq in eqs
collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift)
collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift)
collect_vars!(allunknowns, ps, eq, iv; op = Shift)
if iscall(eq.lhs) && operation(eq.lhs) isa Shift
isequal(iv, operation(eq.lhs).t) ||
throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
Expand All @@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], iv)
collect_vars!(allunknowns, ps, eq[2], iv)
collect_vars!(allunknowns, ps, eq, iv)
else
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
end
end
new_ps = OrderedSet()
Expand Down
9 changes: 3 additions & 6 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...)
allunknowns = OrderedSet()
ps = OrderedSet()
for eq in eqs
collect_vars!(allunknowns, ps, eq.lhs, nothing)
collect_vars!(allunknowns, ps, eq.rhs, nothing)
collect_vars!(allunknowns, ps, eq, nothing)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], nothing)
collect_vars!(allunknowns, ps, eq[2], nothing)
collect_vars!(allunknowns, ps, eq, nothing)
else
collect_vars!(allunknowns, ps, eq.lhs, nothing)
collect_vars!(allunknowns, ps, eq.rhs, nothing)
collect_vars!(allunknowns, ps, eq, nothing)
end
end
new_ps = OrderedSet()
Expand Down
38 changes: 30 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,20 +492,19 @@ recursively searches through all subsystems of `sys`, increasing the depth if it
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
if has_eqs(sys)
for eq in get_eqs(sys)
eq isa Equation || continue
eq.lhs isa Union{Symbolic, Number} || continue
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
eqtype_supports_collect_vars(eq) || continue
if eq isa Equation
eq.lhs isa Union{Symbolic, Number} || continue
end
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
if has_parameter_dependencies(sys)
for eq in get_parameter_dependencies(sys)
if eq isa Pair
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
collect_vars!(unknowns, parameters, eq, iv; depth, op)
else
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
end
Expand All @@ -529,6 +528,29 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different
return nothing
end

"""
$(TYPEDSIGNATURES)
Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`.
Can be dispatched by higher-level libraries to indicate support.
"""
eqtype_supports_collect_vars(eq) = false
eqtype_supports_collect_vars(eq::Equation) = true
eqtype_supports_collect_vars(eq::Pair) = true

function collect_vars!(unknowns, parameters, eq::Equation, iv;
depth = 0, op = Differential)
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
return nothing
end

function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential)
collect_vars!(unknowns, parameters, p[1], iv; depth, op)
collect_vars!(unknowns, parameters, p[2], iv; depth, op)
return nothing
end

function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down

0 comments on commit 1f53f6a

Please sign in to comment.