From 2939a6a8917520ab22c4dccaa2f55a8dfca93b83 Mon Sep 17 00:00:00 2001 From: Andrew Kille <68079167+apkille@users.noreply.github.com> Date: Fri, 13 Sep 2024 23:42:06 -0400 Subject: [PATCH] Implement ForwardDiff in master solvers and general DiffEq problems on QO types (#409) --- Project.toml | 3 ++- src/master.jl | 4 +++ src/schroedinger.jl | 19 ------------- src/timeevolution_base.jl | 44 ++++++++++++++++++++++++++++++ test/test_ForwardDiff.jl | 56 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index a20c6559..116955ce 100644 --- a/Project.toml +++ b/Project.toml @@ -39,10 +39,11 @@ WignerSymbols = "1, 2" julia = "1.10" [extras] +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "SparseArrays", "Random", "Test"] +test = ["FiniteDiff", "LinearAlgebra", "SparseArrays", "Random", "Test"] diff --git a/src/master.jl b/src/master.jl index 0cf63ff6..e402d611 100644 --- a/src/master.jl +++ b/src/master.jl @@ -14,6 +14,7 @@ function master_h(tspan, rho0::Operator, H::AbstractOperator, J; _check_const.(J) _check_const.(Jdagger) check_master(rho0, H, J, Jdagger, rates) + tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan) tmp = copy(rho0) dmaster_(t, rho, drho) = dmaster_h!(drho, H, J, Jdagger, rates, rho, tmp) integrate_master(tspan, dmaster_, rho0, fout; kwargs...) @@ -41,6 +42,7 @@ function master_nh(tspan, rho0::Operator, Hnh::AbstractOperator, J; _check_const.(J) _check_const.(Jdagger) check_master(rho0, Hnh, J, Jdagger, rates) + tspan, rho0 = _promote_time_and_state(rho0, Hnh, J, tspan) tmp = copy(rho0) dmaster_(t, rho, drho) = dmaster_nh!(drho, Hnh, Hnhdagger, J, Jdagger, rates, rho, tmp) integrate_master(tspan, dmaster_, rho0, fout; kwargs...) @@ -86,6 +88,7 @@ function master(tspan, rho0::Operator, H::AbstractOperator, J; _check_const(H) _check_const.(J) _check_const.(Jdagger) + tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan) isreducible = check_master(rho0, H, J, Jdagger, rates) if !isreducible tmp = copy(rho0) @@ -124,6 +127,7 @@ function master(tspan, rho0::Operator, L::SuperOperator; fout=nothing, kwargs... b = GenericBasis(dim) rho_ = Ket(b,reshape(rho0.data, dim)) L_ = Operator(b,b,L.data) + tspan, rho_ = _promote_time_and_state(rho_, L_, tspan) dmaster_(t,rho,drho) = dmaster_liouville!(drho,L_,rho) # Rewrite into density matrix when saving diff --git a/src/schroedinger.jl b/src/schroedinger.jl index 0910066d..eddecc6c 100644 --- a/src/schroedinger.jl +++ b/src/schroedinger.jl @@ -122,22 +122,3 @@ function check_schroedinger(psi::Bra, H) check_multiplicable(psi, H) check_samebases(H) end - - -function _promote_time_and_state(u0, H::AbstractOperator, tspan) - Ts = eltype(H) - Tt = real(Ts) - p = Vector{Tt}(undef,0) - u0data_promote = DiffEqBase.promote_u0(u0.data, p, tspan[1]) - tspan_promote = DiffEqBase.promote_tspan(u0data_promote, p, tspan, nothing, Dict{Symbol, Any}()) - if u0data_promote !== u0.data - u0_promote = rebuild(u0, u0data_promote) - return tspan_promote, u0_promote - end - return tspan_promote, u0 -end -_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan), u0), tspan) - -rebuild(op::Operator, new_data) = Operator(op.basis_l, op.basis_r, new_data) -rebuild(state::Ket, new_data) = Ket(state.basis, new_data) -rebuild(state::Bra, new_data) = Bra(state.basis, new_data) diff --git a/src/timeevolution_base.jl b/src/timeevolution_base.jl index 8949df0c..c34bb0d9 100644 --- a/src/timeevolution_base.jl +++ b/src/timeevolution_base.jl @@ -121,3 +121,47 @@ macro skiptimechecks(ex) end Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T) + +function _promote_time_and_state(u0, H::AbstractOperator, tspan) + Ts = eltype(H) + Tt = real(Ts) + p = Vector{Tt}(undef,0) + u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) + tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) + return tspan_promote, u0_promote +end +function _promote_time_and_state(u0, H::AbstractOperator, J, tspan) + Ts = DiffEqBase.promote_dual(eltype(H), DiffEqBase.anyeltypedual(J)) + Tt = real(Ts) + p = Vector{Tt}(undef,0) + u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) + tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) + return tspan_promote, u0_promote +end + +_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan) + +@inline function DiffEqBase.promote_u0(u0::Ket, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Ket(u0.basis, u0data_promote) + return u0_promote + end + return u0 +end +@inline function DiffEqBase.promote_u0(u0::Bra, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Bra(u0.basis, u0data_promote) + return u0_promote + end + return u0 +end +@inline function DiffEqBase.promote_u0(u0::Operator, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote) + return u0_promote + end + return u0 +end \ No newline at end of file diff --git a/test/test_ForwardDiff.jl b/test/test_ForwardDiff.jl index 91b889fc..e3ccba9c 100644 --- a/test/test_ForwardDiff.jl +++ b/test/test_ForwardDiff.jl @@ -1,6 +1,7 @@ using Test using OrdinaryDiffEq, QuantumOptics import ForwardDiff +import FiniteDiff # for some caese ForwardDiff.jl returns NaN due to issue with DiffEq.jl. see https://github.com/SciML/DiffEqBase.jl/issues/861 # Here we test; @@ -12,6 +13,27 @@ import ForwardDiff # here we partially control the gradient error by limiting step size (dtmax) +@testset "ForwardDiff on ODE Problems" begin + +# schroedinger equation +b = SpinBasis(10//1) +psi0 = spindown(b) +H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) +f_schrod!(dpsi, psi, p, t) = timeevolution.dschroedinger!(dpsi, H(p), psi) +function cost_schrod(p) + prob = ODEProblem(f_schrod!, psi0, (0.0, pi), p) + sol = solve(prob, DP5(); save_everystep=false) + return 1 - norm(sol[end]) +end + +p = [rand(), rand()] +fordiff_schrod = ForwardDiff.gradient(cost_schrod, p) +findiff_schrod = FiniteDiff.finite_difference_gradient(cost_schrod, p) + +@test isapprox(fordiff_schrod, findiff_schrod; atol=1e-2) + +end + @testset "ForwardDiff with schroedinger" begin # system @@ -73,3 +95,37 @@ Ftdop(1.0) @test ForwardDiff.derivative(Ftdop, 1.0) isa Any end # testset + + +@testset "ForwardDiff with master" begin + +b = SpinBasis(1//2) +psi0 = spindown(b) +rho0 = dm(psi0) +params = [rand(), rand()] + +for f in (:(timeevolution.master), :(timeevolution.master_h), :(timeevolution.master_nh)) + # test to see if parameter propagates through Hamiltonian + H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # Hamiltonian + function cost_H(p) # + tf, psif = eval(f)((0.0, pi), rho0, H(p), [sigmax(b)]) + return 1 - norm(psif) + end + + forwarddiff_H = ForwardDiff.gradient(cost_H, params) + finitediff_H = FiniteDiff.finite_difference_gradient(cost_H, params) + @test isapprox(forwarddiff_H, finitediff_H; atol=1e-2) + + # test to see if parameter propagates through Jump operator + J(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # jump operator + function cost_J(p) + tf, psif = eval(f)((0.0, pi), rho0, sigmax(b), [J(p)]) + return 1 - norm(psif) + end + + forwarddiff_J = ForwardDiff.gradient(cost_J, params) + finitediff_J = FiniteDiff.finite_difference_gradient(cost_J, params) + @test isapprox(forwarddiff_J, finitediff_J; atol=1e-2) +end + +end