Skip to content

Commit

Permalink
refactor: use process_SciMLProblem in jumpsystem.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 16, 2024
1 parent 03e294f commit 1861727
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ function Base.eltype(::Type{<:TreeIterator{ModelingToolkit.AbstractSystem}})
end

function check_array_equations_unknowns(eqs, dvs)
if any(eq -> Symbolics.isarraysymbolic(eq.lhs), eqs)
if any(eq -> eq isa Equation && Symbolics.isarraysymbolic(eq.lhs), eqs)
throw(ArgumentError("The system has array equations. Call `structural_simplify` to handle such equations or scalarize them manually."))
end
if any(x -> Symbolics.isarraysymbolic(x), dvs)
Expand Down
41 changes: 6 additions & 35 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
Expand Down Expand Up @@ -399,16 +387,9 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
end
dvs = unknowns(sys)
ps = parameters(sys)
defs = defaults(sys)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
# identity function to make syms works
quote
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
Expand Down Expand Up @@ -454,19 +435,9 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

Expand Down
12 changes: 12 additions & 0 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,18 @@ function get_temporary_value(p)
end
end

"""
$(TYPEDEF)
A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in
case constructing a SciMLFunction is not required.
"""
struct EmptySciMLFunction end

function EmptySciMLFunction(args...; kwargs...)
return nothing
end

"""
$(TYPEDSIGNATURES)
Expand Down
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,15 @@ function collect_constants!(constants, expr::Symbolic)
end
end

function collect_constants!(constants, expr::Union{ConstantRateJump, VariableRateJump})
collect_constants!(constants, expr.rate)
collect_constants!(constants, expr.affect!)
end

function collect_constants!(constants, ::MassActionJump)
return constants
end

"""
Replace symbolic constants with their literal values
"""
Expand Down Expand Up @@ -667,7 +676,7 @@ end

function get_cmap(sys, exprs = nothing)
#Inject substitutions for constants => values
cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else?
cs = collect_constants([collect(get_eqs(sys)); get_observed(sys)]) #ctrls? what else?
if !empty_substitutions(sys)
cs = [cs; collect_constants(get_substitutions(sys).subs)]
end
Expand Down

0 comments on commit 1861727

Please sign in to comment.