Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Waveform Relaxation for Gap Junctions #1810

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
53d18e0
Add WR for gap junction mechanism
kanzl Jan 17, 2022
27877f8
Merge branch 'gj-wfr' of github.com:kanzl/arbor into gj-wfr
kanzl Jan 17, 2022
e1529cb
Modified remaining_steps reset
kanzl Jan 17, 2022
9ed3fb1
Merge branch 'arbor-sim:master' into gj-wfr
kanzl Jan 21, 2022
a3a5035
Map CVs in peer_index to position in traces_v for previous voltage fe…
kanzl Jan 21, 2022
4df45ad
Reset remaining_steps at the beginning of each WR iteration
kanzl Jan 21, 2022
556dc77
Increased number of iterations in Waveform Relaxation for testing
kanzl Jan 21, 2022
1db4211
Clean up traces
kanzl Jan 21, 2022
993b915
Add break condition for WR
kanzl Jan 23, 2022
6a01255
Merge branch 'arbor-sim:master' into gj-wfr
kanzl Jan 23, 2022
c5bd170
Merge branch 'gj-wfr' of github.com:kanzl/arbor into gj-wfr
kanzl Jan 23, 2022
0607a09
Fix error calculation for WR break condition
kanzl Jan 24, 2022
5c5c6d2
Fix break condition
kanzl Jan 25, 2022
c726e16
Merge
kanzl Jan 25, 2022
5eb3aea
Sync
kanzl Jan 25, 2022
9d3c9c7
Clean up
kanzl Jan 25, 2022
34b8fc0
Merge branch 'arbor-sim:master' into gj-wfr
kanzl Jan 29, 2022
21da1ff
Add state reset
kanzl Jan 29, 2022
1fc5a22
Merge branch 'gj-wfr' of github.com:kanzl/arbor into gj-wfr
kanzl Jan 29, 2022
c30cf63
Fix peer voltage and break condition.
kanzl Feb 4, 2022
e9b7fd7
Clean up output.
kanzl Feb 4, 2022
d088c8b
Merge branch 'arbor-sim:master' into gj-wfr
kanzl Mar 23, 2022
e8f2499
domain decomposition gap junctions
kanzl Mar 23, 2022
e3cd28f
Merge branch 'arbor-sim:master' into gj-wfr
kanzl Apr 5, 2022
b0e7042
Setup infrastructure for gap junctions spanning different cell groups
kanzl May 17, 2022
2d49ff2
Merge branch 'arbor-sim:master' into gj-wfr
kanzl May 17, 2022
030c9f2
update output
kanzl May 17, 2022
6b91874
Merge remote-tracking branch 'upstream/master' into gj-wfr
kanzl May 17, 2022
889686b
Merge remote-tracking branch 'origin/gj-wfr' into gj-wfr
kanzl May 17, 2022
d9362f4
Updated maps for switching between local and global cv index
kanzl Jun 8, 2022
f1e799b
Fix differentiation between groups for peer index calc
kanzl Jun 14, 2022
3169d0e
remove peer index reset
kanzl Jun 21, 2022
2caadb6
Adapt trace structure to state_->voltage, fix example, gather traces
kanzl Jun 27, 2022
a2772c3
Global resolution map
kanzl Jun 29, 2022
23fa12a
debugging
kanzl Aug 5, 2022
18c4c34
Change node index for vec_v
kanzl Aug 8, 2022
64813b7
Add cell group of peer to ppack
kanzl Aug 15, 2022
52c23c7
fix peer_cg and trace
kanzl Aug 15, 2022
ad48fa0
backup
kanzl Aug 23, 2022
f3e1bab
Ignore redundant time steps and fix trace gather
kanzl Aug 24, 2022
b7dae2c
Index fix
kanzl Sep 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion arbor/backends/multicore/shared_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <utility>
#include <vector>

#include <iostream>

#include <arbor/assert.hpp>
#include <arbor/common_types.hpp>
#include <arbor/constants.hpp>
Expand Down Expand Up @@ -479,6 +481,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o
m.ppack_.vec_di = cv_to_intdom.data();
m.ppack_.vec_dt = dt_cv.data();
m.ppack_.vec_v = voltage.data();
m.ppack_.vec_v_peer = voltage.data();
m.ppack_.vec_i = current_density.data();
m.ppack_.vec_g = conductivity.data();
m.ppack_.temperature_degC = temperature_degC.data();
Expand Down Expand Up @@ -549,7 +552,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o
{
// Allocate bulk storage
std::size_t index_width_padded = extend_width<arb_index_type>(m, pos_data.cv.size());
std::size_t count = mult_in_place + peer_indices + m.mech_.n_ions + 1;
std::size_t count = mult_in_place + peer_indices + m.mech_.n_ions + 2;
store.indices_ = iarray(count*index_width_padded, 0, pad);
chunk_writer writer(store.indices_.data(), index_width_padded);
// Setup node indices
Expand Down Expand Up @@ -587,6 +590,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o
// Peer CVs are only filled for gap junction mechanisms. They are used
// to index the voltage at the other side of a gap-junction connection.
if (peer_indices) m.ppack_.peer_index = writer.append(pos_data.peer_cv, pos_data.peer_cv.back());
if (peer_indices) m.ppack_.peer_cg = writer.append(pos_data.peer_cg, 0);
}
}

Expand Down
6 changes: 3 additions & 3 deletions arbor/cell_group_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ cell_group_ptr make_cell_group(Args&&... args) {
}

ARB_ARBOR_API cell_group_factory cell_kind_implementation(
cell_kind ck, backend_kind bk, const execution_context& ctx)
cell_kind ck, cell_gid_type cg, backend_kind bk, const execution_context& ctx)
{
using gid_vector = std::vector<cell_gid_type>;

switch (ck) {
case cell_kind::cable:
return [bk, ctx](const gid_vector& gids, const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets) {
return make_cell_group<mc_cell_group>(gids, rec, cg_sources, cg_targets, make_fvm_lowered_cell(bk, ctx));
return [bk, ctx, cg](const gid_vector& gids, const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets) {
return make_cell_group<mc_cell_group>(gids, rec, cg_sources, cg_targets, cg, make_fvm_lowered_cell(bk, ctx, cg));
};

case cell_kind::spike_source:
Expand Down
6 changes: 3 additions & 3 deletions arbor/cell_group_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
namespace arb {

using cell_group_factory = std::function<
cell_group_ptr(const std::vector<cell_gid_type>&, const recipe&, cell_label_range& cg_sources, cell_label_range& cg_targets)>;
cell_group_ptr(const std::vector<cell_gid_type>&, const recipe&, cell_label_range& cg_sources, cell_label_range& cg_targets)>;

ARB_ARBOR_API cell_group_factory cell_kind_implementation(
cell_kind, backend_kind, const execution_context&);
cell_kind, cell_gid_type, backend_kind, const execution_context&);

inline bool cell_kind_supported(
cell_kind c, backend_kind b, const execution_context& ctx)
{
return static_cast<bool>(cell_kind_implementation(c, b, ctx));
return static_cast<bool>(cell_kind_implementation(c, 0, b, ctx));
}

} // namespace arb
10 changes: 10 additions & 0 deletions arbor/communication/dry_run_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ struct dry_run_context_impl {
return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition));
}

std::vector<int>
gather_cg_cv_map(const std::vector<int>& local_map) const {
return {};
}

std::vector<double>
gather_trace(const std::vector<double>& trace) const {
return {};
}

std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>> & local_connections) const {
auto local_size = local_connections.size();
Expand Down
10 changes: 10 additions & 0 deletions arbor/communication/mpi_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ struct mpi_context_impl {
return mpi::gather_all_with_partition(local_gids, comm_);
}

std::vector<int>
gather_cg_cv_map(const std::vector<int>& cg_cv_map) const {
return mpi::gather_all(cg_cv_map, comm_);
}

std::vector<double>
gather_trace(const std::vector<double>& trace) const {
return mpi::gather_all(trace, comm_);
}

std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
return mpi::gather_all(local_connections, comm_);
Expand Down
32 changes: 32 additions & 0 deletions arbor/distributed_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <memory>
#include <string>
#include <iostream>

#include <arbor/export.hpp>
#include <arbor/spike.hpp>
Expand Down Expand Up @@ -66,6 +67,14 @@ class distributed_context {
return impl_->gather_gids(local_gids);
}

std::vector<int> gather_cg_cv_map(const std::vector<int>& cg_cv_map) const {
return impl_->gather_cg_cv_map(cg_cv_map);
}

std::vector<double> gather_trace(const std::vector<double>& trace) const {
return impl_->gather_trace(trace);
}

gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const {
return impl_->gather_gj_connections(local_connections);
}
Expand Down Expand Up @@ -106,6 +115,10 @@ class distributed_context {
gather_spikes(const spike_vector& local_spikes) const = 0;
virtual gathered_vector<cell_gid_type>
gather_gids(const gid_vector& local_gids) const = 0;
virtual std::vector<int>
gather_cg_cv_map(const std::vector<int>& cg_cv_map) const = 0;
virtual std::vector<double>
gather_trace(const std::vector<double>& trace) const = 0;
virtual gj_connection_vector
gather_gj_connections(const gj_connection_vector& local_connections) const = 0;
virtual cell_label_range
Expand Down Expand Up @@ -137,10 +150,21 @@ class distributed_context {
gather_gids(const gid_vector& local_gids) const override {
return wrapped.gather_gids(local_gids);
}
std::vector<int>
gather_cg_cv_map(const std::vector<int>& cg_cv_map) const override {
//std::cout << wrapped.id() << " Gather CG CV Map\n";
return wrapped.gather_cg_cv_map(cg_cv_map);
}
std::vector<double>
gather_trace(const std::vector<double>& trace) const override {
return wrapped.gather_trace(trace);
}
std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const gj_connection_vector& local_connections) const override {
//std::cout << wrapped.id() << " Gather GJ Map\n";
return wrapped.gather_gj_connections(local_connections);
}
//cell_label_range includes sizes, labels, ranges
cell_label_range
gather_cell_label_range(const cell_label_range& local_ranges) const override {
return wrapped.gather_cell_label_range(local_ranges);
Expand Down Expand Up @@ -191,6 +215,14 @@ struct local_context {
{0u, static_cast<count_type>(local_gids.size())}
);
}
std::vector<int>
gather_cg_cv_map(const std::vector<int>& cg_cv_map) const {
return {};
}
std::vector<double>
gather_trace(const std::vector<double>& cg_cv_map) const {
return {};
}
std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
return local_connections;
Expand Down
3 changes: 2 additions & 1 deletion arbor/domain_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ domain_decomposition::domain_decomposition(
}
for (const auto& gj: rec.gap_junctions_on(gid)) {
if (!gid_set.count(gj.peer.gid)) {
throw invalid_gj_cell_group(gid, gj.peer.gid);
//throw invalid_gj_cell_group(gid, gj.peer.gid);
std::cerr << "Warning: Need to use Waveform Relaxation.\n";
}
}
}
Expand Down
162 changes: 161 additions & 1 deletion arbor/fvm_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "util/transform.hpp"
#include "util/unique.hpp"

#include <iostream>

namespace arb {

using util::assign;
Expand Down Expand Up @@ -626,6 +628,7 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r
append(L.multiplicity, R.multiplicity);
append(L.norm_area, R.norm_area);
append(L.local_weight, R.local_weight);
append(L.peer_cg, R.peer_cg);
append_offset(L.target, target_offset, R.target);

arb_assert(util::equal(L.param_values, R.param_values,
Expand Down Expand Up @@ -655,6 +658,124 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r
return left;
}

// build cg_cv_map = { {gid, lid} -> {cg, cv} }
// gather function:
// 1) take cg_cv_map & split into 4 arrays (gids, lids, cgs, cvs)
// 2) gather separately
// 3) update index map function { {gid, cg, cv} ->index } with gathered gids, cgs, cvs
// 4) use index map as gj_cvs map equivalent


// 1) split cg_cv_map into 4 arrays instead of unordered_map
ARB_ARBOR_API std::vector<std::vector<int>> fvm_build_gap_junction_cv_arr(
const std::vector<cable_cell>& cells,
const std::vector<cell_gid_type>& gids,
unsigned cg,
const fvm_cv_discretization& D)
{
std::vector<int> gid, lids, cgs, cvs;
arb_assert(cells.size() == gids.size());
std::unordered_map<cell_member_type, cell_member_type> gj_cg_cvs;
for (auto cell_idx: util::make_span(0, cells.size())) {
for (const auto& mech : cells[cell_idx].junctions()) {
for (const auto& gj: mech.second) {
gid.push_back(gids[cell_idx]);
lids.push_back(gj.lid);
cgs.push_back(cg);
cvs.push_back(D.geometry.location_cv(cell_idx, gj.loc, cv_prefer::cv_nonempty));
}
}
}
return {gid, lids, cgs, cvs};
}

//3) index map function
using cell_id = std::tuple<int, int, int, int>;

ARB_ARBOR_API std::map<std::tuple<int, int>, int> fvm_cell_to_index_lowered(
const std::vector<int>& cgs,
const std::vector<int>& cvs
)
{
std::map<std::tuple<int, int>, int> cell_to_index;

for (int i = 0; i<cgs.size(); ++i) {
std::tuple<int, int> element{cvs[i], cgs[i]};
cell_to_index[element] = i;
}

return cell_to_index;
}

ARB_ARBOR_API std::map<cell_id, int> fvm_cell_to_index(
const std::vector<int>& gids,
const std::vector<int>& cgs,
const std::vector<int>& cvs,
const std::vector<int>& lids
)
{
std::map<cell_id, int> cell_to_index;

for (int i = 0; i<gids.size(); ++i) {
cell_id element{gids[i], cgs[i], cvs[i], lids[i]};
cell_to_index[element] = i;
}

return cell_to_index;
}

ARB_ARBOR_API std::map<int, cell_id> fvm_index_to_cell(
std::map<cell_id, int>& cell_to_index
)
{
std::map<int, cell_id> index_to_cell;

for (const auto& [value, index]: cell_to_index) {
index_to_cell[index] = value;
}

return index_to_cell;
}

/*
ARB_ARBOR_API std::unordered_map<cell_member_type, fvm_size_type> fvm_index_to_cv_map(
const std::vector<int>& gids,
const std::vector<int>& lids,
const std::vector<int>& cgs,
const std::vector<int>& cvs,
const std::map<cell_id, int>& cell_to_index
)
{
std::unordered_map<cell_member_type, fvm_size_type> gj_cvs_index;
for (int i = 0; i<gids.size(); ++i){

cell_id cell = {gids[i], cgs[i], cvs[i], lids[i]};
int index = cell_to_index.at(cell);
//gj_cvs_index.insert({cell_member_type{unsigned(gids[i]), unsigned(lids[i])}, unsigned(index)});
gj_cvs_index.insert({cell_member_type{unsigned(gids[i]), unsigned(lids[i])}, unsigned(cvs[i])});
}

return gj_cvs_index;
}*/

ARB_ARBOR_API std::unordered_map<cell_member_type, cell_member_type> fvm_index_to_cv_map(
const std::vector<int>& gids,
const std::vector<int>& lids,
const std::vector<int>& cgs,
const std::vector<int>& cvs,
const std::map<cell_id, int>& cell_to_index
)
{
std::unordered_map<cell_member_type, cell_member_type> gj_cvs_index;
for (int i = 0; i<gids.size(); ++i){
cell_id cell = {gids[i], cgs[i], cvs[i], lids[i]};
gj_cvs_index.insert({cell_member_type{unsigned(gids[i]), unsigned(lids[i])}, cell_member_type{unsigned(cvs[i]), unsigned(cgs[i])}});
}

return gj_cvs_index;
}

/*
ARB_ARBOR_API std::unordered_map<cell_member_type, fvm_size_type> fvm_build_gap_junction_cv_map(
const std::vector<cable_cell>& cells,
const std::vector<cell_gid_type>& gids,
Expand All @@ -670,8 +791,10 @@ ARB_ARBOR_API std::unordered_map<cell_member_type, fvm_size_type> fvm_build_gap_
}
}
return gj_cvs;
}
}*/

//make resolution map with gj_data and gids global -> maybe try to gather cell_label_range gj_data
/*
ARB_ARBOR_API std::unordered_map<cell_gid_type, std::vector<fvm_gap_junction>> fvm_resolve_gj_connections(
const std::vector<cell_gid_type>& gids,
const cell_label_range& gj_data,
Expand All @@ -692,10 +815,44 @@ ARB_ARBOR_API std::unordered_map<cell_gid_type, std::vector<fvm_gap_junction>> f
auto peer_cv = gj_cvs.at({conn.peer.gid, peer_idx});

local_conns.push_back({local_idx, local_cv, peer_cv, conn.weight});
//std::cout <<"local_idx = " << local_idx << " local_cv = " << local_cv << " peer_cv = " << peer_cv << " conn.weight = " << conn.weight << std::endl;
}
// Sort local_conns by local_cv.
util::sort(local_conns);
gj_conns[gid] = std::move(local_conns);
}
return gj_conns;
}*/


ARB_ARBOR_API std::unordered_map<cell_gid_type, std::vector<fvm_gap_junction>> fvm_resolve_gj_connections(
const std::vector<cell_gid_type>& gids,
const cell_label_range& gj_data,
const std::unordered_map<cell_member_type, cell_member_type>& gj_cvs,
const recipe& rec)
{
// Construct and resolve all gj_connections.
std::unordered_map<cell_gid_type, std::vector<fvm_gap_junction>> gj_conns;
label_resolution_map resolution_map({gj_data, gids});
auto gj_resolver = resolver(&resolution_map);
for (const auto& gid: gids) {
std::vector<fvm_gap_junction> local_conns;
for (const auto& conn: rec.gap_junctions_on(gid)) {
auto local_idx = gj_resolver.resolve({gid, conn.local});
auto peer_idx = gj_resolver.resolve(conn.peer);

auto local_cv = gj_cvs.at({gid, local_idx}).gid;
auto peer_cv = gj_cvs.at({conn.peer.gid, peer_idx}).gid;

auto peer_cg = gj_cvs.at({conn.peer.gid, peer_idx}).index;

local_conns.push_back({local_idx, local_cv, peer_cv, conn.weight, peer_cg});
//std::cout <<"local_idx = " << local_idx << " local_cv = " << local_cv << " peer_cv = " << peer_cv << " peer_cg = " << peer_cg << " conn.weight = " << conn.weight << std::endl;
}
// Sort local_conns by local_cv.
util::sort(local_conns);
gj_conns[gid] = std::move(local_conns);

}
return gj_conns;
}
Expand Down Expand Up @@ -1129,10 +1286,13 @@ fvm_mechanism_data fvm_build_mechanism_data(
config.cv.push_back(conn.local_cv);
config.peer_cv.push_back(conn.peer_cv);
config.local_weight.push_back(conn.weight);
config.peer_cg.push_back(conn.peer_cg);
//std::cout << "node cv = " << conn.local_cv <<"peer cv = " << conn.peer_cv <<"peer cg = " << config.peer_cg.back() << std::endl;
for (unsigned i = 0; i < local_junction_desc.param_values.size(); ++i) {
config.param_values[i].second.push_back(local_junction_desc.param_values[i]);
}
}


// Add non-empty fvm_mechanism_config to the fvm_mechanism_data
for (auto [name, config]: junction_configs) {
Expand Down
Loading