Skip to content

Commit

Permalink
Merge branch 'master' into as/symarray-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Oct 24, 2024
2 parents 1025363 + ba519bf commit 3f7d45a
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ MTKLabelledArraysExt = "LabelledArrays"
[compat]
AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.3"
BifurcationKit = "0.4"
BlockArrays = "1.1"
ChainRulesCore = "1"
Combinatorics = "1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
BenchmarkTools = "1.3"
BifurcationKit = "0.3"
BifurcationKit = "0.4"
DataInterpolations = "6.5"
DifferentialEquations = "7.6"
Distributions = "0.25"
Expand Down
4 changes: 2 additions & 2 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
# If the plot var is a normal state.
if any(isequal(plot_var, var) for var in unknowns(nsys))
plot_idx = findfirst(isequal(plot_var), unknowns(nsys))
record_from_solution = (x, p) -> x[plot_idx]
record_from_solution = (x, p; k...) -> x[plot_idx]

# If the plot var is an observed state.
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
Expand All @@ -132,7 +132,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
return BifurcationKit.BifurcationProblem(F,
u0_bif_vals,
p_vals,
(@lens _[bif_idx]),
(BifurcationKit.@optic _[bif_idx]),
args...;
record_from_solution = record_from_solution,
J = J,
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,8 @@ function build_explicit_observed_function(sys, ts;
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs) .∘
mtkparams_wrapper,
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)
Expand Down
30 changes: 26 additions & 4 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,24 @@ function to_varmap(vals, varlist::Vector)
check_eqs_u0(varlist, varlist, vals)
vals = vec(varlist) .=> vec(vals)
end
return anydict(unwrap(k) => unwrap(v) for (k, v) in anydict(vals))
return recursive_unwrap(anydict(vals))
end

"""
$(TYPEDSIGNATURES)
Recursively call `Symbolics.unwrap` on `x`. Useful when `x` is an array of (potentially)
symbolic values, all of which need to be unwrapped. Specializes when `x isa AbstractDict`
to unwrap keys and values, returning an `AnyDict`.
"""
function recursive_unwrap(x::AbstractArray)
symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x)
end

recursive_unwrap(x) = unwrap(x)

function recursive_unwrap(x::AbstractDict)
return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x)
end

"""
Expand Down Expand Up @@ -262,7 +279,7 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
end
vals = map(x -> varmap[x], vars)

if container_type <: Union{AbstractDict, Tuple, Nothing}
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
container_type = Array
end

Expand Down Expand Up @@ -410,7 +427,7 @@ function process_SciMLProblem(
u0map = to_varmap(u0map, dvs)
_pmap = pmap
pmap = to_varmap(pmap, ps)
defs = add_toterms(defaults(sys))
defs = add_toterms(recursive_unwrap(defaults(sys)))
cmap, cs = get_cmap(sys)
kwargs = NamedTuple(kwargs)

Expand All @@ -433,9 +450,14 @@ function process_SciMLProblem(
solvablepars = [p
for p in parameters(sys)
if is_parameter_solvable(p, pmap, defs, guesses)]
has_dependent_unknowns = any(unknowns(sys)) do sym
val = get(op, sym, nothing)
val === nothing && return false
return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val)
end
if build_initializeprob &&
(((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) ||
!isempty(solvablepars)) &&
!isempty(solvablepars) || has_dependent_unknowns) &&
get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys))) && t !== nothing
initializeprob = ModelingToolkit.InitializationProblem(
Expand Down
4 changes: 2 additions & 2 deletions test/extensions/bifurcationkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ let
bprob_BK = BifurcationProblem(f_BK,
[1.0, 1.0],
[-1.0, 1.0],
(@lens _[1]);
record_from_solution = (x, p) -> x[1])
(BifurcationKit.@optic _[1]);
record_from_solution = (x, p; k...) -> x[1])
bif_dia_BK = bifurcationdiagram(bprob_BK,
PALC(),
2,
Expand Down
17 changes: 17 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,20 @@ end
prob = ODEProblem(sys, [], (1.0, 2.0), [])
@test prob[x] == 1.0
@test prob.ps[p] == 2.0

@testset "Array of symbolics is unwrapped" begin
@variables x(t)[1:2] y(t)
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]])
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
@test eltype(prob.u0) <: Float64
prob = ODEProblem(sys, [x => [y, 4.0], y => 2.0], (0.0, 1.0))
@test eltype(prob.u0) <: Float64
end

@testset "split=false systems with all parameter defaults" begin
@variables x(t) = 1.0
@parameters p=1.0 q=2.0 r=3.0
@mtkbuild sys=ODESystem(D(x) ~ p * x + q * t + r, t) split=false
prob = @test_nowarn ODEProblem(sys, [], (0.0, 1.0))
@test prob.p isa Vector{Float64}
end
16 changes: 16 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,19 @@ end
isys = ModelingToolkit.generate_initializesystem(sys)
@test isequal(defaults(isys)[y], 2x + 1)
end

@testset "Create initializeprob when unknown has dependent value" begin
@variables x(t) y(t)
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t * y], t; defaults = [x => 2y])
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
@test prob.f.initializeprob !== nothing
integ = init(prob)
@test integ[x] 2.0

@variables x(t)[1:2] y(t)
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ t], t; defaults = [x => [y, 3.0]])
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0))
@test prob.f.initializeprob !== nothing
integ = init(prob)
@test integ[x] [1.0, 3.0]
end
12 changes: 12 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1455,3 +1455,15 @@ end
@test length(unknowns(sys)) == 2
@test any(isequal(y), unknowns(sys))
end

@testset "Inplace observed" begin
@variables x(t)
@parameters p[1:2] q
@mtkbuild sys = ODESystem(D(x) ~ sum(p) * x + q * t, t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => ones(2), q => 2])
obsfn = ModelingToolkit.build_explicit_observed_function(
sys, [p..., q], return_inplace = true)[2]
buf = zeros(3)
obsfn(buf, prob.u0, prob.p, 0.0)
@test buf [1.0, 1.0, 2.0]
end

0 comments on commit 3f7d45a

Please sign in to comment.