Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH: add new property to access spike times by cell type #916

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Bug

API
~~~
- Add :func:`~hnn_core.CellResponse.spike_times_by_type` to get cell spiking times
organized by cell type, by `Mainak Jas`_ in :gh:`916`.

.. _0.4:

Expand Down
20 changes: 20 additions & 0 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ def __eq__(self, other):
def spike_times(self):
return self._spike_times

@property
def cell_types(self):
"""Get unique cell types."""
spike_types_data = np.concatenate(np.array(self.spike_types,
dtype=object))
return np.unique(spike_types_data).tolist()
Comment on lines +156 to +161
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would also include spike types that are from bursty drives. Should this filter for only the cell types? Without hardcoding one way to do this is to have a bidirectional relationship between Network and CellResponse. Similar to what matplotlib does with Figure and Axes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually want the spike times of the input drives for my work :) They are "artificial cells" ... but yeah not part of the local network. We would need the network object to know which cells are drive cells and which ones belong to the local network. Are you thinking of a net.cell and cell.net attribute? It will create complications for IO, probably more than what I want to bite for this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, net.cell_reponse is already an attribute. But there's not a way to get information about the parent network at the moment.

This came up in work on the spike raster plot. Currently we hardcode the local cell type names in the plotting function. I was thinking this might not be very flexible if local cell types are ever expanded or names changed.

But now that you mention it... should that plot also include artificial cells and not just the 4 local network? Then I can just grab them all with this property and just call it a day!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @ntolley may be able to answer better. But I think if you want to preserve the functionality, the cell_types (of the local network) from the default jones model may be declared somewhere as a global variable and imported where needed in the codebase ... acknowledging that it's not the best solution but we also don't want to over-engineer

Copy link
Collaborator

@asoplata asoplata Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I also think cell_types needs to be re-evaluated, but it's a big task and would need its own issue. The same goes for differentiating between drives and cells using `cell_types.


@property
def spike_times_by_type(self):
"""Get a dictionary of spike times by cell type"""
spike_times = dict()
for cell_type in self.cell_types:
spike_times[cell_type] = list()
for trial_spike_times, trial_spike_types in zip(self.spike_times,
self.spike_types):
mask = np.isin(trial_spike_types, cell_type)
cell_spike_times = np.array(trial_spike_times)[mask].tolist()
spike_times[cell_type].append(cell_spike_times)
return spike_times

@property
def spike_gids(self):
return self._spike_gids
Expand Down
5 changes: 5 additions & 0 deletions hnn_core/tests/test_cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def test_cell_response(tmp_path):
spike_gids=spike_gids,
spike_types=spike_types,
times=sim_times)

assert set(cell_response.cell_types) == set(gid_ranges.keys())
assert cell_response.spike_times_by_type['L2_basket'] == [[7.89], []]
assert cell_response.spike_times_by_type['L5_pyramidal'] == [[], [4.2812]]

kwargs_hist = dict(alpha=0.25)
fig = cell_response.plot_spikes_hist(show=False, **kwargs_hist)
assert all(patch.get_alpha() == kwargs_hist['alpha']
Expand Down
Loading