Skip to content

Commit

Permalink
Merge pull request #3277 from AayushSabharwal/as/optional-inbounds
Browse files Browse the repository at this point in the history
feat: better inbounds handling and propagation for generated functions
  • Loading branch information
ChrisRackauckas authored Jan 9, 2025
2 parents 113aec7 + ae5841a commit 961e25a
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.22"
Symbolics = "6.22.1"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
18 changes: 12 additions & 6 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +808,8 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(
sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1672,14 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false,
eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(
sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1685,7 +1689,9 @@ function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
steady_state = ofc.steady_state
eval_expression = ofc.eval_expression
eval_module = ofc.eval_module
newofc = ObservedFunctionCache(sys, dict, steady_state, eval_expression, eval_module)
checkbounds = ofc.checkbounds
newofc = ObservedFunctionCache(
sys, dict, steady_state, eval_expression, eval_module, checkbounds)
stackdict[ofc] = newofc
return newofc
end
Expand All @@ -1694,7 +1700,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
6 changes: 4 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down Expand Up @@ -522,7 +523,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

SDEFunction{iip, specialize}(f, g;
sys = sys,
Expand Down
3 changes: 2 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = ObservedFunctionCache(sys)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
6 changes: 4 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down Expand Up @@ -527,7 +528,8 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module,
checkbounds = get(kwargs, :checkbounds, false))
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
Expand Down
6 changes: 4 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

if length(dvs) == length(equations(sys))
resid_prototype = nothing
Expand Down Expand Up @@ -411,7 +412,8 @@ function SciMLBase.IntervalNonlinearFunction(
f(u, p) = f_oop(u, p)
f(u, p::MTKParameters) = f_oop(u, p...)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

IntervalNonlinearFunction{false}(
f; observed = observedfun, sys = sys, initialization_data)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
Expand Down

0 comments on commit 961e25a

Please sign in to comment.