Skip to content

Commit

Permalink
Merge pull request #3079 from SebastianM-C/smc/pdeps
Browse files Browse the repository at this point in the history
Relax type constraints to allow callable parameters in pdeps
  • Loading branch information
ChrisRackauckas authored Oct 5, 2024
2 parents 0916a02 + 6f96622 commit 2497d9b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3090,7 +3090,7 @@ function process_parameter_dependencies(pdeps, ps)
end
for p in pdeps]
end
lhss = BasicSymbolic[]
lhss = []
for p in pdeps
if !isparameter(p.lhs)
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
Expand All @@ -3101,6 +3101,7 @@ function process_parameter_dependencies(pdeps, ps)
end
push!(lhss, p.lhs)
end
lhss = map(identity, lhss)
pdeps = topsort_equations(pdeps, union(ps, lhss))
ps = filter(ps) do p
!any(isequal(p), lhss)
Expand Down
5 changes: 3 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct IndexCache
constant_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms::Set{BasicSymbolic}
dependent_pars::Set{BasicSymbolic}
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
Expand Down Expand Up @@ -275,7 +275,8 @@ function IndexCache(sys::AbstractSystem)
end
end

dependent_pars = Set{BasicSymbolic}()
dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()

for eq in parameter_dependencies(sys)
sym = eq.lhs
ttsym = default_toterm(sym)
Expand Down
23 changes: 23 additions & 0 deletions test/parameter_dependencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ end
@test SciMLBase.successful_retcode(sol)
end

struct CallableFoo
p::Any
end

@register_symbolic CallableFoo(x)

(f::CallableFoo)(x) = f.p + x

@testset "callable parameters" begin
@variables y(t) = 1
@parameters p=2 (i::CallableFoo)(..)

eqs = [D(y) ~ i(t) + p]
@named model = ODESystem(eqs, t, [y], [p, i];
parameter_dependencies = [i ~ CallableFoo(p)])
sys = structural_simplify(model)

prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Tsit5())

@test SciMLBase.successful_retcode(sol)
end

@testset "Clock system" begin
dt = 0.1
@variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t)
Expand Down

0 comments on commit 2497d9b

Please sign in to comment.