Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop scalarizing #1052

Merged
merged 14 commits into from
Sep 21, 2024
7 changes: 3 additions & 4 deletions src/network_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,16 +657,15 @@ function cache_conservationlaw_eqs!(rn::ReactionSystem, N::AbstractMatrix, col_o
indepspecs = sts[indepidxs]
depidxs = col_order[(r + 1):end]
depspecs = sts[depidxs]
constants = MT.unwrap.(MT.scalarize(only(
@parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] [conserved = true])))
constants = MT.unwrap(only(
@parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] [conserved = true]))

conservedeqs = Equation[]
constantdefs = Equation[]
for (i, depidx) in enumerate(depidxs)
scaleby = (N[i, depidx] != 1) ? N[i, depidx] : one(eltype(N))
(scaleby != 0) || error("Error, found a zero in the conservation law matrix where "
*
"one was not expected.")
* "one was not expected.")
coefs = @view N[i, indepidxs]
terms = sum(p -> p[1] / scaleby * p[2], zip(coefs, indepspecs))
eq = depspecs[i] ~ constants[i] - terms
Expand Down
27 changes: 21 additions & 6 deletions src/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ function ReactionSystem(eqs, iv, unknowns, ps;
sivs′ = if spatial_ivs === nothing
Vector{typeof(iv′)}()
else
value.(MT.scalarize(spatial_ivs))
value.(spatial_ivs)
end
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies)
unknowns′ = sort!(value.(unknowns), by = !isspecies)
spcs = filter(isspecies, unknowns′)
ps′ = value.(MT.scalarize(ps))
ps′ = value.(ps)

# Checks that no (by Catalyst) forbidden symbols are used.
allsyms = Iterators.flatten((ps′, unknowns′))
Expand Down Expand Up @@ -467,7 +467,7 @@ end
# Two-argument constructor (reactions/equations and time variable).
# Calls the `make_ReactionSystem_internal`, which in turn calls the four-argument constructor.
function ReactionSystem(rxs::Vector, iv = Catalyst.DEFAULT_IV; kwargs...)
make_ReactionSystem_internal(rxs, iv, Vector{Num}(), Vector{Num}(); kwargs...)
make_ReactionSystem_internal(rxs, iv, [], []; kwargs...)
end

# One-argument constructor. Creates an emtoy `ReactionSystem` from a time independent variable only.
Expand All @@ -494,7 +494,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
t = value(iv)
ivs = Set([t])
if (spatial_ivs !== nothing)
for siv in (MT.scalarize(spatial_ivs))
for siv in (spatial_ivs)
push!(ivs, value(siv))
end
end
Expand Down Expand Up @@ -548,7 +548,22 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;

# Converts the found unknowns and parameters to vectors.
usv = collect(us)
psv = collect(ps)

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end
psv = collect(new_ps)

# Passes the processed input into the next `ReactionSystem` call.
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events,
Expand Down
23 changes: 9 additions & 14 deletions test/reactionsystem_core/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rxs = [Reaction(k[1], nothing, [A]), # 0 -> A
Reaction(k[19] * t, [A], [B]), # A -> B with non constant rate.
Reaction(k[20] * t * A, [B, C], [D], [2, 1], [2]), # 2A +B -> 2C with non constant rate.
]
@named rs = ReactionSystem(rxs, t, [A, B, C, D], k)
@named rs = ReactionSystem(rxs, t, [A, B, C, D], [k])
rs = complete(rs)
odesys = complete(convert(ODESystem, rs))
sdesys = complete(convert(SDESystem, rs))
Expand Down Expand Up @@ -109,11 +109,12 @@ end

# Defaults test.
let
def_p = [ki => float(i) for (i, ki) in enumerate(k)]
kvals = Float64.(1:length(k))
def_p = [k => kvals]
def_u0 = [A => 0.5, B => 1.0, C => 1.5, D => 2.0]
defs = merge(Dict(def_p), Dict(def_u0))

@named rs = ReactionSystem(rxs, t, [A, B, C, D], k; defaults = defs)
@named rs = ReactionSystem(rxs, t, [A, B, C, D], [k]; defaults = defs)
rs = complete(rs)
odesys = complete(convert(ODESystem, rs))
sdesys = complete(convert(SDESystem, rs))
Expand All @@ -126,15 +127,11 @@ let
defs

u0map = [A => 5.0]
pmap = [k[1] => 5.0]
kvals[1] = 5.0
pmap = [k => kvals]
prob = ODEProblem(rs, u0map, (0, 10.0), pmap)
@test prob.ps[k[1]] == 5.0
@test prob.u0[1] == 5.0
u0 = [10.0, 11.0, 12.0, 13.0]
ps = [float(x) for x in 100:119]
prob = ODEProblem(rs, u0, (0, 10.0), ps)
@test [prob.ps[k[i]] for i in 1:20] == ps
@test prob.u0 == u0
Comment on lines -133 to -137
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not appropriate inputs since they aren't mappings, hence I removed them.

end

### Check ODE, SDE, and Jump Functions ###
Expand Down Expand Up @@ -868,11 +865,9 @@ end
let
@species (A(t))[1:20]
using ModelingToolkit: value
@test isspecies(value(A))
@test isspecies(value(A[2]))
Av = value.(ModelingToolkit.scalarize(A))
@test isspecies(Av[2])
@test isequal(value(Av[2]), value(A[2]))
Av = value(A)
@test isspecies(Av)
@test all(i -> isspecies(Av[i]), 1:length(Av))
end

# Test mixed models are formulated correctly.
Expand Down
Loading