Skip to content

Commit

Permalink
Make compas work with latest KMM
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Feb 9, 2024
1 parent 49ec9ee commit 0096101
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 151 deletions.
112 changes: 55 additions & 57 deletions CompasToolkit.jl/src/CompasToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,6 @@ function make_array(context::Context, input::Array{Float32, N})::CompasArray{Flo
return CompasArray{Float32, N}(context, ptr, Dims{N}(sizes))
end

function collect(input::CompasArray{ComplexF32, N}) where {N}
result = Array{ComplexF32, N}(undef, input.sizes...)
@ccall LIBRARY.compas_read_array_complex(
input.context::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{ComplexF32}
)::Cvoid
return result
end

function collect(input::CompasArray{Float32, N}) where {N}
result = Array{Float32, N}(undef, input.sizes...)
@ccall LIBRARY.compas_read_array_float(
input.context::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{Float32}
)::Cvoid
return result
end

function make_array(context::Context, input::Array{ComplexF32, N})::CompasArray{ComplexF32, N} where {N}
sizes::Vector{Int64} = [reverse(size(input))...]

Expand All @@ -94,14 +74,36 @@ function make_array(context::Context, input::Array{ComplexF32, N})::CompasArray{
return CompasArray{ComplexF32, N}(context, ptr, Dims{N}(sizes))
end

function Base.collect(input::CompasArray{Float32, N}) where {N}
result = Array{Float32, N}(undef, reverse(input.sizes)...)
@ccall LIBRARY.compas_read_array_float(
pointer(input.context)::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{Float32},
length(result)::Int64
)::Cvoid
return result
end

function Base.collect(input::CompasArray{ComplexF32, N}) where {N}
result = Array{ComplexF32, N}(undef, reverse(input.sizes)...)
@ccall LIBRARY.compas_read_array_complex(
pointer(input.context)::Ptr{Cvoid},
input.ptr::Ptr{Cvoid},
pointer(result)::Ptr{ComplexF32},
length(result)::Int64
)::Cvoid
return result
end

function assert_size(input::AbstractArray, expected::Dims{N}) where {N}
gotten = size(input)
if gotten != expected
throw(ArgumentError("Invalid argument dimensions $gotten, should be $expected"))
end
end

function convert_array_old(
function convert_array_host(
ty::Type{T},
dims::Dims{N},
input::Array{T,N},
Expand All @@ -110,7 +112,7 @@ function convert_array_old(
return input
end

function convert_array_old(
function convert_array_host(
ty::Type{T},
dims::Dims{N},
input::AbstractArray,
Expand All @@ -119,7 +121,7 @@ function convert_array_old(
return convert(Array{T,N}, input)
end

function convert_array_old(ty::Type{T}, dims::Dims{N}, input::Number)::Array{T,N} where {T,N}
function convert_array_host(ty::Type{T}, dims::Dims{N}, input::Number)::Array{T,N} where {T,N}
return fill(convert(ty, input), dims)
end

Expand Down Expand Up @@ -152,7 +154,7 @@ mutable struct CartesianTrajectory <: Trajectory
k_start::AbstractVector,
delta_k::Number,
)
k_start = convert_array_old(ComplexF32, (nreadouts,), k_start)
k_start = convert_array_host(ComplexF32, (nreadouts,), k_start)

ptr = @ccall LIBRARY.compas_make_cartesian_trajectory(
pointer(context)::Ptr{Cvoid},
Expand Down Expand Up @@ -181,8 +183,8 @@ mutable struct SpiralTrajectory <: Trajectory
k_start::AbstractVector,
delta_k::AbstractVector,
)
k_start = convert_array_old(ComplexF32, (nreadouts,), k_start)
delta_k = convert_array_old(ComplexF32, (nreadouts,), delta_k)
k_start = convert_array_host(ComplexF32, (nreadouts,), k_start)
delta_k = convert_array_host(ComplexF32, (nreadouts,), delta_k)

ptr = @ccall LIBRARY.compas_make_spiral_trajectory(
pointer(context)::Ptr{Cvoid},
Expand Down Expand Up @@ -216,15 +218,15 @@ mutable struct TissueParameters
y::AbstractVector,
z::AbstractVector,
)
T1 = convert_array_old(Float32, (nvoxels,), T1)
T2 = convert_array_old(Float32, (nvoxels,), T2)
B1 = convert_array_old(Float32, (nvoxels,), B1)
B0 = convert_array_old(Float32, (nvoxels,), B0)
rho_x = convert_array_old(Float32, (nvoxels,), rho_x)
rho_y = convert_array_old(Float32, (nvoxels,), rho_y)
x = convert_array_old(Float32, (nvoxels,), x)
y = convert_array_old(Float32, (nvoxels,), y)
z = convert_array_old(Float32, (nvoxels,), z)
T1 = convert_array_host(Float32, (nvoxels,), T1)
T2 = convert_array_host(Float32, (nvoxels,), T2)
B1 = convert_array_host(Float32, (nvoxels,), B1)
B0 = convert_array_host(Float32, (nvoxels,), B0)
rho_x = convert_array_host(Float32, (nvoxels,), rho_x)
rho_y = convert_array_host(Float32, (nvoxels,), rho_y)
x = convert_array_host(Float32, (nvoxels,), x)
y = convert_array_host(Float32, (nvoxels,), y)
z = convert_array_host(Float32, (nvoxels,), z)


ptr = @ccall LIBRARY.compas_make_tissue_parameters(
Expand Down Expand Up @@ -271,8 +273,8 @@ mutable struct FispSequence
nreadouts = size(RF_train, 1)
nslices = size(slice_profiles, 2)

RF_train = convert_array_old(ComplexF32, (nreadouts,), RF_train)
slice_profiles = convert_array_old(ComplexF32, (nreadouts, nslices), slice_profiles)
RF_train = convert_array(context, ComplexF32, (nreadouts,), RF_train)
slice_profiles = convert_array(context, ComplexF32, (nreadouts, nslices), slice_profiles)

return new(RF_train, slice_profiles, TR, TE, max_state, TI)
end
Expand All @@ -283,7 +285,7 @@ mutable struct pSSFPSequence
TR::Float32
nRF::Int32
nTR::Int32
gamma_dt_GRz::CompasArray{ComplexF32, 1}
gamma_dt_RF::CompasArray{ComplexF32, 1}
dt::NTuple{3, Float32}
gamma_dt_GRz::NTuple{3, Float32}
z::CompasArray{Float32, 1}
Expand All @@ -301,9 +303,9 @@ mutable struct pSSFPSequence
nRF = size(gamma_dt_RF, 1)
nslices = size(z, 1)

RF_train = convert_array_old(ComplexF32, (nreadouts,), RF_train)
gamma_dt_RF = convert_array_old(ComplexF32, (nRF,), gamma_dt_RF)
z = convert_array_old(Float32, (nslices,), z)
RF_train = convert_array(context, ComplexF32, (nreadouts,), RF_train)
gamma_dt_RF = convert_array(context, ComplexF32, (nRF,), gamma_dt_RF)
z = convert_array(context, Float32, (nslices,), z)

return new(RF_train, TR, nRF, nTR, gamma_dt_RF, dt, gamma_dt_GRz, z)
end
Expand All @@ -314,39 +316,35 @@ function simulate_magnetization(
echos::AbstractMatrix,
parameters::TissueParameters,
sequence::FispSequence,
)
)::CompasArray{ComplexF32, 2}
nvoxels::Int64 = parameters.nvoxels
nreadouts::Int64 = sequence.nreadouts
echos = convert_array_old(ComplexF32, (nvoxels, nreadouts), echos)

@ccall LIBRARY.compas_simulate_magnetization_fisp(
pointer(context)::Ptr{Cvoid},
pointer(echos)::Ptr{ComplexF32},
parameters.ptr::Ptr{Cvoid},
sequence.RF_train.ptr::Ptr{Cvoid},
sequence.slice_profiles.ptr::Ptr{Cvoid},
sequence.TR,
sequence.TE,
sequence.max_state,
sequence.TI
)::Cvoid
sequence.TR::Float32,
sequence.TE::Float32,
sequence.max_state::Int32,
sequence.TI::Float32
)::Ptr{Cvoid}

return echos
return CompasArray{ComplexF32, 2}(context, echos_ptr, (nreadouts, nvoxels))
end

function simulate_magnetization(
context::Context,
echos::AbstractMatrix,
parameters::TissueParameters,
sequence::pSSFPSequence,
)
)::CompasArray{ComplexF32, 2}
nvoxels::Int64 = parameters.nvoxels
nreadouts::Int64 = sequence.nreadouts
echos = convert_array_old(ComplexF32, (nvoxels, nreadouts), echos)

@ccall LIBRARY.compas_simulate_magnetization_pssfp(
echos_ptr = @ccall LIBRARY.compas_simulate_magnetization_pssfp(
pointer(context)::Ptr{Cvoid},
pointer(echos)::Ptr{ComplexF32},
parameters.ptr::Ptr{Cvoid},
sequence.RF_train.ptr::Ptr{Cvoid},
sequence.TR::Float32,
Expand All @@ -357,9 +355,9 @@ function simulate_magnetization(
sequence.gamma_dt_GRz[0]::Float32,
sequence.gamma_dt_GRz[1]::Float32,
sequence.gamma_dt_GRz[2]::Float32
)::Cvoid
)::Ptr{Cvoid}

return echos
return CompasArray{ComplexF32, 2}(context, echos_ptr, (nreadouts, nvoxels))
end

function magnetization_to_signal(
Expand All @@ -368,7 +366,7 @@ function magnetization_to_signal(
parameters::TissueParameters,
trajectory::Trajectory,
coils::AbstractMatrix,
)
)::CompasArray{ComplexF32, 3}
ncoils = size(coils, 2)
nreadouts::Int64 = trajectory.nreadouts
samples_per_readout::Int64 = trajectory.samples_per_readout
Expand Down
71 changes: 28 additions & 43 deletions julia-bindings/src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ extern "C" const kmm::ArrayBase* compas_make_array_float(
extern "C" void compas_read_array_float(
const compas::CudaContext* context,
const kmm::ArrayBase* input_array,
float* dest_ptr) {
catch_exceptions([&]() -> kmm::ArrayBase* {
input_array->block()->read(dest_ptr, input_array->size() * sizeof(float));
float* dest_ptr,
int64_t length) {
catch_exceptions([&]() {
size_t num_bytes = kmm::checked_mul(kmm::checked_cast<size_t>(length), sizeof(float));
input_array->read_bytes(dest_ptr, num_bytes);
});
}

Expand All @@ -94,9 +96,11 @@ extern "C" const kmm::ArrayBase* compas_make_array_complex(
extern "C" void compas_read_array_complex(
const compas::CudaContext* context,
const kmm::ArrayBase* input_array,
compas::cfloat* dest_ptr) {
catch_exceptions([&]() -> kmm::ArrayBase* {
input_array->block()->read(dest_ptr, input_array->size() * sizeof(compas::cfloat));
compas::cfloat* dest_ptr,
int64_t length) {
catch_exceptions([&]() {
size_t num_bytes = kmm::checked_mul(kmm::checked_cast<size_t>(length), 2 * sizeof(float));
input_array->read_bytes(dest_ptr, num_bytes);
});
}

Expand Down Expand Up @@ -174,46 +178,34 @@ extern "C" const compas::TissueParameters* compas_make_tissue_parameters(
});
}

extern "C" void compas_simulate_magnetization_fisp(
extern "C" kmm::ArrayBase* compas_simulate_magnetization_fisp(
const compas::CudaContext* context,
compas::cfloat* echos_ptr,
const compas::TissueParameters* parameters,
compas::CudaArray<compas::cfloat>* RF_train,
compas::CudaArray<compas::cfloat, 2>* sliceprofiles,
float TR,
float TE,
int max_state,
float TI
) {
float TI) {
return catch_exceptions([&] {
int nreadouts = RF_train->size();
int nvoxels = parameters->nvoxels;

auto echos = make_view(echos_ptr, nreadouts, nvoxels);
auto d_echos = compas::CudaArray<compas::cfloat, 2>(nreadouts, nvoxels);

auto sequence = compas::FISPSequence {
*RF_train,
*sliceprofiles,
TR,
TE,
max_state,
TI
};
auto* echos = new compas::CudaArray<compas::cfloat, 2>(nreadouts, nvoxels);
auto sequence = compas::FISPSequence {*RF_train, *sliceprofiles, TR, TE, max_state, TI};

context->submit_device(
compas::simulate_magnetization_fisp,
write(d_echos),
write(*echos),
*parameters,
sequence);

d_echos.read(echos);
return echos;
});
}

extern "C" void compas_simulate_magnetization_pssfp(
extern "C" kmm::ArrayBase* compas_simulate_magnetization_pssfp(
const compas::CudaContext* context,
compas::cfloat* echos_ptr,
const compas::TissueParameters* parameters,
const compas::CudaArray<compas::cfloat>* RF_train,
float TR,
Expand All @@ -228,26 +220,23 @@ extern "C" void compas_simulate_magnetization_pssfp(
return catch_exceptions([&] {
int nreadouts = RF_train->size();
int nvoxels = parameters->nvoxels;

auto echos = make_view(echos_ptr, nreadouts, nvoxels);
auto d_echos = context->allocate(echos);
auto* echos = new compas::CudaArray<compas::cfloat, 2>(nreadouts, nvoxels);

auto sequence = compas::pSSFPSequence {
*RF_train,
TR,
*gamma_dt_RF,
{dt_ex, dt_inv, dt_pr},
{gamma_dt_GRz_ex, gamma_dt_GRz_inv, gamma_dt_GRz_pr},
*z
};
*RF_train,
TR,
*gamma_dt_RF,
{dt_ex, dt_inv, dt_pr},
{gamma_dt_GRz_ex, gamma_dt_GRz_inv, gamma_dt_GRz_pr},
*z};

context->submit_device(
compas::simulate_magnetization_pssfp,
write(d_echos),
write(*echos),
*parameters,
sequence);

d_echos.read(echos);
return echos;
});
}

Expand All @@ -257,18 +246,14 @@ extern "C" kmm::ArrayBase* compas_magnetization_to_signal(
const compas::CudaArray<compas::cfloat, 2>* echos,
const compas::TissueParameters* parameters,
const compas::Trajectory* trajectory,
const float* coils_ptr) {
const compas::CudaArray<float, 2>* coils) {
return catch_exceptions([&] {
int nreadouts = trajectory->nreadouts;
int samples_per_readout = trajectory->samples_per_readout;
int nvoxels = parameters->nvoxels;

auto coils = make_view(coils_ptr, ncoils, nvoxels);

auto d_coils = context->allocate(coils);

auto signal =
compas::magnetization_to_signal(*context, *echos, *parameters, *trajectory, d_coils);
compas::magnetization_to_signal(*context, *echos, *parameters, *trajectory, *coils);

return new compas::CudaArray<compas::cfloat, 3>(signal);
});
Expand Down
Loading

0 comments on commit 0096101

Please sign in to comment.