Skip to content

Commit

Permalink
Tests::Py.
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Dec 1, 2023
1 parent 8457d41 commit a085dbc
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 205 deletions.
70 changes: 36 additions & 34 deletions python/test/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import arbor
import arbor as A
from arbor import units as U
import functools
from functools import lru_cache as cache
from pathlib import Path
import subprocess
import atexit
import inspect

_mpi_enabled = arbor.__config__["mpi"]
_mpi4py_enabled = arbor.__config__["mpi4py"]
_mpi_enabled = A.__config__["mpi"]
_mpi4py_enabled = A.__config__["mpi4py"]

# The API of `functools`'s caches went through a bunch of breaking changes from
# 3.6 to 3.9. Patch them up in a local `cache` function.
Expand Down Expand Up @@ -78,13 +79,13 @@ def _finalize_mpi():

MPI.Finalize()
else:
arbor.mpi_finalize()
A.mpi_finalize()


@_fixture
def context():
"""
Fixture that produces an MPI sensitive `arbor.context`
Fixture that produces an MPI sensitive `A.context`
"""
if _mpi_enabled:
if _mpi4py_enabled:
Expand All @@ -94,13 +95,13 @@ def context():
print("Context fixture initializing mpi4py", flush=True)
MPI.Initialize()
atexit.register(_finalize_mpi)
return arbor.context(arbor.proc_allocation(), mpi=MPI.COMM_WORLD)
elif not arbor.mpi_is_initialized():
return A.context(A.proc_allocation(), mpi=MPI.COMM_WORLD)
elif not A.mpi_is_initialized():
print("Context fixture initializing mpi", flush=True)
arbor.mpi_init()
A.mpi_init()
atexit.register(_finalize_mpi)
return arbor.context(arbor.proc_allocation(), mpi=arbor.mpi_comm())
return arbor.context(arbor.proc_allocation())
return A.context(A.proc_allocation(), mpi=A.mpi_comm())
return A.context(A.proc_allocation())


class _BuildCatError(Exception):
Expand Down Expand Up @@ -165,16 +166,16 @@ def _build_cat(name, path, context):
@repo_path()
def dummy_catalogue(repo_path):
"""
Fixture that returns a dummy `arbor.catalogue`
Fixture that returns a dummy `A.catalogue`
which contains the `dummy` mech.
"""
path = repo_path / "test" / "unit" / "dummy"
cat_path = _build_cat("dummy", path)
return arbor.load_catalogue(str(cat_path))
return A.load_catalogue(str(cat_path))


@_fixture
class empty_recipe(arbor.recipe):
class empty_recipe(A.recipe):
"""
Blank recipe fixture.
"""
Expand All @@ -183,24 +184,24 @@ class empty_recipe(arbor.recipe):


@_fixture
class art_spiker_recipe(arbor.recipe):
class art_spiker_recipe(A.recipe):
"""
Recipe fixture with 3 artificial spiking cells and one cable cell.
"""

def __init__(self):
super().__init__()
self.the_props = arbor.neuron_cable_properties()
self.the_props = A.neuron_cable_properties()
self.trains = [[0.8, 2, 2.1, 3], [0.4, 2, 2.2, 3.1, 4.5], [0.2, 2, 2.8, 3]]

def num_cells(self):
return 4

def cell_kind(self, gid):
if gid < 3:
return arbor.cell_kind.spike_source
return A.cell_kind.spike_source
else:
return arbor.cell_kind.cable
return A.cell_kind.cable

def connections_on(self, gid):
return []
Expand All @@ -215,41 +216,42 @@ def probes(self, gid):
if gid < 3:
return []
else:
return [arbor.cable_probe_membrane_voltage('"midpoint"', "Um")]
return [A.cable_probe_membrane_voltage('"midpoint"', "Um")]

def _cable_cell_elements(self):
# (1) Create a morphology with a single (cylindrical) segment of length=diameter
# = # 6 μm
tree = arbor.segment_tree()
tree = A.segment_tree()
tree.append(
arbor.mnpos,
arbor.mpoint(-3, 0, 0, 3),
arbor.mpoint(3, 0, 0, 3),
A.mnpos,
(-3, 0, 0, 3),
(3, 0, 0, 3),
tag=1,
)

# (2) Define the soma and its midpoint
labels = arbor.label_dict({"soma": "(tag 1)", "midpoint": "(location 0 0.5)"})
labels = A.label_dict({"soma": "(tag 1)", "midpoint": "(location 0 0.5)"})

# (3) Create cell and set properties
decor = arbor.decor()
decor = A.decor()
decor.set_property(Vm=-40)
decor.paint('"soma"', arbor.density("hh"))
decor.place('"midpoint"', arbor.iclamp(10, 2, 0.8), "iclamp")
decor.place('"midpoint"', arbor.threshold_detector(-10), "detector")
decor.paint('"soma"', A.density("hh"))
decor.place('"midpoint"', A.iclamp(10 * U.ms, 2 * U.ms, 0.8 * U.nA), "iclamp")
decor.place('"midpoint"', A.threshold_detector(-10 * U.mV), "detector")

# return tuple of tree, labels, and decor for creating a cable cell (can still
# be modified before calling arbor.cable_cell())
# be modified before calling A.cable_cell())
return tree, labels, decor

def cell_description(self, gid):
if gid < 3:
return arbor.spike_source_cell(
"src", arbor.explicit_schedule(self.trains[gid])
)
return A.spike_source_cell("src", self.schedule(gid))
else:
tree, labels, decor = self._cable_cell_elements()
return arbor.cable_cell(tree, decor, labels)
return A.cable_cell(tree, decor, labels)

def schedule(self, gid):
return A.explicit_schedule([t * U.ms for t in self.trains[gid]])


@_fixture
Expand All @@ -274,5 +276,5 @@ def sum_weight_hh_spike_2():
@context()
@art_spiker_recipe()
def art_spiking_sim(context, art_spiker_recipe):
dd = arbor.partition_load_balance(art_spiker_recipe, context)
return arbor.simulation(art_spiker_recipe, context, dd)
dd = A.partition_load_balance(art_spiker_recipe, context)
return A.simulation(art_spiker_recipe, context, dd)
43 changes: 22 additions & 21 deletions python/test/unit/test_catalogues.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
from .. import fixtures
import unittest
import arbor as arb
import arbor as A
from arbor import units as U

"""
tests for (dynamically loaded) catalogues
"""


class recipe(arb.recipe):
class recipe(A.recipe):
def __init__(self):
arb.recipe.__init__(self)
self.tree = arb.segment_tree()
self.tree.append(arb.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1)
self.props = arb.neuron_cable_properties()
A.recipe.__init__(self)
self.tree = A.segment_tree()
self.tree.append(A.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1)
self.props = A.neuron_cable_properties()
try:
self.props.catalogue = arb.load_catalogue("dummy-catalogue.so")
self.props.catalogue = A.load_catalogue("dummy-catalogue.so")
except Exception:
print("Catalogue not found. Are you running from build directory?")
raise
self.props.catalogue = arb.default_catalogue()
self.props.catalogue = A.default_catalogue()

d = arb.decor()
d.paint("(all)", arb.density("pas"))
d = A.decor()
d.paint("(all)", A.density("pas"))
d.set_property(Vm=0.0)
self.cell = arb.cable_cell(self.tree, d)
self.cell = A.cable_cell(self.tree, d)

def global_properties(self, _):
return self.props
Expand All @@ -32,7 +33,7 @@ def num_cells(self):
return 1

def cell_kind(self, gid):
return arb.cell_kind.cable
return A.cell_kind.cable

def cell_description(self, gid):
return self.cell
Expand All @@ -41,7 +42,7 @@ def cell_description(self, gid):
class TestCatalogues(unittest.TestCase):
def test_nonexistent(self):
with self.assertRaises(FileNotFoundError):
arb.load_catalogue("_NO_EXIST_.so")
A.load_catalogue("_NO_EXIST_.so")

@fixtures.dummy_catalogue()
def test_shared_catalogue(self, dummy_catalogue):
Expand All @@ -58,10 +59,10 @@ def test_shared_catalogue(self, dummy_catalogue):

def test_simulation(self):
rcp = recipe()
ctx = arb.context()
dom = arb.partition_load_balance(rcp, ctx)
sim = arb.simulation(rcp, ctx, dom)
sim.run(tfinal=30)
ctx = A.context()
dom = A.partition_load_balance(rcp, ctx)
sim = A.simulation(rcp, ctx, dom)
sim.run(tfinal=30 * U.ms)

def test_empty(self):
def len(cat):
Expand All @@ -70,9 +71,9 @@ def len(cat):
def hash_(cat):
return hash(" ".join(sorted(cat)))

cat = arb.catalogue()
ref = arb.default_catalogue()
other = arb.default_catalogue()
cat = A.catalogue()
ref = A.default_catalogue()
other = A.default_catalogue()
# Test empty constructor
self.assertEqual(0, len(cat), "Expected no mechanisms in `arbor.catalogue()`.")
# Test empty extend
Expand All @@ -98,7 +99,7 @@ def hash_(cat):
hash_(cat),
"Extending empty with cat should turn empty into cat.",
)
cat = arb.catalogue()
cat = A.catalogue()
cat.extend(other, "prefix/")
self.assertNotEqual(
hash_(other),
Expand Down
11 changes: 6 additions & 5 deletions python/test/unit/test_clear_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unittest
import arbor as A
import numpy as np
from arbor import units as U

from .. import fixtures
from .. import cases
Expand All @@ -21,11 +22,11 @@ class TestClearSamplers(unittest.TestCase):
def test_spike_clearing(self, art_spiking_sim):
sim = art_spiking_sim
sim.record(A.spike_recording.all)
handle = sim.sample((3, "Um"), A.regular_schedule(0.1))
handle = sim.sample((3, "Um"), A.regular_schedule(0.1*U.ms))

# baseline to test against Run in exactly the same stepping to make sure there are no rounding differences
sim.run(3, 0.01)
sim.run(5, 0.01)
sim.run(3*U.ms, 0.01*U.ms)
sim.run(5*U.ms, 0.01*U.ms)
spikes = sim.spikes()
times = spikes["time"].tolist()
gids = spikes["source"]["gid"].tolist()
Expand All @@ -34,7 +35,7 @@ def test_spike_clearing(self, art_spiking_sim):
sim.reset()

# simulated with clearing the memory inbetween the steppings
sim.run(3, 0.01)
sim.run(3*U.ms, 0.01*U.ms)
spikes = sim.spikes()
times_t = spikes["time"].tolist()
gids_t = spikes["source"]["gid"].tolist()
Expand All @@ -51,7 +52,7 @@ def test_spike_clearing(self, art_spiking_sim):
self.assertEqual(0, data_test.size)

# run the next part of the simulation
sim.run(5, 0.01)
sim.run(5*U.ms, 0.01*U.ms)
spikes = sim.spikes()
times_t.extend(spikes["time"].tolist())
gids_t.extend(spikes["source"]["gid"].tolist())
Expand Down
14 changes: 7 additions & 7 deletions python/test/unit/test_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import unittest
import arbor as A

from arbor import units as U
"""
Tests for decor and decoration wrappers.
TODO: Coverage for more than just iclamp.
Expand All @@ -12,29 +12,29 @@
class TestDecorClasses(unittest.TestCase):
def test_iclamp(self):
# Constant amplitude iclamp:
clamp = A.iclamp(10)
clamp = A.iclamp(10*U.nA)
self.assertEqual(0, clamp.frequency)
self.assertEqual([(0, 10)], clamp.envelope)

clamp = A.iclamp(10, frequency=20)
clamp = A.iclamp(current=10*U.nA, frequency=20*U.kHz)
self.assertEqual(20, clamp.frequency)
self.assertEqual([(0, 10)], clamp.envelope)

# Square pulse:
clamp = A.iclamp(100, 20, 3)
clamp = A.iclamp(100*U.ms, 20*U.ms, 3*U.nA)
self.assertEqual(0, clamp.frequency)
self.assertEqual([(100, 3), (120, 3), (120, 0)], clamp.envelope)

clamp = A.iclamp(100, 20, 3, frequency=7)
clamp = A.iclamp(100*U.ms, 20*U.ms, 3*U.nA, frequency=7*U.kHz)
self.assertEqual(7, clamp.frequency)
self.assertEqual([(100, 3), (120, 3), (120, 0)], clamp.envelope)

# Explicit envelope:
envelope = [(1, 10), (3, 30), (5, 50), (7, 0)]
clamp = A.iclamp(envelope)
clamp = A.iclamp([(t*U.ms, i*U.nA) for t, i in envelope])
self.assertEqual(0, clamp.frequency)
self.assertEqual(envelope, clamp.envelope)

clamp = A.iclamp(envelope, frequency=7)
clamp = A.iclamp([(t*U.ms, i*U.nA) for t, i in envelope], frequency=7*U.kHz)
self.assertEqual(7, clamp.frequency)
self.assertEqual(envelope, clamp.envelope)
8 changes: 4 additions & 4 deletions python/test/unit/test_event_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest

import arbor as arb

from arbor import units as U
"""
all tests for event generators (regular, explicit, poisson)
"""
Expand All @@ -14,22 +14,22 @@
class TestEventGenerator(unittest.TestCase):
def test_event_generator_regular_schedule(self):
cm = arb.cell_local_label("tgt0")
rs = arb.regular_schedule(2.0, 1.0, 100.0)
rs = arb.regular_schedule(2.0*U.ms, 1.0*U.ms, 100.0*U.ms)
rg = arb.event_generator(cm, 3.14, rs)
self.assertEqual(rg.target.label, "tgt0")
self.assertEqual(rg.target.policy, arb.selection_policy.univalent)
self.assertAlmostEqual(rg.weight, 3.14)

def test_event_generator_explicit_schedule(self):
cm = arb.cell_local_label("tgt1", arb.selection_policy.round_robin)
es = arb.explicit_schedule([0, 1, 2, 3, 4.4])
es = arb.explicit_schedule([0*U.ms, 1*U.ms, 2*U.ms, 3*U.ms, 4.4*U.ms])
eg = arb.event_generator(cm, -0.01, es)
self.assertEqual(eg.target.label, "tgt1")
self.assertEqual(eg.target.policy, arb.selection_policy.round_robin)
self.assertAlmostEqual(eg.weight, -0.01)

def test_event_generator_poisson_schedule(self):
ps = arb.poisson_schedule(0.0, 10.0, 0)
ps = arb.poisson_schedule(freq=10.0*U.kHz, seed=0)
pg = arb.event_generator("tgt2", 42.0, ps)
self.assertEqual(pg.target.label, "tgt2")
self.assertEqual(pg.target.policy, arb.selection_policy.univalent)
Expand Down
Loading

0 comments on commit a085dbc

Please sign in to comment.