Skip to content

Commit

Permalink
Inference improvements for timeevolution (#396)
Browse files Browse the repository at this point in the history
* Specialize on function arguments in timeevolution
* Stabilize mcwf
* Use let blocks around inner functions

---------

Co-authored-by: Ashley Milsted <[email protected]>
  • Loading branch information
amilsted and Ashley Milsted authored Jun 15, 2024
1 parent b515a2b commit b9a88b7
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 78 deletions.
4 changes: 2 additions & 2 deletions src/bloch_redfield_master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ end

# Integrate with given fout
function integrate_br(tspan, dmaster_br, rho,
transf_op, inv_transf_op, fout::Function;
kwargs...)
transf_op, inv_transf_op, fout::F;
kwargs...) where {F}
# Pre-allocate for in-place back-transformation from eigenbasis
rho_out = copy(transf_op)
tmp = copy(transf_op)
Expand Down
8 changes: 5 additions & 3 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ function master_dynamic(tspan, rho0::Operator, f;
fout=nothing,
kwargs...)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_h_dynamic!(drho, f, rates, rho, tmp, t)
dmaster_ = let f = f, tmp = tmp
dmaster_(t, rho, drho) = dmaster_h_dynamic!(drho, f, rates, rho, tmp, t)
end
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
end

Expand Down Expand Up @@ -395,7 +397,7 @@ returned from `f`.
See also: [`master_dynamic`](@ref), [`dmaster_h!`](@ref), [`dmaster_nh!`](@ref),
[`dmaster_nh_dynamic!`](@ref)
"""
function dmaster_h_dynamic!(drho, f, rates, rho, drho_cache, t)
function dmaster_h_dynamic!(drho, f::F, rates, rho, drho_cache, t) where {F}
result = f(t, rho)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
Expand All @@ -418,7 +420,7 @@ equation. Optionally, rates can also be returned from `f`.
See also: [`master_dynamic`](@ref), [`dmaster_h!`](@ref), [`dmaster_nh!`](@ref),
[`dmaster_h_dynamic!`](@ref)
"""
function dmaster_nh_dynamic!(drho, f, rates, rho, drho_cache, t)
function dmaster_nh_dynamic!(drho, f::F, rates, rho, drho_cache, t) where {F}
result = f(t, rho)
QO_CHECKS[] && @assert 4 <= length(result) <= 5
if length(result) == 4
Expand Down
156 changes: 103 additions & 53 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ function mcwf_h(tspan, psi0::Ket, H::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_mcwf(psi0, H, J, Jdagger, rates)
f(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
f = let H = H, J = J, Jdagger = Jdagger, rates = rates, tmp = tmp
f(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
end
probs = zeros(real(eltype(psi0)), length(J))
j = let J = J, probs = probs, rates = rates
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
Expand All @@ -48,8 +53,13 @@ function mcwf_nh(tspan, psi0::Ket, Hnh::AbstractOperator, J;
_check_const(Hnh)
_check_const.(J)
check_mcwf(psi0, Hnh, J, J, nothing)
f(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, nothing)
f = let Hnh = Hnh
f(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
end
probs = zeros(real(eltype(psi0)), length(J))
j = let J = J, probs = probs
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, nothing)
end
integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
Expand Down Expand Up @@ -107,8 +117,13 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
isreducible = check_mcwf(psi0, H, J, Jdagger, rates)
if !isreducible
tmp = copy(psi0)
dmcwf_h_(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
dmcwf_h_ = let H = H, J = J, Jdagger = Jdagger, rates = rates, tmp = tmp
dmcwf_h_(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
end
probs = zeros(real(eltype(psi0)), length(J))
j_h = let J = J, probs = probs, rates = rates
j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -125,8 +140,13 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
Hnh -= complex(float(eltype(H)))(0.5im*rates[i])*Jdagger[i]*J[i]
end
end
dmcwf_nh_(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
dmcwf_nh_ = let Hnh = Hnh # Hnh type often not inferrable
dmcwf_nh_(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
end
probs = zeros(real(eltype(psi0)), length(J))
j_nh = let J = J, probs = probs, rates = rates
j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand Down Expand Up @@ -177,8 +197,14 @@ function mcwf_dynamic(tspan, psi0::Ket, f;
fout=nothing, display_beforeevent=false, display_afterevent=false,
kwargs...)
tmp = copy(psi0)
dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic!(dpsi, f, rates, psi, tmp, t)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
dmcwf_ = let f = f, tmp = tmp, rates = rates
dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic!(dpsi, f, rates, psi, tmp, t)
end
J = f(first(tspan), psi0)[2]
probs = zeros(real(eltype(psi0)), length(J))
j_ = let f = f, probs = probs, rates = rates
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -203,8 +229,14 @@ function mcwf_nh_dynamic(tspan, psi0::Ket, f;
seed=rand(UInt), rates=nothing,
fout=nothing, display_beforeevent=false, display_afterevent=false,
kwargs...)
dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic!(dpsi, f, psi, t)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
dmcwf_ = let f = f
dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic!(dpsi, f, psi, t)
end
J = f(first(tspan), psi0)[2]
probs = zeros(real(eltype(psi0)), length(J))
j_ = let f = f, probs = probs, rates = rates
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -225,7 +257,7 @@ update `dpsi` according to a non-Hermitian Schrödinger equation.
See also: [`mcwf_dynamic`](@ref), [`dmcwf_h!`](@ref), [`dmcwf_nh_dynamic`](@ref)
"""
function dmcwf_h_dynamic!(dpsi, f, rates, psi, dpsi_cache, t)
function dmcwf_h_dynamic!(dpsi, f::F, rates, psi, dpsi_cache, t) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
Expand All @@ -246,15 +278,15 @@ and update `dpsi` according to a Schrödinger equation.
See also: [`mcwf_nh_dynamic`](@ref), [`dmcwf_nh!`](@ref), [`dschroedinger!`](@ref)
"""
function dmcwf_nh_dynamic!(dpsi, f, psi, t)
function dmcwf_nh_dynamic!(dpsi, f::F, psi, t) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
H, J, Jdagger = result[1:3]
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, nothing)
dschroedinger!(dpsi, H, psi)
end

function jump_dynamic(rng, t, psi, f, psi_new, rates)
function jump_dynamic(rng, t, psi, f::F, psi_new, probs_tmp, rates) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
Expand All @@ -263,7 +295,7 @@ function jump_dynamic(rng, t, psi, f, psi_new, rates)
else
rates_ = result[4]
end
jump(rng, t, psi, J, psi_new, rates_)
jump(rng, t, psi, J, psi_new, probs_tmp, rates_)
end

"""
Expand All @@ -289,15 +321,15 @@ Integrate a single Monte Carlo wave function trajectory.
an initial jump threshold. If provided, `seed` is ignored.
* `kwargs`: Further arguments are passed on to the ode solver.
"""
function integrate_mcwf(dmcwf, jumpfun, tspan,
psi0, seed, fout::Function;
function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
psi0, seed, fout;
display_beforeevent=false, display_afterevent=false,
display_jumps=false,
rng_state=nothing,
save_everystep=false, callback=nothing,
saveat=tspan,
alg=OrdinaryDiffEq.DP5(),
kwargs...)
kwargs...) where {T, J}

tspan_ = convert(Vector{float(eltype(tspan))}, tspan)
# Display before or after events
Expand All @@ -308,29 +340,33 @@ function integrate_mcwf(dmcwf, jumpfun, tspan,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
return nothing
end
save_before! = display_beforeevent ? save_func! : (affect!,integrator)->nothing
save_after! = display_afterevent ? save_func! : (affect!,integrator)->nothing
no_save_func!(affect!,integrator) = nothing
save_before! = display_beforeevent ? save_func! : no_save_func!
save_after! = display_afterevent ? save_func! : no_save_func!

# Display jump operator index and times
jump_t = eltype(tspan_)[]
jump_index = Int[]
save_t_index = if display_jumps
function(t,i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
else
(t,i)->nothing
end

function fout_(x, t, integrator)
recast!(state,x)
fout(t, state)
function jump_saver(t, i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
no_jump_saver(t, i) = nothing

save_t_index = display_jumps ? jump_saver : no_jump_saver

state = copy(psi0)
dstate = copy(psi0)

fout_ = let state = state, fout = fout
function fout_(x, t, integrator)
recast!(state,x)
fout(t, state)
end
end

out_type = pure_inference(fout, Tuple{eltype(tspan_),typeof(state)})
out = DiffEqCallbacks.SavedValues(eltype(tspan_),out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan_,
Expand All @@ -340,11 +376,14 @@ function integrate_mcwf(dmcwf, jumpfun, tspan,
cb = jump_callback(jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0, rng_state)
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx, x, p, t)
recast!(state,x)
recast!(dstate,dx)
dmcwf(t, state, dstate)
recast!(dx,dstate)
df_ = let state = state, dstate = dstate # help inference along
function df_(dx, x, p, t)
recast!(state,x)
recast!(dstate,dx)
dmcwf(t, state, dstate)
recast!(dx,dstate)
return nothing
end
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0), (tspan_[1],tspan_[end]))
Expand Down Expand Up @@ -396,8 +435,8 @@ end
roll!(s::JumpRNGState{T}) where T = (s.threshold = rand(s.rng, T))
threshold(s::JumpRNGState) = s.threshold

function jump_callback(jumpfun, seed, scb, save_before!,
save_after!, save_t_index, psi0, rng_state::JumpRNGState)
function jump_callback(jumpfun::F, seed, scb, save_before!::G,
save_after!::H, save_t_index::I, psi0, rng_state::JumpRNGState) where {F,G,H,I}

tmp = copy(psi0)
psi_tmp = copy(psi0)
Expand Down Expand Up @@ -431,7 +470,7 @@ jump_callback(jumpfun, seed, scb, save_before!,
as_vector(psi::StateVector) = psi.data

"""
jump(rng, t, psi, J, psi_new)
jump(rng, t, psi, J, psi_new, probs_tmp)
Default jump function.
Expand All @@ -441,41 +480,52 @@ Default jump function.
* `psi`: State vector before the jump.
* `J`: List of jump operators.
* `psi_new`: Result of jump.
* `probs_tmp`: Temporary array for holding jump probailities.
"""
function jump(rng, t, psi, J, psi_new, rates::Nothing)
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::Nothing)
if length(J)==1
QuantumOpticsBase.mul!(psi_new,J[1],psi,true,false)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(real(eltype(psi)), length(J))
for i=1:length(J)
QuantumOpticsBase.mul!(psi_new,J[i],psi,true,false)
probs[i] = real(dot(psi_new.data, psi_new.data))
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
end
cumprobs = cumsum(probs./sum(probs))
r = rand(rng)
i = findfirst(cumprobs.>r)
QuantumOpticsBase.mul!(psi_new,J[i],psi,one(eltype(psi))/sqrt(probs[i]),zero(eltype(psi)))
total = sum(probs_tmp)
cumulative_prob = 0.0
i = 0
for p in probs_tmp
i += 1
cumulative_prob += p / total
cumulative_prob > r && break
end
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(1/sqrt(probs_tmp[i])),zero(eltype(psi)))
end
return i
end

function jump(rng, t, psi, J, psi_new, rates::AbstractVector)
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::AbstractVector)
if length(J)==1
QuantumOpticsBase.mul!(psi_new,J[1],psi,eltype(psi)(sqrt(rates[1])),zero(eltype(psi)))
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(real(eltype(psi)), length(J))
for i=1:length(J)
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i])),zero(eltype(psi)))
probs[i] = real(dot(psi_new.data, psi_new.data))
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
end
cumprobs = cumsum(probs./sum(probs))
r = rand(rng)
i = findfirst(cumprobs.>r)
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i]/probs[i])),zero(eltype(psi)))
total = sum(probs_tmp)
cumulative_prob = 0.0
i = 0
for p in probs_tmp
i += 1
cumulative_prob += p / total
cumulative_prob > r && break
end
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i]/probs_tmp[i])),zero(eltype(psi)))
end
return i
end
Expand Down
10 changes: 6 additions & 4 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Integrate Schroedinger equation to evolve states or compute propagators.
therefore must not be changed.
"""
function schroedinger(tspan, psi0::T, H::AbstractOperator{B,B};
fout::Union{Function,Nothing}=nothing,
fout=nothing,
kwargs...) where {B,Bo,T<:Union{AbstractOperator{B,Bo},StateVector{B}}}
_check_const(H)
dschroedinger_(t, psi, dpsi) = dschroedinger!(dpsi, H, psi)
Expand Down Expand Up @@ -44,9 +44,11 @@ Integrate time-dependent Schroedinger equation to evolve states or compute propa
Instead of a function `f`, this takes a time-dependent operator `H`.
"""
function schroedinger_dynamic(tspan, psi0, f;
fout::Union{Function,Nothing}=nothing,
fout=nothing,
kwargs...)
dschroedinger_(t, psi, dpsi) = dschroedinger_dynamic!(dpsi, f, psi, t)
dschroedinger_ = let f = f
dschroedinger_(t, psi, dpsi) = dschroedinger_dynamic!(dpsi, f, psi, t)
end
tspan, psi0 = _promote_time_and_state(psi0, f, tspan) # promote only if ForwardDiff.Dual
x0 = psi0.data
state = copy(psi0)
Expand Down Expand Up @@ -105,7 +107,7 @@ Schrödinger equation as `-im*H*psi`.
See also: [`dschroedinger!`](@ref)
"""
function dschroedinger_dynamic!(dpsi, f, psi, t)
function dschroedinger_dynamic!(dpsi, f::F, psi, t) where {F}
H = f(t, psi)
dschroedinger!(dpsi, H, psi)
end
Expand Down
Loading

0 comments on commit b9a88b7

Please sign in to comment.