diff --git a/.gitignore b/.gitignore index f3fc3a6b4..8b8e929b3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ *.jl.*.cov *.jl.mem Manifest.toml -/docs/build/ \ No newline at end of file +/docs/build/ +.vscode diff --git a/.vscode/settings.json b/.vscode/settings.json index 9e26dfeeb..3f293ca70 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1 +1,3 @@ -{} \ No newline at end of file +{ + "julia.environmentPath": "/Users/alexcohen/.julia/dev/SciMLSensitivity" +} \ No newline at end of file diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 57c542c34..2e1974640 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -364,7 +364,6 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad, W) where {TS <: SensitivityFunction} @unpack sensealg = S f = unwrapped_f(S.f) - if inplace_sensitivity(S) if W === nothing _dy, back = Tracker.forward(y, p) do u, p @@ -422,7 +421,6 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg @unpack sensealg = S prob = getprob(S) f = unwrapped_f(S.f) - if p isa DiffEqBase.NullParameters _p = similar(y, (0,)) else @@ -539,7 +537,6 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, @unpack sensealg = S prob = getprob(S) f = unwrapped_f(S.f) - if inplace_sensitivity(S) if W === nothing _dy, back = Zygote.pullback(y, p) do u, p @@ -604,7 +601,6 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, @unpack sensealg = S prob = getprob(S) f = unwrapped_f(S.f) - if W === nothing _dy, back = Zygote.pullback(y, p) do u, p vec(f(u, p, t)) @@ -639,7 +635,6 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, W) where {TS <: SensitivityFunction} @unpack sensealg = S f = unwrapped_f(S.f) - prob = getprob(S) _tmp1, tmp2, _tmp3, _tmp4, _tmp5 = S.diffcache.paramjac_config diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index c9417a4e1..14942c551 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -25,6 +25,7 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, checkpoint_sol::CPS prob::pType f::fType + noiseterm::Bool GaussInt::GaussIntegrand end @@ -41,6 +42,7 @@ end function ODEGaussAdjointSensitivityFunction(g, sensealg, gaussint, discrete, sol, dgdu, dgdp, f, alg, checkpoints, tols, tstops = nothing; + noiseterm = false, tspan = reverse(sol.prob.tspan)) checkpointing = ischeckpointing(sensealg, sol) @@ -52,28 +54,57 @@ function ODEGaussAdjointSensitivityFunction(g, sensealg, gaussint, discrete, sol tspan[1] > interval_end && push!(intervals, (interval_end, tspan[1])) cursor = lastindex(intervals) interval = intervals[cursor] - if tstops === nothing - cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), - sol.alg; dense=true, tols...) + if typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + # replicated noise + _sol = deepcopy(sol) + idx1 = searchsortedfirst(_sol.W.t, interval[1] - 1000eps(interval[1])) + if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess + sol.W.save_everystep = false + _sol.W.save_everystep = false + forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, indx = idx1) + elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid + #idx2 = searchsortedfirst(_sol.W.t, interval[2]+1000eps(interval[1])) + forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx1:end], + _sol.W.W[idx1:end]) + else + error("NoiseProcess type not implemented.") + end + dt = choose_dt((_sol.W.t[idx1] - _sol.W.t[idx1 + 1]), _sol.W.t, interval) + + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]), + noise = forwardnoise), + sol.alg, save_noise = false; dt = dt, tstops = _sol.t[idx1:end], + tols...) + + #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]), + # noise = forwardnoise), + # sol.alg, save_noise = false; dt = dt, dense=true, + # tols...) gaussint.sol = cpsol else - if maximum(interval[1] .< tstops .< interval[2]) - # callback might have changed p - _p = Gaussreset_p(sol.prob.kwargs[:callback], interval) - #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), - # tstops = tstops, - # p = _p, sol.alg; tols...) - - cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), - dense=true, - p = _p, sol.alg; tols...) - gaussint.sol = cpsol - else - #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), - # tstops = tstops, sol.alg; tols...) + if tstops === nothing cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), sol.alg; dense=true, tols...) gaussint.sol = cpsol + else + if maximum(interval[1] .< tstops .< interval[2]) + # callback might have changed p + _p = Gaussreset_p(sol.prob.kwargs[:callback], interval) + #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + # tstops = tstops, + # p = _p, sol.alg; tols...) + + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + dense=true, + p = _p, sol.alg; tols...) + gaussint.sol = cpsol + else + #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + # tstops = tstops, sol.alg; tols...) + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1])), + sol.alg; dense=true, tols...) + gaussint.sol = cpsol + end end end GaussCheckpointSolution(cpsol, intervals, cursor, tols, tstops) @@ -81,9 +112,9 @@ function ODEGaussAdjointSensitivityFunction(g, sensealg, gaussint, discrete, sol nothing end diffcache, y = adjointdiffcache(g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; - quad = true) + quad = true, noiseterm = noiseterm) return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete, - y, sol, checkpoint_sol, sol.prob, f, gaussint) + y, sol, checkpoint_sol, sol.prob, f, noiseterm, gaussint) end function Gaussfindcursor(intervals, t) @@ -98,7 +129,20 @@ function (S::ODEGaussAdjointSensitivityFunction)(du, u, p, t) #f = sol.prob.f λ, grad, y, dλ, dgrad, dy = split_states(du, u, t, S) - vecjacobian!(dλ, y, λ, p, t, S) + #vecjacobian!(dλ, y, λ, p, t, S) + if S.noiseterm + if length(u) == length(du) + vecjacobian!(dλ, y, λ, p, t, S) + elseif length(u) != length(du) && StochasticDiffEq.is_diagonal_noise(prob) && + !isnoisemixing(S.sensealg) + vecjacobian!(dλ, y, λ, p, t, S) + jacNoise!(λ, y, p, t, S) + else + jacNoise!(λ, y, p, t, S, dλ = dλ) + end + else + vecjacobian!(dλ, y, λ, p, t, S) + end dλ .*= -one(eltype(λ)) discrete || accumulate_cost!(dλ, y, p, t, S) @@ -139,7 +183,7 @@ function split_states(du, u, t, S::ODEGaussAdjointSensitivityFunction; update = @unpack sol, y, checkpoint_sol, discrete, prob, f, GaussInt = S if update if checkpoint_sol === nothing - if t isa ForwardDiff.Dual && eltype(S.y) <: AbstractFloat + if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat y = sol(t, continuity = :right) else sol(y, t, continuity = :right) @@ -151,31 +195,57 @@ function split_states(du, u, t, S::ODEGaussAdjointSensitivityFunction; update = cursor′ = Gaussfindcursor(intervals, t) interval = intervals[cursor′] cpsol_t = checkpoint_sol.cpsol.t - if t isa ForwardDiff.Dual && eltype(S.y) <: AbstractFloat + if typeof(t) <: ForwardDiff.Dual && eltype(S.y) <: AbstractFloat y = sol(interval[1]) else sol(y, interval[1]) end - if checkpoint_sol.tstops === nothing - prob′ = remake(prob, tspan = intervals[cursor′], u0 = y) - cpsol′ = solve(prob′, sol.alg; - dt = abs(cpsol_t[end] - cpsol_t[end - 1]), - checkpoint_sol.tols...) - else - if maximum(interval[1] .< checkpoint_sol.tstops .< interval[2]) - # callback might have changed p - _p = reset_p(prob.kwargs[:callback], interval) - prob′ = remake(prob, tspan = intervals[cursor′], u0 = y, p = _p) - cpsol′ = solve(prob′, sol.alg; - dt = abs(cpsol_t[end] - cpsol_t[end - 1]), - tstops = checkpoint_sol.tstops, - checkpoint_sol.tols...) + if typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + #idx1 = searchsortedfirst(sol.t, interval[1]) + _sol = deepcopy(sol) + idx1 = searchsortedfirst(_sol.t, interval[1] - 100eps(interval[1])) + idx2 = searchsortedfirst(_sol.t, interval[2] + 100eps(interval[2])) + idx_noise = searchsortedfirst(_sol.W.t, + interval[1] - 100eps(interval[1])) + if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess + _sol.W.save_everystep = false + forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, + indx = idx_noise) + elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid + forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx_noise:end], + _sol.W.W[idx_noise:end]) else + error("NoiseProcess type not implemented.") + end + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y, + noise = forwardnoise) + dt = choose_dt(abs(cpsol_t[1] - cpsol_t[2]), cpsol_t, interval) + cpsol′ = solve(prob′, sol.alg, save_noise = false; dt = dt, + tstops = _sol.t[idx1:idx2], checkpoint_sol.tols...) + #cpsol′ = solve(prob′, sol.alg, save_noise = false; dense=true, dt = dt, + # checkpoint_sol.tols...) + else + if checkpoint_sol.tstops === nothing prob′ = remake(prob, tspan = intervals[cursor′], u0 = y) cpsol′ = solve(prob′, sol.alg; dt = abs(cpsol_t[end] - cpsol_t[end - 1]), - tstops = checkpoint_sol.tstops, checkpoint_sol.tols...) + else + if maximum(interval[1] .< checkpoint_sol.tstops .< interval[2]) + # callback might have changed p + _p = reset_p(prob.kwargs[:callback], interval) + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y, p = _p) + cpsol′ = solve(prob′, sol.alg; + dt = abs(cpsol_t[end] - cpsol_t[end - 1]), + tstops = checkpoint_sol.tstops, + checkpoint_sol.tols...) + else + prob′ = remake(prob, tspan = intervals[cursor′], u0 = y) + cpsol′ = solve(prob′, sol.alg; + dt = abs(cpsol_t[end] - cpsol_t[end - 1]), + tstops = checkpoint_sol.tstops, + checkpoint_sol.tols...) + end end end checkpoint_sol.cpsol = cpsol′ @@ -306,6 +376,255 @@ end end end +@noinline function SDEAdjointProblem(sol, sensealg::GaussAdjoint, alg, + GaussInt::GaussIntegrand, + t = nothing, + dgdu_discrete::DG1 = nothing, + dgdp_discrete::DG2 = nothing, + dgdu_continuous::DG3 = nothing, + dgdp_continuous::DG4 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + reltol = nothing, abstol = nothing, + diffusion_jac = nothing, diffusion_paramjac = nothing, + kwargs...) where {DG1, DG2, DG3, DG4, G} + dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing && + error("Either `dgdu_discrete`, `dgdu_continuous`, or `g` must be specified.") + t !== nothing && dgdu_discrete === nothing && dgdp_discrete === nothing && + error("It looks like you're using the direct `adjoint_sensitivities` interface + with a discrete cost function but no specified `dgdu_discrete` or `dgdp_discrete`. + Please use the higher level `solve` interface or specify these two contributions.") + @unpack f, p, u0, tspan = sol.prob + + # check if solution was terminated, then use reduced time span + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == ReturnCode.Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end + end + tspan = reverse(tspan) + discrete = (t !== nothing && + (dgdu_continuous === nothing && dgdp_continuous === nothing || + g !== nothing)) + # remove duplicates from checkpoints + if ischeckpointing(sensealg, sol) && + (length(unique(checkpoints)) != length(checkpoints)) + _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) + tstops = duplicate_iterator_times[1] + checkpoints = filter(x -> x ∉ tstops, _checkpoints) + # check if start is in checkpoints. Otherwise first interval is missed. + if checkpoints[1] != tspan[2] + pushfirst!(checkpoints, tspan[2]) + end + + if haskey(kwargs, :tstops) + (tstops !== kwargs[:tstops]) && unique!(push!(tstops, kwargs[:tstops]...)) + end + + # check if end is in checkpoints. + if checkpoints[end] != tspan[1] + push!(checkpoints, tspan[1]) + end + else + tstops = nothing + end + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + + len = numstates + + λ = one(eltype(u0)) .* similar(p, len) + λ .= false + + sense_drift = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, + dgdu_continuous, + dgdp_continuous, sol.prob.f, + alg, checkpoints, + (reltol = reltol, + abstol = abstol), + tspan = tspan) + diffusion_function = ODEFunction{isinplace(sol.prob), true}(sol.prob.g, + jac = diffusion_jac, + paramjac = diffusion_paramjac) + sense_diffusion = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, + dgdu_continuous, + dgdp_continuous, + diffusion_function, + alg, checkpoints, + (reltol = reltol, + abstol = abstol); + tspan = tspan, + noiseterm = true) + + init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end] + cb, _, duplicate_iterator_times = generate_callbacks(sense_drift, dgdu_discrete, + dgdp_discrete, λ, t, + tspan[2], callback, init_cb, + terminated) + z0 = vec(zero(λ)) + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + mm = I + else + adjmm = copy(sol.prob.f.mass_matrix') + zzz = similar(adjmm, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + # using concrate I is slightly more efficient + II = Diagonal(I, numparams) + mm = [adjmm zzz + copy(zzz') II] + end + jac_prototype = sol.prob.f.jac_prototype + if !sense_drift.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + _adjoint_jac_prototype = copy(jac_prototype') + zzz = similar(_adjoint_jac_prototype, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + II = Diagonal(I, numparams) + adjoint_jac_prototype = [_adjoint_jac_prototype zzz + copy(zzz') II] + end + sdefun = SDEFunction(sense_drift, sense_diffusion, mass_matrix = mm, + jac_prototype = adjoint_jac_prototype) + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + + if StochasticDiffEq.is_diagonal_noise(sol.prob) && typeof(sol.W[end]) <: Number + # scalar noise case + noise_matrix = nothing + else + m = sol.prob.noise_rate_prototype === nothing ? numstates : + size(sol.prob.noise_rate_prototype)[2] + noise_matrix = similar(z0, length(z0), m) + noise_matrix .= false + end + + return SDEProblem(sdefun, sense_diffusion, z0, tspan, p, + noise = backwardnoise, + noise_rate_prototype = noise_matrix), cb, nothing +end + +@noinline function RODEAdjointProblem(sol, sensealg::GaussAdjoint, alg, + GaussInt::GaussIntegrand, + t = nothing, + dgdu_discrete::DG1 = nothing, + dgdp_discrete::DG2 = nothing, + dgdu_continuous::DG3 = nothing, + dgdp_continuous::DG4 = nothing, + g::G = nothing; + checkpoints = sol.t, + callback = CallbackSet(), + reltol = nothing, abstol = nothing, + kwargs...) where {DG1, DG2, DG3, DG4, G} + dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing && + error("Either `dgdu_discrete`, `dgdu_continuous`, or `g` must be specified.") + t !== nothing && dgdu_discrete === nothing && dgdp_discrete === nothing && + error("It looks like you're using the direct `adjoint_sensitivities` interface + with a discrete cost function but no specified `dgdu_discrete` or `dgdp_discrete`. + Please use the higher level `solve` interface or specify these two contributions.") + @unpack f, p, u0, tspan = sol.prob + + # check if solution was terminated, then use reduced time span + terminated = false + if hasfield(typeof(sol), :retcode) + if sol.retcode == ReturnCode.Terminated + tspan = (tspan[1], sol.t[end]) + terminated = true + end + end + tspan = reverse(tspan) + + discrete = (t !== nothing && + (dgdu_continuous === nothing && dgdp_continuous === nothing || + g !== nothing)) + + # remove duplicates from checkpoints + if ischeckpointing(sensealg, sol) && + (length(unique(checkpoints)) != length(checkpoints)) + _checkpoints, duplicate_iterator_times = separate_nonunique(checkpoints) + tstops = duplicate_iterator_times[1] + checkpoints = filter(x -> x ∉ tstops, _checkpoints) + # check if start is in checkpoints. Otherwise first interval is missed. + if checkpoints[1] != tspan[2] + pushfirst!(checkpoints, tspan[2]) + end + + if haskey(kwargs, :tstops) + (tstops !== kwargs[:tstops]) && unique!(push!(tstops, kwargs[:tstops]...)) + end + + # check if end is in checkpoints. + if checkpoints[end] != tspan[1] + push!(checkpoints, tspan[1]) + end + else + tstops = nothing + end + + numstates = length(u0) + numparams = p === nothing || p === DiffEqBase.NullParameters() ? 0 : length(p) + + len = numstates + + λ = p === nothing || p === DiffEqBase.NullParameters() ? similar(u0) : + one(eltype(u0)) .* similar(p, len) + λ .= false + + sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol, + dgdu_continuous, dgdp_continuous, f, + alg, checkpoints, + (reltol = reltol, abstol = abstol), + tstops, tspan = tspan) + + init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end] + cb, _, duplicate_iterator_times = generate_callbacks(sense, dgdu_discrete, + dgdp_discrete, + λ, t, tspan[2], + callback, init_cb, terminated) + z0 = vec(zero(λ)) + original_mm = sol.prob.f.mass_matrix + if original_mm === I || original_mm === (I, I) + mm = I + else + adjmm = copy(sol.prob.f.mass_matrix') + zzz = similar(adjmm, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + # using concrate I is slightly more efficient + II = Diagonal(I, numparams) + mm = [adjmm zzz + copy(zzz') II] + end + + jac_prototype = sol.prob.f.jac_prototype + if !sense.discrete || jac_prototype === nothing + adjoint_jac_prototype = nothing + else + _adjoint_jac_prototype = copy(jac_prototype') + zzz = similar(_adjoint_jac_prototype, numstates, numparams) + fill!(zzz, zero(eltype(zzz))) + II = Diagonal(I, numparams) + adjoint_jac_prototype = [_adjoint_jac_prototype zzz + copy(zzz') II] + end + + rodefun = RODEFunction(sense, mass_matrix = mm, jac_prototype = adjoint_jac_prototype) + + # replicated noise + _sol = deepcopy(sol) + backwardnoise = reverse(_sol.W) + # make sure noise grid starts at correct time values, e.g., if sol.W.t is longer than sol.t + tspan[1] != backwardnoise.t[1] && + reinit!(backwardnoise, backwardnoise.t[2] - backwardnoise.t[1], t0 = tspan[1]) + + return RODEProblem(rodefun, z0, tspan, p, + noise = backwardnoise), cb, nothing +end + function Gaussreset_p(CBS, interval) # check which events are close to tspan[1] if !isempty(CBS.discrete_callbacks) @@ -364,7 +683,7 @@ function Gaussreset_p(CBS, interval) return p end - + function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) prob = sol.prob @unpack f, p, tspan, u0 = prob @@ -402,11 +721,21 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) elseif sensealg.autojacvec isa EnzymeVJP paramjac_config = zero(y), zero(y) pf = let f = unwrappedf - if DiffEqBase.isinplace(prob) + if DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + f(out, u, _p, t, W) + nothing + end + elseif DiffEqBase.isinplace(prob) function (out, u, _p, t) f(out, u, _p, t) nothing end + elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem + function (out, u, _p, t, W) + out .= f(u, _p, t, W) + nothing + end else !DiffEqBase.isinplace(prob) function (out, u, _p, t) @@ -427,6 +756,34 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) end cpsol = sol + checkpointing = ischeckpointing(sensealg, sol) + if checkpointing && typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + intervals = map(tuple, @view(checkpoints[1:(end - 1)]), @view(checkpoints[2:end])) + interval_end = intervals[end][end] + tspan[1] > interval_end && push!(intervals, (interval_end, tspan[1])) + cursor = lastindex(intervals) + interval = intervals[cursor] + _sol = deepcopy(sol) + idx1 = searchsortedfirst(_sol.W.t, interval[1] - 1000eps(interval[1])) + if typeof(sol.W) <: DiffEqNoiseProcess.NoiseProcess + sol.W.save_everystep = false + _sol.W.save_everystep = false + forwardnoise = DiffEqNoiseProcess.NoiseWrapper(_sol.W, indx = idx1) + elseif typeof(sol.W) <: DiffEqNoiseProcess.NoiseGrid + forwardnoise = DiffEqNoiseProcess.NoiseGrid(_sol.W.t[idx1:end], + _sol.W.W[idx1:end]) + else + error("NoiseProcess type not implemented.") + end + dt = choose_dt((_sol.W.t[idx1] - _sol.W.t[idx1 + 1]), _sol.W.t, interval) + #cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]), + # noise = forwardnoise), + # sol.alg, save_noise = false; dt = dt, dense=true) + cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]), + noise = forwardnoise), + sol.alg, save_noise = false; dt = dt, tstops = _sol.t[idx1:end], + abstol=1e-14, reltol=1e-14) + end GaussIntegrand(cpsol, p, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp) @@ -436,6 +793,7 @@ end function vec_pjac!(out, λ, y, t, S::GaussIntegrand) @unpack pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol = S f = sol.prob.f + isautojacvec = get_jacvec(sensealg) # y is aliased @@ -466,7 +824,74 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) vec(f(y, p, t)) end tmp = back(λ) - recursive_copyto!(out, tmp[1]) + #out[:] .= vec(tmp[1]) + recursive_copyto!(out,tmp[1]) + elseif sensealg.autojacvec isa EnzymeVJP + tmp3, tmp4 = paramjac_config + tmp4 .= λ + out .= 0 + Enzyme.autodiff(Enzyme.Reverse, pf, Enzyme.Duplicated(tmp3, tmp4), + y, Enzyme.Duplicated(p, out), t) + else + error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint") + end + # TODO: Add tracker? + + + return out +end + +function vec_pjac_diffusion!(out, λ, y, t, S::GaussIntegrand, W = nothing) + @unpack pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol = S + f = sol.prob.f + g = sol.prob.g + + Wtmp = sol.W(t) + + if sensealg.autojacvec isa ZygoteVJP + if W === nothing + _dy, back = Zygote.pullback(p) do p + vec(g(y, p, t).*Wtmp) + end + else + _dy, back = Zygote.pullback(p) do p + vec(g(y, p, t, W)) + end + end + tmp, = back(λ) + recursive_add!(out, tmp) + end + + #= + if !isautojacvec + if DiffEqBase.has_paramjac(f) + f.paramjac(pJ, y, p, t) # Calculate the parameter Jacobian into pJ + else + pf.t = t + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + end + mul!(out', λ', pJ) + elseif sensealg.autojacvec isa ReverseDiffVJP + tape = paramjac_config + tu, tp, tt = ReverseDiff.input_hook(tape) + output = ReverseDiff.output_hook(tape) + ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls + ReverseDiff.unseed!(tp) + ReverseDiff.unseed!(tt) + ReverseDiff.value!(tu, y) + ReverseDiff.value!(tp, p) + ReverseDiff.value!(tt, [t]) + ReverseDiff.forward_pass!(tape) + ReverseDiff.increment_deriv!(output, λ) + ReverseDiff.reverse_pass!(tape) + copyto!(vec(out), ReverseDiff.deriv(tp)) + elseif sensealg.autojacvec isa ZygoteVJP + _dy, back = Zygote.pullback(p) do p + vec(f(y, p, t)) + end + tmp = back(λ) + #out[:] .= vec(tmp[1]) + recursive_copyto!(out,tmp[1]) elseif sensealg.autojacvec isa EnzymeVJP tmp3, tmp4 = paramjac_config tmp4 .= λ @@ -477,7 +902,9 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint") end # TODO: Add tracker? + =# return out + end function (S::GaussIntegrand)(out, t, λ) @@ -488,16 +915,21 @@ function (S::GaussIntegrand)(out, t, λ) y = sol(t) end vec_pjac!(out, λ, y, t, S) + if typeof(sol.prob) <: Union{SDEProblem, RODEProblem} + vec_pjac_diffusion!(out, λ, y, t, S) + end if S.dgdp !== nothing S.dgdp(dgdp_cache, y, p, t) @show typeof(dgdp_cache) out .+= dgdp_cache end + #out' out end function (S::GaussIntegrand)(t, λ) - out = allocate_zeros(S.p) + #out = similar(S.p) + out = DiffEqCallbacks.allocate_zeros(S.p) S(out, t, λ) end @@ -507,20 +939,19 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing, dgdu_continuous = nothing, dgdp_continuous = nothing, g = nothing, - abstol = 1e-6, reltol = 1e-3, + abstol = sensealg.abstol, reltol = sensealg.reltol, checkpoints = sol.t, corfunc_analytical = false, callback = CallbackSet(), kwargs...) - + integrand = GaussIntegrand(sol, sensealg, checkpoints, dgdp_continuous) - integrand_values = IntegrandValuesSum(allocate_zeros(sol.prob.p)) - cb = IntegratingSumCallback((out, u, t, integrator) -> integrand(out, t, u), - integrand_values, allocate_vjp(sol.prob.p)) + integrand_values = IntegrandValues(Float64, typeof(sol.prob.p)) + cb = IntegratingCallback((out, u, t, integrator) -> integrand(out, t, u), integrand_values, DiffEqCallbacks.allocate_zeros(sol.prob.p))#similar(sol.prob.p)) rcb = nothing cb2 = nothing adj_prob = nothing - + if sol.prob isa ODEProblem adj_prob, cb2, rcb = ODEAdjointProblem(sol, sensealg, alg, integrand, t, dgdu_discrete, dgdp_discrete, @@ -528,36 +959,72 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing, checkpoints = checkpoints, callback = callback, abstol = abstol, reltol = reltol, kwargs...) + elseif sol.prob isa SDEProblem + adj_prob, cb2, rcb = SDEAdjointProblem(sol, sensealg, alg, integrand, t, dgdu_discrete, dgdp_discrete, + dgdu_continuous, dgdp_continuous, g; + checkpoints = checkpoints, + callback = callback, + abstol = abstol, reltol = reltol, + corfunc_analytical = corfunc_analytical) + elseif sol.prob isa RODEProblem + adj_prob, cb2, rcb = RODEAdjointProblem(sol, sensealg, alg, integrand, t, dgdu_discrete, dgdp_discrete, + dgdu_continuous, dgdp_continuous, g; + checkpoints = checkpoints, + callback = callback, + abstol = abstol, reltol = reltol, + corfunc_analytical = corfunc_analytical) else - error("Continuous adjoint sensitivities are only supported for ODE problems.") + error("Continuous adjoint sensitivities are only supported for ODE/SDE/RODE problems.") end tstops = ischeckpointing(sensealg, sol) ? checkpoints : similar(sol.t, 0) - - adj_sol = solve(adj_prob, alg; abstol = abstol, reltol = reltol, save_everystep = false, + + adj_sol = solve(adj_prob, alg; abstol = abstol, reltol = reltol, save_everystep = false, save_start = false, save_end = true, saveat = eltype(sol[1])[], tstops = tstops, - callback = CallbackSet(cb,cb2), kwargs...) - res = integrand_values.integrand + callback = CallbackSet(cb,cb2), + kwargs...) + res = compute_dGdp(integrand_values) + #println("adj_sol.t = ", adj_sol.t) + #println(tstops) if rcb !== nothing && !isempty(rcb.Δλas) iλ = zero(rcb.λ) - out = zero(res) + out = zero(res') yy = similar(rcb.y) for (Δλa, tt) in rcb.Δλas @unpack algevar_idxs = rcb.diffcache iλ[algevar_idxs] .= Δλa sol(yy, tt) vec_pjac!(out, iλ, yy, tt, integrand) - res .+= out + res .+= out' iλ .= zero(eltype(iλ)) end end - return adj_sol[end], __maybe_adjoint(res) + return adj_sol[end], res' end -__maybe_adjoint(x::AbstractArray) = x' -__maybe_adjoint(x) = x +recursive_add!(x::AbstractArray, y::AbstractArray) = x .+= y +recursive_add!(x::Tuple, y::Tuple) = recursive_add!.(x, y) +function recursive_add!(x::NamedTuple{F}, y::NamedTuple{F}) where {F} + return NamedTuple{F}(recursive_add!(values(x), values(y))) +end +function compute_dGdp(integrand::IntegrandValues) + res = DiffEqCallbacks.allocate_zeros(integrand.integrand[1]) + for (i, j) in enumerate(integrand.integrand) + recursive_add!(res, j) + end + return res +end +#= +function compute_dGdp(integrand::IntegrandValues) + res = zeros(length(integrand.integrand[1])) + for (i, j) in enumerate(integrand.integrand) + res .+= j + end + return res +end +=# function update_p_integrand(integrand::GaussIntegrand, p) @unpack sol, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp = integrand @@ -642,3 +1109,4 @@ function _update_integrand_and_dgrad(res, sensealg::GaussAdjoint, cb, integrand, res .-= dgrad return integrand end + diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 407044e34..23eb6b476 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -569,18 +569,22 @@ struct GaussAdjoint{CS, AD, FDT, VJP} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP checkpointing::Bool + abstol::Float64 + reltol::Float64 end Base.@pure function GaussAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, - checkpointing=false) - GaussAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}(autojacvec, checkpointing) + checkpointing=false, + abstol = 1e-6, + reltol = 1e-3) + GaussAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}(autojacvec, checkpointing, abstol, reltol) end TruncatedStacktraces.@truncate_stacktrace GaussAdjoint function setvjp(sensealg::GaussAdjoint{CS, AD, FDT, Nothing}, vjp) where {CS, AD, FDT} - GaussAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing) + GaussAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing, sensealg.abstol, sensealg.reltol) end """ diff --git a/test/adjoint.jl b/test/adjoint.jl index 6fb54a5fa..903762ab5 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -168,7 +168,7 @@ integrand = AdjointSensitivityIntegrand(sol, adj_sol, autojacvec = SciMLSensitivity.ReverseDiffVJP())) res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12) -@test isapprox(res, easy_res, rtol = 1e-10) +@test isapprox(res, easy_res, rtol= 1e-10) @test isapprox(res, easy_res2, rtol = 1e-10) @test isapprox(res, easy_res22, rtol = 1e-10) @test isapprox(res, easy_res23, rtol = 1e-10) diff --git a/test/sde_scalar_stratonovich.jl b/test/sde_scalar_stratonovich.jl index 50fc51956..5713e6cd2 100644 --- a/test/sde_scalar_stratonovich.jl +++ b/test/sde_scalar_stratonovich.jl @@ -1,6 +1,7 @@ using Test, LinearAlgebra using SciMLSensitivity, StochasticDiffEq using Random +using Interpolations @info "SDE Adjoints" @@ -22,8 +23,75 @@ function dg!(out, u, p, t, i) (out .= u) end +function dg(u, p, t, i) + return u +end + p2 = [1.01, 0.87] + +using DiffEqNoiseProcess + +dtscalar = tend / 1e3 + +f!(du, u, p, t) = (du .= p[1] * u) +σ!(du, u, p, t) = (du .= p[2] * u) + +function foop(u, p, t) + return p[1] * u +end +function σoop(u, p, t) + return p[2] * u +end + +@info "scalar SDE" + + +Random.seed!(seed) +W = WienerProcess(0.0, 0.0, 0.0) +u0 = rand(2) + +linear_analytic_strat(u0, p, t, W) = @.(u0*exp(p[1] * t + p[2] * W)) + +prob = SDEProblem(SDEFunction(f!, σ!, analytic = linear_analytic_strat), σ!, u0, trange, + p2, + noise = W) +prob_oop = SDEProblem(SDEFunction(foop, σoop, analytic = linear_analytic_strat), σoop, u0, trange, + p2, + noise = W) +sol = solve(prob, EulerHeun(), dt = dtscalar, save_noise = true) + +sol_oop = solve(prob_oop, EulerHeun(), dt = dtscalar, save_noise = true) + +@test isapprox(sol.u_analytic, sol.u, atol = 1e-4) +@test isapprox(sol_oop.u_analytic, sol_oop.u, atol = 1e-4) + + +res_sde_u0, res_sde_p = adjoint_sensitivities(sol_oop, EulerHeun(), t = Array(t), + dgdu_discrete = dg!, + dt = dtscalar, adaptive = false, + sensealg = BacksolveAdjoint()) + +@show res_sde_u0, res_sde_p + +res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop, EulerHeun(), t = Array(t), + dgdu_discrete = dg!, + dt = tend / 1e2, adaptive = false, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())) + +@test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) +@test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) + +res_sde_u02, res_sde_p2 = adjoint_sensitivities(sol_oop, EulerHeun(), t=Array(t), + dgdu_discrete = dg!, + adaptive = false, + sensealg = GaussAdjoint(autojacvec = ZygoteVJP())) + + +@test isapprox(res_sde_u0, res_sde_u02, rtol = 1e-4) +@test isapprox(res_sde_p, res_sde_p2, rtol = 1e-4) + +#= # scalar noise @testset "SDE inplace scalar noise tests" begin using DiffEqNoiseProcess @@ -233,3 +301,4 @@ end @test isapprox(true_grads[2], res_sde_p2', atol = 1e-4) @test isapprox(true_grads[1], res_sde_u02, rtol = 1e-4) end +=# \ No newline at end of file