diff --git a/src/operators_dense.jl b/src/operators_dense.jl index de077558..d68fa073 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -331,40 +331,45 @@ struct OperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL<:Basis,BR<:Basis} = OperatorStyle{BL,BR}() Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) bl,br = find_basis(bcf.args) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - return Operator{BL,BR}(bl, br, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(bl), length(br)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Operator{BL,BR}(bl, br, data) end -find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) -end -function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:DataOperator}}, axes) - throw(error("Cannot broadcast function `$f` on type `$(eltype(args))`")) -end +find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) +find_dType(a::DataOperator, rest) = eltype(a) +Base.getindex(a::DataOperator, idx) = getindex(a.data, idx) +Base.iterate(a::DataOperator) = iterate(a.data) +Base.iterate(a::DataOperator, idx) = iterate(a.data, idx) # In-place broadcasting @inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of operators and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A) @inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} = throw(IncompatibleBases()) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init +Base.any(f::Function, ρ::Operator; kwargs...) = any(f, ρ.data; kwargs...) # ODE nan checks +Base.all(f::Function, ρ::Operator; kwargs...) = all(f, ρ.data; kwargs...) +Broadcast.similar(ρ::Operator, t) = typeof(ρ)(ρ.basis_l, ρ.basis_r, copy(ρ.data)) +using RecursiveArrayTools +RecursiveArrayTools.recursivecopy!(dst::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copy!(dst.data,src.data) # ODE in-place equations