Skip to content

Commit

Permalink
Implement the OrdinaryDiffEq interface for Kets
Browse files Browse the repository at this point in the history
The following now works

```julia
using QuantumOptics
using DifferentialEquations

ℋ = SpinBasis(1//2)

σx = sigmax(ℋ)

↓ = s =  spindown(ℋ)

schrod(ψ,p,t) = im * σx * ψ

t₀, t₁ = (0.0, pi)
Δt = 0.1

prob = ODEProblem(schrod, ↓, (t₀, t₁))
sol = solve(prob,Tsit5())
```

It works for Bras as well.
It works for in-place operations and in some situations it is
faster than the standard `timeevolution.schroedinger`.

```julia
ℋ = SpinBasis(20//1)
↓ = spindown(ℋ)
t₀, t₁ = (0.0, pi)
const σx = sigmax(ℋ)
const iσx = im * σx
schrod!(dψ,ψ,p,t) = mul!(dψ, iσx, ψ)
prob! = ODEProblem(schrod!, ↓, (t₀, t₁))

julia> @benchmark sol = solve($prob!,DP5(),save_everystep=false)
BenchmarkTools.Trial:
  memory estimate:  22.67 KiB
  allocs estimate:  178
  --------------
  minimum time:     374.463 μs (0.00% GC)
  median time:      397.327 μs (0.00% GC)
  mean time:        406.738 μs (0.37% GC)
  maximum time:     4.386 ms (89.76% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark timeevolution.schroedinger([$t₀,$t₁], $↓, $σx)
BenchmarkTools.Trial:
  memory estimate:  23.34 KiB
  allocs estimate:  161
  --------------
  minimum time:     748.106 μs (0.00% GC)
  median time:      774.601 μs (0.00% GC)
  mean time:        786.933 μs (0.14% GC)
  maximum time:     4.459 ms (80.46% GC)
  --------------
  samples:          6350
  evals/sample:     1
```
  • Loading branch information
Krastanov committed Apr 7, 2021
1 parent 201b630 commit e08608b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 45 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ version = "v0.2.7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

[compat]
julia = "1.3"
FFTW = "1.2"
Adapt = "1, 2"
RecursiveArrayTools = "2.11"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down
86 changes: 45 additions & 41 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,72 +209,76 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B<:Basis} = BraStyle{B}()
Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T()
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Ket{B}(b, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Ket{B}(b, data)
end
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Bra{B}(b, copy(bc_))
end
find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args)
find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args))
find_basis(x) = x
find_basis(a::StateVector, rest) = a.basis
find_basis(::Any, rest) = find_basis(rest)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Bra{B}(b, data)
end
function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
throw(error("Cannot broadcast function `$f` on type `$T`"))
for f [:find_basis,:find_dType]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)
end

find_basis(a::StateVector, rest) = a.basis
find_dType(a::StateVector, rest) = eltype(a)
Base.getindex(st::StateVector, idx) = getindex(st.data, idx)

# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of kets and broadcast them as arrays
bcf = Broadcast.flatten(bc)
args_ = Tuple(a.data for a=bcf.args)
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:KetStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

# In-place broadcasting for Bras
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of bras and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:BraStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

@inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A)

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init
Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...)
Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, copy(k.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data)
2 changes: 1 addition & 1 deletion src/superoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ end
# end
find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down
15 changes: 13 additions & 2 deletions test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,18 @@ psi_ .+= psi123
bra_ = copy(bra123)
bra_ .= 3*bra123
@test bra_ == 3*dagger(psi123)
@test_throws ErrorException cos.(psi_)
@test_throws ErrorException cos.(bra_)
@test bra_ .* 2 == bra_ .+ bra_
@test bra_ * 2 == bra_ .+ bra_
z = zero(bra_)
z .= bra_ .* 2
@test_broken all(z .== bra_ .+ bra_)
@test z == bra_ .+ bra_
ket_ = bra_'
@test ket_ .* 2 == ket_ .+ ket_
@test ket_ * 2 == ket_ .+ ket_
z = zero(ket_)
z .= ket_ .* 2
@test_broken all(z .== ket_.+ ket_)
@test z == ket_ .+ ket_

end # testset

0 comments on commit e08608b

Please sign in to comment.