Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better inbounds handling and propagation for generated functions #3277

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.22"
Symbolics = "6.22.1"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
18 changes: 12 additions & 6 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +808,8 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(
sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1672,14 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false,
eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(
sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1685,7 +1689,9 @@ function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
steady_state = ofc.steady_state
eval_expression = ofc.eval_expression
eval_module = ofc.eval_module
newofc = ObservedFunctionCache(sys, dict, steady_state, eval_expression, eval_module)
checkbounds = ofc.checkbounds
newofc = ObservedFunctionCache(
sys, dict, steady_state, eval_expression, eval_module, checkbounds)
stackdict[ofc] = newofc
return newofc
end
Expand All @@ -1694,7 +1700,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
6 changes: 4 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down Expand Up @@ -522,7 +523,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

SDEFunction{iip, specialize}(f, g;
sys = sys,
Expand Down
3 changes: 2 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = ObservedFunctionCache(sys)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
6 changes: 4 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down Expand Up @@ -527,7 +528,8 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module,
checkbounds = get(kwargs, :checkbounds, false))
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
Expand Down
6 changes: 4 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

if length(dvs) == length(equations(sys))
resid_prototype = nothing
Expand Down Expand Up @@ -411,7 +412,8 @@ function SciMLBase.IntervalNonlinearFunction(
f(u, p) = f_oop(u, p)
f(u, p::MTKParameters) = f_oop(u, p...)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

IntervalNonlinearFunction{false}(
f; observed = observedfun, sys = sys, initialization_data)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
Expand Down
Loading