Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/kmm-integration' into kmm-integr…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
isazi committed Feb 7, 2024
2 parents 9c0e1fb + c087c24 commit eb6ea53
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 146 deletions.
5 changes: 2 additions & 3 deletions CompasToolkit.jl/examples/signal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,14 @@ signal_ref = simulate(CUDALibs(), pssfp_ref, parameters_ref, trajectory_ref, coi
signal_ref = collect(signal_ref)
signal_ref = reshape(signal_ref, ns, nr)

signal = zeros(ComplexF32, ns, nr, ncoils)
CompasToolkit.magnetization_to_signal(
signal = CompasToolkit.magnetization_to_signal(
context,
signal,
echos,
parameters,
trajectory,
coil_sensitivities)

signal = collect(signal)

for c in 1:ncoils
expected = map(x -> x[c], signal_ref)
Expand Down
174 changes: 99 additions & 75 deletions CompasToolkit.jl/src/CompasToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ make_context(device::Integer)::Context = Context(device)
make_context() = make_context(0)

mutable struct CompasArray{T, N} <: AbstractArray{T, N}
context::Context
ptr::Ptr{Cvoid}
sizes::Dims{N}

function CompasArray{T, N}(ptr::Ptr{Cvoid}, sizes::Dims{N}) where {T, N}
obj = new(ptr, sizes)
function CompasArray{T, N}(context::Context, ptr::Ptr{Cvoid}, sizes::Dims{N}) where {T, N}
obj = new(context, ptr, sizes)
destroy = (obj) -> @ccall LIBRARY.compas_destroy_array(ptr::Ptr{Cvoid})::Cvoid
finalizer(destroy, obj)
end
Expand All @@ -50,23 +51,47 @@ Base.getindex(array::CompasArray, i) = throw(ArgumentError("cannot index into a
function make_array(context::Context, input::Array{Float32, N})::CompasArray{Float32, N} where {N}
sizes::Vector{Int64} = [reverse(size(input))...]

@ccall LIBRARY.compas_make_array_float(
ptr = @ccall LIBRARY.compas_make_array_float(
context.ptr::Ptr{Cvoid},
pointer(input)::Ptr{Float32},
N::Int32,
pointer(sizes)::Ptr{Int64}
)::Ptr{Cvoid}

return CompasArray{Float32, N}(context, ptr, Dims{N}(sizes))
end

function collect(input::CompasArray{ComplexF32, N}) where {N}
result = Array{ComplexF32, N}(undef, input.sizes...)
@ccall LIBRARY.compas_read_array_complex(
input.context::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{ComplexF32}
)::Cvoid
return result
end

function collect(input::CompasArray{Float32, N}) where {N}
result = Array{Float32, N}(undef, input.sizes...)
@ccall LIBRARY.compas_read_array_float(
input.context::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{Float32}
)::Cvoid
return result
end

function make_array(context::Context, input::Array{ComplexF32, N})::CompasArray{ComplexF32, N} where {N}
sizes::Vector{Int64} = [reverse(size(input))...]

@ccall LIBRARY.compas_make_array_complex(
ptr = @ccall LIBRARY.compas_make_array_complex(
context.ptr::Ptr{Cvoid},
pointer(input)::Ptr{ComplexF32},
N::Int32,
pointer(sizes)::Ptr{Int64}
)::Cvoid
)::Ptr{Cvoid}

return CompasArray{ComplexF32, N}(context, ptr, Dims{N}(sizes))
end

function assert_size(input::AbstractArray, expected::Dims{N}) where {N}
Expand All @@ -76,7 +101,7 @@ function assert_size(input::AbstractArray, expected::Dims{N}) where {N}
end
end

function convert_array(
function convert_array_old(
ty::Type{T},
dims::Dims{N},
input::Array{T,N},
Expand All @@ -85,7 +110,7 @@ function convert_array(
return input
end

function convert_array(
function convert_array_old(
ty::Type{T},
dims::Dims{N},
input::AbstractArray,
Expand All @@ -94,10 +119,20 @@ function convert_array(
return convert(Array{T,N}, input)
end

function convert_array(ty::Type{T}, dims::Dims{N}, input::Number)::Array{T,N} where {T,N}
function convert_array_old(ty::Type{T}, dims::Dims{N}, input::Number)::Array{T,N} where {T,N}
return fill(convert(ty, input), dims)
end

function convert_array(
context::Context,
ty::Type{T},
dims::Dims{N},
input::Array{T,N},
)::CompasArray{T,N} where {T,N}
assert_size(input, dims)
return make_array(context, input)
end

function unsafe_destroy_object!(obj)
@ccall LIBRARY.compas_destroy(obj.ptr::Ptr{Cvoid})::Cvoid
end
Expand All @@ -117,7 +152,7 @@ mutable struct CartesianTrajectory <: Trajectory
k_start::AbstractVector,
delta_k::Number,
)
k_start = convert_array(ComplexF32, (nreadouts,), k_start)
k_start = convert_array_old(ComplexF32, (nreadouts,), k_start)

ptr = @ccall LIBRARY.compas_make_cartesian_trajectory(
pointer(context)::Ptr{Cvoid},
Expand Down Expand Up @@ -146,8 +181,8 @@ mutable struct SpiralTrajectory <: Trajectory
k_start::AbstractVector,
delta_k::AbstractVector,
)
k_start = convert_array(ComplexF32, (nreadouts,), k_start)
delta_k = convert_array(ComplexF32, (nreadouts,), delta_k)
k_start = convert_array_old(ComplexF32, (nreadouts,), k_start)
delta_k = convert_array_old(ComplexF32, (nreadouts,), delta_k)

ptr = @ccall LIBRARY.compas_make_spiral_trajectory(
pointer(context)::Ptr{Cvoid},
Expand Down Expand Up @@ -181,15 +216,15 @@ mutable struct TissueParameters
y::AbstractVector,
z::AbstractVector,
)
T1 = convert_array(Float32, (nvoxels,), T1)
T2 = convert_array(Float32, (nvoxels,), T2)
B1 = convert_array(Float32, (nvoxels,), B1)
B0 = convert_array(Float32, (nvoxels,), B0)
rho_x = convert_array(Float32, (nvoxels,), rho_x)
rho_y = convert_array(Float32, (nvoxels,), rho_y)
x = convert_array(Float32, (nvoxels,), x)
y = convert_array(Float32, (nvoxels,), y)
z = convert_array(Float32, (nvoxels,), z)
T1 = convert_array_old(Float32, (nvoxels,), T1)
T2 = convert_array_old(Float32, (nvoxels,), T2)
B1 = convert_array_old(Float32, (nvoxels,), B1)
B0 = convert_array_old(Float32, (nvoxels,), B0)
rho_x = convert_array_old(Float32, (nvoxels,), rho_x)
rho_y = convert_array_old(Float32, (nvoxels,), rho_y)
x = convert_array_old(Float32, (nvoxels,), x)
y = convert_array_old(Float32, (nvoxels,), y)
z = convert_array_old(Float32, (nvoxels,), z)


ptr = @ccall LIBRARY.compas_make_tissue_parameters(
Expand Down Expand Up @@ -217,8 +252,12 @@ mutable struct TissueParameters
end

mutable struct FispSequence
ptr::Ptr{Cvoid}
nreadouts::Int32
RF_train::CompasArray{ComplexF32, 1}
slice_profiles::CompasArray{ComplexF32, 2}
TR::Float32
TE::Float32
max_state::Int32
TI::Float32

function FispSequence(
context::Context,
Expand All @@ -232,29 +271,22 @@ mutable struct FispSequence
nreadouts = size(RF_train, 1)
nslices = size(slice_profiles, 2)

RF_train = convert_array(ComplexF32, (nreadouts,), RF_train)
slice_profiles = convert_array(ComplexF32, (nreadouts, nslices), slice_profiles)
RF_train = convert_array_old(ComplexF32, (nreadouts,), RF_train)
slice_profiles = convert_array_old(ComplexF32, (nreadouts, nslices), slice_profiles)

ptr = @ccall LIBRARY.compas_make_fisp_sequence(
pointer(context)::Ptr{Cvoid},
nreadouts::Int32,
nslices::Int32,
pointer(RF_train)::Ptr{ComplexF32},
pointer(slice_profiles)::Ptr{ComplexF32},
TR::Float32,
TE::Float32,
max_state::Int32,
TI::Float32,
)::Ptr{Cvoid}

obj = new(ptr, nreadouts)
finalizer(unsafe_destroy_object!, obj)
return new(RF_train, slice_profiles, TR, TE, max_state, TI)
end
end

mutable struct pSSFPSequence
ptr::Ptr{Cvoid}
nreadouts::Int32
RF_train::CompasArray{ComplexF32, 1}
TR::Float32
nRF::Int32
nTR::Int32
gamma_dt_GRz::CompasArray{ComplexF32, 1}
dt::NTuple{3, Float32}
gamma_dt_GRz::NTuple{3, Float32}
z::CompasArray{Float32, 1}

function pSSFPSequence(
context::Context,
Expand All @@ -269,29 +301,11 @@ mutable struct pSSFPSequence
nRF = size(gamma_dt_RF, 1)
nslices = size(z, 1)

RF_train = convert_array(ComplexF32, (nreadouts,), RF_train)
gamma_dt_RF = convert_array(ComplexF32, (nRF,), gamma_dt_RF)
z = convert_array(Float32, (nslices,), z)
RF_train = convert_array_old(ComplexF32, (nreadouts,), RF_train)
gamma_dt_RF = convert_array_old(ComplexF32, (nRF,), gamma_dt_RF)
z = convert_array_old(Float32, (nslices,), z)

ptr = @ccall LIBRARY.compas_make_pssfp_sequence(
pointer(context)::Ptr{Cvoid},
nRF::Int32,
nreadouts::Int32,
nslices::Int32,
pointer(RF_train)::Ptr{ComplexF32},
TR::Float32,
pointer(gamma_dt_RF)::Ptr{ComplexF32},
dt[1]::Float32,
dt[2]::Float32,
dt[3]::Float32,
gamma_dt_GRz[1]::Float32,
gamma_dt_GRz[2]::Float32,
gamma_dt_GRz[3]::Float32,
pointer(z)::Ptr{Float32},
)::Ptr{Cvoid}

obj = new(ptr, nreadouts)
finalizer(unsafe_destroy_object!, obj)
return new(RF_train, TR, nRF, nTR, gamma_dt_RF, dt, gamma_dt_GRz, z)
end
end

Expand All @@ -303,13 +317,18 @@ function simulate_magnetization(
)
nvoxels::Int64 = parameters.nvoxels
nreadouts::Int64 = sequence.nreadouts
echos = convert_array(ComplexF32, (nvoxels, nreadouts), echos)
echos = convert_array_old(ComplexF32, (nvoxels, nreadouts), echos)

@ccall LIBRARY.compas_simulate_magnetization_fisp(
pointer(context)::Ptr{Cvoid},
pointer(echos)::Ptr{ComplexF32},
parameters.ptr::Ptr{Cvoid},
sequence.ptr::Ptr{Cvoid},
sequence.RF_train.ptr::Ptr{Cvoid},
sequence.slice_profiles.ptr::Ptr{Cvoid},
sequence.TR,
sequence.TE,
sequence.max_state,
sequence.TI
)::Cvoid

return echos
Expand All @@ -323,21 +342,28 @@ function simulate_magnetization(
)
nvoxels::Int64 = parameters.nvoxels
nreadouts::Int64 = sequence.nreadouts
echos = convert_array(ComplexF32, (nvoxels, nreadouts), echos)
echos = convert_array_old(ComplexF32, (nvoxels, nreadouts), echos)

@ccall LIBRARY.compas_simulate_magnetization_pssfp(
pointer(context)::Ptr{Cvoid},
pointer(echos)::Ptr{ComplexF32},
parameters.ptr::Ptr{Cvoid},
sequence.ptr::Ptr{Cvoid},
sequence.RF_train.ptr::Ptr{Cvoid},
sequence.TR::Float32,
sequence.gamma_dt_RF.ptr::Ptr{Cvoid},
sequence.dt[0]::Float32,
sequence.dt[1]::Float32,
sequence.dt[2]::Float32,
sequence.gamma_dt_GRz[0]::Float32,
sequence.gamma_dt_GRz[1]::Float32,
sequence.gamma_dt_GRz[2]::Float32
)::Cvoid

return echos
end

function magnetization_to_signal(
context::Context,
signal::AbstractArray{<:Any,3},
echos::AbstractMatrix,
parameters::TissueParameters,
trajectory::Trajectory,
Expand All @@ -348,21 +374,19 @@ function magnetization_to_signal(
samples_per_readout::Int64 = trajectory.samples_per_readout
nvoxels::Int64 = parameters.nvoxels

signal = convert_array(ComplexF32, (samples_per_readout, nreadouts, ncoils), signal)
echos = convert_array(ComplexF32, (nvoxels, nreadouts), echos)
coils = convert_array(Float32, (nvoxels, ncoils), coils)
echos = convert_array(context, ComplexF32, (nvoxels, nreadouts), echos)
coils = convert_array(context, Float32, (nvoxels, ncoils), coils)

@ccall LIBRARY.compas_magnetization_to_signal(
signal_ptr = @ccall LIBRARY.compas_magnetization_to_signal(
pointer(context)::Ptr{Cvoid},
ncoils::Int32,
pointer(signal)::Ptr{ComplexF32},
pointer(echos)::Ptr{ComplexF32},
pointer(echos)::Ptr{Cvoid},
parameters.ptr::Ptr{Cvoid},
trajectory.ptr::Ptr{Cvoid},
pointer(coils)::Ptr{Float32}
)::Cvoid
)::Ptr{Cvoid}

return signal
return CompasArray{ComplexF32, 3}(context, signal_ptr, (ncoils, nreadouts, samples_per_readout))
end

Base.pointer(c::Context) = c.ptr
Expand Down
Loading

0 comments on commit eb6ea53

Please sign in to comment.