Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Remove hardcoded celltypes #970

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from glob import glob
from warnings import warn

import numpy as np

from .viz import plot_spikes_hist, plot_spikes_raster
Expand Down Expand Up @@ -80,8 +79,8 @@ class CellResponse(object):
Write spiking activity to a collection of spike trial files.
"""

def __init__(self, spike_times=None, spike_gids=None, spike_types=None,
times=None, cell_type_names=None):
def __init__(self, cell_type_names, spike_times=None,
spike_gids=None, spike_types=None, times=None):
if spike_times is None:
spike_times = list()
if spike_gids is None:
Expand All @@ -91,10 +90,6 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None,
if times is None:
times = list()

if cell_type_names is None:
cell_type_names = ['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal']

# Validate arguments
arg_names = ['spike_times', 'spike_gids', 'spike_types']
for arg_idx, arg in enumerate([spike_times, spike_gids, spike_types]):
Expand Down Expand Up @@ -125,7 +120,8 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None,
raise TypeError("'times' is an np.ndarray of simulation times")
self._times = np.array(times)
self._cell_type_names = cell_type_names

self._gid_ranges = dict()

def __repr__(self):
class_name = self.__class__.__name__
n_trials = len(self._spike_times)
Expand Down Expand Up @@ -502,7 +498,12 @@ def read_spikes(fname, gid_ranges=None):
spike_gids += [list()]
spike_types += [list()]

cell_response = CellResponse(spike_times=spike_times,
network_cell_names = ['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal']
cell_type_names = list(cell_name for cell_name in
network_cell_names if cell_name in spike_types)
cell_response = CellResponse(cell_type_names=cell_type_names,
spike_times=spike_times,
spike_gids=spike_gids,
spike_types=spike_types)
if gid_ranges is not None:
Expand Down
102 changes: 75 additions & 27 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Union


def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance):
def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance, cell_types):
"""Creates coordinate grid and place cells in it.

Parameters
Expand All @@ -47,7 +47,7 @@ def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance):
The grid spacing of pyramidal cells (in um). Note that basket cells are
placed in an uneven formation. Each one of them lies on a grid point
together with a pyramidal cell, though (overlapping).

cell_names: a dictionary of cells {'L2_pyramidal': 'L2_pyramidal'}
Returns
-------
pos_dict : dict of list of tuple (x, y, z)
Expand Down Expand Up @@ -94,18 +94,22 @@ def _calc_origin(xxrange, yyrange, zdiff):
xxrange = np.arange(n_pyr_x) * inplane_distance
yyrange = np.arange(n_pyr_y) * inplane_distance

pos_dict = {
'L5_pyramidal': _calc_pyramidal_coord(xxrange, yyrange, zdiff=0),
'L2_pyramidal': _calc_pyramidal_coord(xxrange, yyrange, zdiff=zdiff),
'L5_basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff,
inplane_distance, weight=0.2
),
'L2_basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff,
inplane_distance, weight=0.8
),
'origin': _calc_origin(xxrange, yyrange, zdiff),
cell_name_pos_mapping = {
'L5Pyr': _calc_pyramidal_coord(xxrange, yyrange, zdiff=0),
'L2Pyr': _calc_pyramidal_coord(xxrange, yyrange, zdiff=zdiff),
'L5Basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff,
inplane_distance, weight=0.2),
'L2Basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff,
inplane_distance, weight=0.8)
}

pos_dict = dict()
for cell_net_name, cell_template in cell_types.items():
cell_name = cell_template.name
pos_dict[cell_net_name] = cell_name_pos_mapping[cell_name]

pos_dict['origin'] = _calc_origin(xxrange, yyrange, zdiff),

return pos_dict


Expand Down Expand Up @@ -364,7 +368,6 @@ class Network:
produce a network with no cell-to-cell connections. As such,
connectivity information contained in ``params`` will be ignored.
"""

def __init__(self, params, add_drives_from_params=False,
legacy_mode=False, mesh_shape=(10, 10)):
# Save the parameters used to create the Network
Expand All @@ -391,7 +394,8 @@ def __init__(self, params, add_drives_from_params=False,
stacklevel=1)

# Source dict of names, first real ones only!
cell_types = {
# adding self before cell_types to make it an instance attribute
self.cell_types = {
'L2_basket': basket(cell_name=_short_name('L2_basket')),
'L2_pyramidal': pyramidal(cell_name=_short_name('L2_pyramidal')),
'L5_basket': basket(cell_name=_short_name('L5_basket')),
Expand All @@ -415,7 +419,6 @@ def __init__(self, params, add_drives_from_params=False,
# cell counts, real and artificial
self._n_cells = 0 # used in tests and MPIBackend checks
self.pos_dict = dict()
self.cell_types = dict()

# set the mesh shape
_validate_type(mesh_shape, tuple, 'mesh_shape')
Expand All @@ -433,12 +436,11 @@ def __init__(self, params, add_drives_from_params=False,
self._layer_separation = 1307.4 # XXX hard-coded default
self.set_cell_positions(inplane_distance=self._inplane_distance,
layer_separation=self._layer_separation)

# populates self.gid_ranges for the 1st time: order matters for
# NetworkBuilder!
for cell_name in cell_types:
for cell_name in self.cell_types:
self._add_cell_type(cell_name, self.pos_dict[cell_name],
cell_template=cell_types[cell_name])
cell_template=self.cell_types[cell_name])

if add_drives_from_params:
_add_drives_from_params(self)
Expand All @@ -448,20 +450,23 @@ def __init__(self, params, add_drives_from_params=False,

def __repr__(self):
class_name = self.__class__.__name__
s = ("%d x %d Pyramidal cells (L2, L5)"
% (self._N_pyr_x, self._N_pyr_y))
s += ("\n%d L2 basket cells\n%d L5 basket cells"
% (len(self.pos_dict['L2_basket']),
len(self.pos_dict['L5_basket'])))
return '<%s | %s>' % (class_name, s)
# Dynamically create the description based on the current cell types
descriptions = []
for cell_name in self.cell_types:
count = len(self.pos_dict.get(cell_name, []))
descriptions.append(f"{count} {cell_name} cells")

# Combine all descriptions into a single string
description_str = "\n".join(descriptions)
return f'<{class_name} | {description_str}>'

def __eq__(self, other):
if not isinstance(other, Network):
return NotImplemented

# Check connectivity
if ((len(self.connectivity) != len(other.connectivity)) or
not (_compare_lists(self.connectivity, other.connectivity))):
not (_compare_lists(self.onnectivity, other.connectivity))):
return False

# Check all other attributes
Expand Down Expand Up @@ -512,7 +517,8 @@ def set_cell_positions(self, *, inplane_distance=None,

pos = _create_cell_coords(n_pyr_x=self._N_pyr_x, n_pyr_y=self._N_pyr_y,
zdiff=layer_separation,
inplane_distance=inplane_distance)
inplane_distance=inplane_distance,
cell_types=self.cell_types)
# update positions of the real cells
for key in pos.keys():
self.pos_dict[key] = pos[key]
Expand Down Expand Up @@ -1196,12 +1202,54 @@ def _add_cell_type(self, cell_name, pos, cell_template=None):
ll = self._n_gids
self._n_gids += len(pos)
self.gid_ranges[cell_name] = range(ll, self._n_gids)

self.pos_dict[cell_name] = pos
if cell_template is not None:
self.cell_types.update({cell_name: cell_template})
self._n_cells += len(pos)

def rename_cell(self, original_name, new_name):
"""Renames cells in the network and clears connectivity so user can
set new connections.

Parameters
----------
original_name: str
The original cell name in the network to be changed
new_name: str
The desired new cell name in the network
"""
if original_name not in self.cell_types.keys():
# Raises error if the original name is not in cell_types
raise ValueError(
f" '{original_name}' is not in cell_types!")
elif new_name in self.cell_types.keys():
# Raises error if the new name is already in cell_types
raise ValueError(f"'{new_name}' is already in cell_types!")
elif original_name is None or new_name is None:
# Raises error if either arguments are not present.
raise TypeError
elif not isinstance(original_name, str):
# Raises error when original_name is not a string
raise TypeError(f"'{original_name}' must be a string")
elif not isinstance(new_name, str):
# Raises error when new_name is not a string
raise TypeError(f"'{new_name}' must be a string")
elif original_name in self.cell_types.keys():
# Update cell name in places where order doesn't matter
self.cell_types[new_name] = self.cell_types.pop(original_name)
self.pos_dict[new_name] = self.pos_dict.pop(original_name)

# Update cell name in gid_ranges: order matters for consistency!
for _ in range(len(self.gid_ranges)):
name, gid_range = self.gid_ranges.popitem(last=False)
if name == original_name:
# Insert the new name with the value of the original name
self.gid_ranges[new_name] = gid_range
else:
# Insert the value as it is
self.gid_ranges[name] = gid_range
self.clear_connectivity()

def gid_to_type(self, gid):
"""Reverse lookup of gid to type."""
return _gid_to_type(gid, self.gid_ranges)
Expand Down
5 changes: 5 additions & 0 deletions hnn_core/network_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def jones_2009_model(params=None, add_drives_from_params=False,
delay, lamtha, allow_autapses=False)

# layer2 Basket -> layer2 Pyr
# WAGDY: could try to create a dictionary
# of src_cell and target_cell and then a
# function that takes default cell names
# from this dictionary and overrides it
# with user cell names
src_cell = 'L2_basket'
target_cell = 'L2_pyramidal'
lamtha = 50.
Expand Down
34 changes: 26 additions & 8 deletions hnn_core/tests/test_cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def test_cell_response(tmp_path):
sim_times = np.arange(tstart, tstop, 1 / fs)
gid_ranges = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4),
'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)}
cell_response = CellResponse(spike_times=spike_times,
cell_response = CellResponse(cell_type_names=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
spike_times=spike_times,
spike_gids=spike_gids,
spike_types=spike_types,
times=sim_times)
Expand Down Expand Up @@ -57,11 +59,15 @@ def test_cell_response(tmp_path):
# creates these check that we always know which response attributes are
# simulated see #291 for discussion; objective is to keep cell_response
# size small
assert list(cell_response.__dict__.keys()) == \
sim_attributes + net_attributes
print("cell_response.__dict__.keys():", sorted(list(cell_response.__dict__.keys())))
print("sim_attributes + net_attributes:", sorted(sim_attributes + net_attributes))
assert sorted(list(cell_response.__dict__.keys())) == \
sorted(sim_attributes + net_attributes)

# Test recovery of empty spike files
empty_spike = CellResponse(spike_times=[[], []], spike_gids=[[], []],
empty_spike = CellResponse(cell_type_names=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
spike_times=[[], []], spike_gids=[[], []],
spike_types=[[], []])
empty_spike.write(tmp_path / 'empty_spk_%d.txt')
empty_spike.write(tmp_path / 'empty_spk.txt')
Expand All @@ -72,23 +78,35 @@ def test_cell_response(tmp_path):

with pytest.raises(TypeError,
match="spike_times should be a list of lists"):
cell_response = CellResponse(spike_times=([2.3456, 7.89],
cell_response = CellResponse(cell_type_names=['L2_basket',
'L2_pyramidal', 'L5_basket',
'L5_pyramidal'],
spike_times=([2.3456, 7.89],
[4.2812, 93.2]),
spike_gids=spike_gids,
spike_types=spike_types)

with pytest.raises(TypeError,
match="spike_times should be a list of lists"):
cell_response = CellResponse(spike_times=[1, 2], spike_gids=spike_gids,
cell_response = CellResponse(cell_type_names=['L2_basket',
'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
spike_times=[1, 2], spike_gids=spike_gids,
spike_types=spike_types)

with pytest.raises(ValueError, match="spike times, gids, and types should "
"be lists of the same length"):
cell_response = CellResponse(spike_times=[[2.3456, 7.89]],
cell_response = CellResponse(cell_type_names=['L2_basket',
'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
spike_times=[[2.3456, 7.89]],
spike_gids=spike_gids,
spike_types=spike_types)

cell_response = CellResponse(spike_times=spike_times,
cell_response = CellResponse(cell_type_names=['L2_basket',
'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
spike_times=spike_times,
spike_gids=spike_gids,
spike_types=spike_types)

Expand Down
42 changes: 40 additions & 2 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ def test_network_drives():
# to CellResponse-constructor for storage (Network is agnostic of time)
with pytest.raises(TypeError,
match="'times' is an np.ndarray of simulation times"):
_ = CellResponse(times='blah')
_ = CellResponse(cell_type_names=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
times='blah')

# Check that all external drives are initialized with the expected amount
# of artificial cells assuming legacy_mode=False (i.e., dependent on
Expand Down Expand Up @@ -553,7 +555,9 @@ def test_network_drives_legacy():
# to CellResponse-constructor for storage (Network is agnostic of time)
with pytest.raises(TypeError,
match="'times' is an np.ndarray of simulation times"):
_ = CellResponse(times='blah')
_ = CellResponse(cell_type_names=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
times='blah')

# Assert that all external drives are initialized
# Assumes legacy mode where cell-specific drives create artificial cells
Expand Down Expand Up @@ -725,6 +729,7 @@ def test_add_cell_type():

new_cell = net.cell_types['L2_basket'].copy()
net._add_cell_type('new_type', pos=pos, cell_template=new_cell)
assert 'new_type' in net.cell_types.keys()
net.cell_types['new_type'].synapses['gabaa']['tau1'] = tau1

n_new_type = len(net.gid_ranges['new_type'])
Expand Down Expand Up @@ -1158,3 +1163,36 @@ def test_only_drives_specified(self, base_network, src_gids,
target_gids=target_gids
)
assert len(indices) == expected


def test_rename_cell():
"""Tests renaming cell function"""
params = read_params(params_fname)
net = hnn_core.jones_2009_model(params)
# adding a list of new_names
new_names = ['L2_basket_test', 'L2_pyramidal_test',
'L5_basket_test', 'L5_pyrmidal_test']
# avoid iteration through net.cell_type.keys() by creating tuples of old and new names
rename_pairs = list(zip(net.cell_types.keys(), new_names))
for original_name, new_name in rename_pairs:
net.rename_cell(original_name, new_name)
for new_name in new_names:
assert new_name in net.cell_types.keys()
assert new_name in net.pos_dict.keys()
assert not net.connectivity
# Tests for non-existent original_name
original_name = 'original_name'
with pytest.raises(ValueError,
match=f"'{original_name}' is not in cell_types!"):
net.rename_cell('original_name', 'L2_basket_2')

# Test for already existing new_name
new_name = 'L2_basket_test'
with pytest.raises(ValueError,
match=f"'{new_name}' is already in cell_types!"):
net.rename_cell('L2_basket_test', new_name)

# Tests for non-string new_name
new_name = 5
with pytest.raises(TypeError, match=f"'{new_name}' must be a string"):
net.rename_cell('L2_basket_test', 5)
Loading
Loading