diff --git a/python/test/fixtures.py b/python/test/fixtures.py index 11b9011750..af74318687 100644 --- a/python/test/fixtures.py +++ b/python/test/fixtures.py @@ -1,4 +1,5 @@ -import arbor +import arbor as A +from arbor import units as U import functools from functools import lru_cache as cache from pathlib import Path @@ -6,8 +7,8 @@ import atexit import inspect -_mpi_enabled = arbor.__config__["mpi"] -_mpi4py_enabled = arbor.__config__["mpi4py"] +_mpi_enabled = A.__config__["mpi"] +_mpi4py_enabled = A.__config__["mpi4py"] # The API of `functools`'s caches went through a bunch of breaking changes from # 3.6 to 3.9. Patch them up in a local `cache` function. @@ -78,13 +79,13 @@ def _finalize_mpi(): MPI.Finalize() else: - arbor.mpi_finalize() + A.mpi_finalize() @_fixture def context(): """ - Fixture that produces an MPI sensitive `arbor.context` + Fixture that produces an MPI sensitive `A.context` """ if _mpi_enabled: if _mpi4py_enabled: @@ -94,13 +95,13 @@ def context(): print("Context fixture initializing mpi4py", flush=True) MPI.Initialize() atexit.register(_finalize_mpi) - return arbor.context(arbor.proc_allocation(), mpi=MPI.COMM_WORLD) - elif not arbor.mpi_is_initialized(): + return A.context(A.proc_allocation(), mpi=MPI.COMM_WORLD) + elif not A.mpi_is_initialized(): print("Context fixture initializing mpi", flush=True) - arbor.mpi_init() + A.mpi_init() atexit.register(_finalize_mpi) - return arbor.context(arbor.proc_allocation(), mpi=arbor.mpi_comm()) - return arbor.context(arbor.proc_allocation()) + return A.context(A.proc_allocation(), mpi=A.mpi_comm()) + return A.context(A.proc_allocation()) class _BuildCatError(Exception): @@ -165,16 +166,16 @@ def _build_cat(name, path, context): @repo_path() def dummy_catalogue(repo_path): """ - Fixture that returns a dummy `arbor.catalogue` + Fixture that returns a dummy `A.catalogue` which contains the `dummy` mech. """ path = repo_path / "test" / "unit" / "dummy" cat_path = _build_cat("dummy", path) - return arbor.load_catalogue(str(cat_path)) + return A.load_catalogue(str(cat_path)) @_fixture -class empty_recipe(arbor.recipe): +class empty_recipe(A.recipe): """ Blank recipe fixture. """ @@ -183,14 +184,14 @@ class empty_recipe(arbor.recipe): @_fixture -class art_spiker_recipe(arbor.recipe): +class art_spiker_recipe(A.recipe): """ Recipe fixture with 3 artificial spiking cells and one cable cell. """ def __init__(self): super().__init__() - self.the_props = arbor.neuron_cable_properties() + self.the_props = A.neuron_cable_properties() self.trains = [[0.8, 2, 2.1, 3], [0.4, 2, 2.2, 3.1, 4.5], [0.2, 2, 2.8, 3]] def num_cells(self): @@ -198,9 +199,9 @@ def num_cells(self): def cell_kind(self, gid): if gid < 3: - return arbor.cell_kind.spike_source + return A.cell_kind.spike_source else: - return arbor.cell_kind.cable + return A.cell_kind.cable def connections_on(self, gid): return [] @@ -215,41 +216,42 @@ def probes(self, gid): if gid < 3: return [] else: - return [arbor.cable_probe_membrane_voltage('"midpoint"', "Um")] + return [A.cable_probe_membrane_voltage('"midpoint"', "Um")] def _cable_cell_elements(self): # (1) Create a morphology with a single (cylindrical) segment of length=diameter # = # 6 μm - tree = arbor.segment_tree() + tree = A.segment_tree() tree.append( - arbor.mnpos, - arbor.mpoint(-3, 0, 0, 3), - arbor.mpoint(3, 0, 0, 3), + A.mnpos, + (-3, 0, 0, 3), + (3, 0, 0, 3), tag=1, ) # (2) Define the soma and its midpoint - labels = arbor.label_dict({"soma": "(tag 1)", "midpoint": "(location 0 0.5)"}) + labels = A.label_dict({"soma": "(tag 1)", "midpoint": "(location 0 0.5)"}) # (3) Create cell and set properties - decor = arbor.decor() + decor = A.decor() decor.set_property(Vm=-40) - decor.paint('"soma"', arbor.density("hh")) - decor.place('"midpoint"', arbor.iclamp(10, 2, 0.8), "iclamp") - decor.place('"midpoint"', arbor.threshold_detector(-10), "detector") + decor.paint('"soma"', A.density("hh")) + decor.place('"midpoint"', A.iclamp(10 * U.ms, 2 * U.ms, 0.8 * U.nA), "iclamp") + decor.place('"midpoint"', A.threshold_detector(-10 * U.mV), "detector") # return tuple of tree, labels, and decor for creating a cable cell (can still - # be modified before calling arbor.cable_cell()) + # be modified before calling A.cable_cell()) return tree, labels, decor def cell_description(self, gid): if gid < 3: - return arbor.spike_source_cell( - "src", arbor.explicit_schedule(self.trains[gid]) - ) + return A.spike_source_cell("src", self.schedule(gid)) else: tree, labels, decor = self._cable_cell_elements() - return arbor.cable_cell(tree, decor, labels) + return A.cable_cell(tree, decor, labels) + + def schedule(self, gid): + return A.explicit_schedule([t * U.ms for t in self.trains[gid]]) @_fixture @@ -274,5 +276,5 @@ def sum_weight_hh_spike_2(): @context() @art_spiker_recipe() def art_spiking_sim(context, art_spiker_recipe): - dd = arbor.partition_load_balance(art_spiker_recipe, context) - return arbor.simulation(art_spiker_recipe, context, dd) + dd = A.partition_load_balance(art_spiker_recipe, context) + return A.simulation(art_spiker_recipe, context, dd) diff --git a/python/test/unit/test_catalogues.py b/python/test/unit/test_catalogues.py index fef3965e9c..833af85225 100644 --- a/python/test/unit/test_catalogues.py +++ b/python/test/unit/test_catalogues.py @@ -1,29 +1,30 @@ from .. import fixtures import unittest -import arbor as arb +import arbor as A +from arbor import units as U """ tests for (dynamically loaded) catalogues """ -class recipe(arb.recipe): +class recipe(A.recipe): def __init__(self): - arb.recipe.__init__(self) - self.tree = arb.segment_tree() - self.tree.append(arb.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1) - self.props = arb.neuron_cable_properties() + A.recipe.__init__(self) + self.tree = A.segment_tree() + self.tree.append(A.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1) + self.props = A.neuron_cable_properties() try: - self.props.catalogue = arb.load_catalogue("dummy-catalogue.so") + self.props.catalogue = A.load_catalogue("dummy-catalogue.so") except Exception: print("Catalogue not found. Are you running from build directory?") raise - self.props.catalogue = arb.default_catalogue() + self.props.catalogue = A.default_catalogue() - d = arb.decor() - d.paint("(all)", arb.density("pas")) + d = A.decor() + d.paint("(all)", A.density("pas")) d.set_property(Vm=0.0) - self.cell = arb.cable_cell(self.tree, d) + self.cell = A.cable_cell(self.tree, d) def global_properties(self, _): return self.props @@ -32,7 +33,7 @@ def num_cells(self): return 1 def cell_kind(self, gid): - return arb.cell_kind.cable + return A.cell_kind.cable def cell_description(self, gid): return self.cell @@ -41,7 +42,7 @@ def cell_description(self, gid): class TestCatalogues(unittest.TestCase): def test_nonexistent(self): with self.assertRaises(FileNotFoundError): - arb.load_catalogue("_NO_EXIST_.so") + A.load_catalogue("_NO_EXIST_.so") @fixtures.dummy_catalogue() def test_shared_catalogue(self, dummy_catalogue): @@ -58,10 +59,10 @@ def test_shared_catalogue(self, dummy_catalogue): def test_simulation(self): rcp = recipe() - ctx = arb.context() - dom = arb.partition_load_balance(rcp, ctx) - sim = arb.simulation(rcp, ctx, dom) - sim.run(tfinal=30) + ctx = A.context() + dom = A.partition_load_balance(rcp, ctx) + sim = A.simulation(rcp, ctx, dom) + sim.run(tfinal=30 * U.ms) def test_empty(self): def len(cat): @@ -70,9 +71,9 @@ def len(cat): def hash_(cat): return hash(" ".join(sorted(cat))) - cat = arb.catalogue() - ref = arb.default_catalogue() - other = arb.default_catalogue() + cat = A.catalogue() + ref = A.default_catalogue() + other = A.default_catalogue() # Test empty constructor self.assertEqual(0, len(cat), "Expected no mechanisms in `arbor.catalogue()`.") # Test empty extend @@ -98,7 +99,7 @@ def hash_(cat): hash_(cat), "Extending empty with cat should turn empty into cat.", ) - cat = arb.catalogue() + cat = A.catalogue() cat.extend(other, "prefix/") self.assertNotEqual( hash_(other), diff --git a/python/test/unit/test_clear_samplers.py b/python/test/unit/test_clear_samplers.py index fcfb6d7d95..3f9c5b68b4 100644 --- a/python/test/unit/test_clear_samplers.py +++ b/python/test/unit/test_clear_samplers.py @@ -5,6 +5,7 @@ import unittest import arbor as A import numpy as np +from arbor import units as U from .. import fixtures from .. import cases @@ -21,11 +22,11 @@ class TestClearSamplers(unittest.TestCase): def test_spike_clearing(self, art_spiking_sim): sim = art_spiking_sim sim.record(A.spike_recording.all) - handle = sim.sample((3, "Um"), A.regular_schedule(0.1)) + handle = sim.sample((3, "Um"), A.regular_schedule(0.1*U.ms)) # baseline to test against Run in exactly the same stepping to make sure there are no rounding differences - sim.run(3, 0.01) - sim.run(5, 0.01) + sim.run(3*U.ms, 0.01*U.ms) + sim.run(5*U.ms, 0.01*U.ms) spikes = sim.spikes() times = spikes["time"].tolist() gids = spikes["source"]["gid"].tolist() @@ -34,7 +35,7 @@ def test_spike_clearing(self, art_spiking_sim): sim.reset() # simulated with clearing the memory inbetween the steppings - sim.run(3, 0.01) + sim.run(3*U.ms, 0.01*U.ms) spikes = sim.spikes() times_t = spikes["time"].tolist() gids_t = spikes["source"]["gid"].tolist() @@ -51,7 +52,7 @@ def test_spike_clearing(self, art_spiking_sim): self.assertEqual(0, data_test.size) # run the next part of the simulation - sim.run(5, 0.01) + sim.run(5*U.ms, 0.01*U.ms) spikes = sim.spikes() times_t.extend(spikes["time"].tolist()) gids_t.extend(spikes["source"]["gid"].tolist()) diff --git a/python/test/unit/test_decor.py b/python/test/unit/test_decor.py index 59dcf0d900..3a9b29c3be 100644 --- a/python/test/unit/test_decor.py +++ b/python/test/unit/test_decor.py @@ -2,7 +2,7 @@ import unittest import arbor as A - +from arbor import units as U """ Tests for decor and decoration wrappers. TODO: Coverage for more than just iclamp. @@ -12,29 +12,29 @@ class TestDecorClasses(unittest.TestCase): def test_iclamp(self): # Constant amplitude iclamp: - clamp = A.iclamp(10) + clamp = A.iclamp(10*U.nA) self.assertEqual(0, clamp.frequency) self.assertEqual([(0, 10)], clamp.envelope) - clamp = A.iclamp(10, frequency=20) + clamp = A.iclamp(current=10*U.nA, frequency=20*U.kHz) self.assertEqual(20, clamp.frequency) self.assertEqual([(0, 10)], clamp.envelope) # Square pulse: - clamp = A.iclamp(100, 20, 3) + clamp = A.iclamp(100*U.ms, 20*U.ms, 3*U.nA) self.assertEqual(0, clamp.frequency) self.assertEqual([(100, 3), (120, 3), (120, 0)], clamp.envelope) - clamp = A.iclamp(100, 20, 3, frequency=7) + clamp = A.iclamp(100*U.ms, 20*U.ms, 3*U.nA, frequency=7*U.kHz) self.assertEqual(7, clamp.frequency) self.assertEqual([(100, 3), (120, 3), (120, 0)], clamp.envelope) # Explicit envelope: envelope = [(1, 10), (3, 30), (5, 50), (7, 0)] - clamp = A.iclamp(envelope) + clamp = A.iclamp([(t*U.ms, i*U.nA) for t, i in envelope]) self.assertEqual(0, clamp.frequency) self.assertEqual(envelope, clamp.envelope) - clamp = A.iclamp(envelope, frequency=7) + clamp = A.iclamp([(t*U.ms, i*U.nA) for t, i in envelope], frequency=7*U.kHz) self.assertEqual(7, clamp.frequency) self.assertEqual(envelope, clamp.envelope) diff --git a/python/test/unit/test_event_generators.py b/python/test/unit/test_event_generators.py index b43387029c..cf3eeb64c2 100644 --- a/python/test/unit/test_event_generators.py +++ b/python/test/unit/test_event_generators.py @@ -5,7 +5,7 @@ import unittest import arbor as arb - +from arbor import units as U """ all tests for event generators (regular, explicit, poisson) """ @@ -14,7 +14,7 @@ class TestEventGenerator(unittest.TestCase): def test_event_generator_regular_schedule(self): cm = arb.cell_local_label("tgt0") - rs = arb.regular_schedule(2.0, 1.0, 100.0) + rs = arb.regular_schedule(2.0*U.ms, 1.0*U.ms, 100.0*U.ms) rg = arb.event_generator(cm, 3.14, rs) self.assertEqual(rg.target.label, "tgt0") self.assertEqual(rg.target.policy, arb.selection_policy.univalent) @@ -22,14 +22,14 @@ def test_event_generator_regular_schedule(self): def test_event_generator_explicit_schedule(self): cm = arb.cell_local_label("tgt1", arb.selection_policy.round_robin) - es = arb.explicit_schedule([0, 1, 2, 3, 4.4]) + es = arb.explicit_schedule([0*U.ms, 1*U.ms, 2*U.ms, 3*U.ms, 4.4*U.ms]) eg = arb.event_generator(cm, -0.01, es) self.assertEqual(eg.target.label, "tgt1") self.assertEqual(eg.target.policy, arb.selection_policy.round_robin) self.assertAlmostEqual(eg.weight, -0.01) def test_event_generator_poisson_schedule(self): - ps = arb.poisson_schedule(0.0, 10.0, 0) + ps = arb.poisson_schedule(freq=10.0*U.kHz, seed=0) pg = arb.event_generator("tgt2", 42.0, ps) self.assertEqual(pg.target.label, "tgt2") self.assertEqual(pg.target.policy, arb.selection_policy.univalent) diff --git a/python/test/unit/test_multiple_connections.py b/python/test/unit/test_multiple_connections.py index 7f8af4fcc4..6cd2c517f2 100644 --- a/python/test/unit/test_multiple_connections.py +++ b/python/test/unit/test_multiple_connections.py @@ -6,7 +6,8 @@ import types import numpy as np -import arbor as arb +import arbor as A +from arbor import units as U from .. import fixtures """ @@ -38,7 +39,7 @@ def __init__(self, args): # Method creating a new mechanism for a synapse with STDP def create_syn_mechanism(self, scale_contrib=1): # create new synapse mechanism - syn_mechanism = arb.mechanism("expsyn_stdp") + syn_mechanism = A.mechanism("expsyn_stdp") # set pre- and postsynaptic contributions for STDP syn_mechanism.set("Apre", 0.01 * scale_contrib) @@ -89,9 +90,7 @@ def rr_main(self, context, art_spiker_recipe, weight, weight2): def cell_description(self, gid): # spike source neuron if gid < 3: - return arb.spike_source_cell( - "spike_source", arb.explicit_schedule(self.trains[gid]) - ) + return A.spike_source_cell("spike_source", self.schedule(gid)) # spike-receiving cable neuron elif gid == 3: @@ -101,16 +100,16 @@ def cell_description(self, gid): decor.place( '"midpoint"', - arb.synapse(create_syn_mechanism(scale_stdp)), + A.synapse(create_syn_mechanism(scale_stdp)), "postsyn_target", ) # place synapse for input from one presynaptic neuron at the center of the soma decor.place( '"midpoint"', - arb.synapse(create_syn_mechanism(scale_stdp)), + A.synapse(create_syn_mechanism(scale_stdp)), "postsyn_target", ) # place synapse for input from another presynaptic neuron at the center of the soma # (using the same label as above!) - return arb.cable_cell(tree, decor, labels) + return A.cable_cell(tree, decor, labels) art_spiker_recipe.cell_description = types.MethodType( cell_description, art_spiker_recipe @@ -140,15 +139,15 @@ def cell_description(self, gid): self.assertAlmostEqual(connections_from_recipe[3].delay, 1.4) # construct domain_decomposition and simulation object - sim = arb.simulation(art_spiker_recipe, context) - sim.record(arb.spike_recording.all) + sim = A.simulation(art_spiker_recipe, context) + sim.record(A.spike_recording.all) # create schedule and handle to record the membrane potential of neuron 3 - reg_sched = arb.regular_schedule(0, self.dt, self.runtime) + reg_sched = A.regular_schedule(0 * U.ms, self.dt * U.ms, self.runtime * U.ms) handle_mem = sim.sample((3, "Um"), reg_sched) # run the simulation - sim.run(self.runtime, self.dt) + sim.run(self.runtime * U.ms, self.dt * U.ms) return sim, handle_mem @@ -175,28 +174,28 @@ def connections_on(self, gid): # incoming to neuron 3 elif gid == 3: - source_label_0 = arb.cell_global_label( + source_label_0 = A.cell_global_label( 0, "spike_source" ) # referring to the "spike_source" label of neuron 0 - source_label_1 = arb.cell_global_label( + source_label_1 = A.cell_global_label( 1, "spike_source" ) # referring to the "spike_source" label of neuron 1 - target_label_rr = arb.cell_local_label( - "postsyn_target", arb.selection_policy.round_robin + target_label_rr = A.cell_local_label( + "postsyn_target", A.selection_policy.round_robin ) # referring to the current item in the "postsyn_target" label group of neuron 3, moving to the next item afterwards - conn_0_3_n1 = arb.connection( + conn_0_3_n1 = A.connection( source_label_0, target_label_rr, weight, 0.2 ) # first connection from neuron 0 to 3 - conn_0_3_n2 = arb.connection( + conn_0_3_n2 = A.connection( source_label_0, target_label_rr, weight, 0.2 ) # second connection from neuron 0 to 3 # NOTE: this is not connecting to the same target label item as 'conn_0_3_n1' because 'round_robin' has been used before! - conn_1_3_n1 = arb.connection( + conn_1_3_n1 = A.connection( source_label_1, target_label_rr, weight2, 1.4 ) # first connection from neuron 1 to 3 - conn_1_3_n2 = arb.connection( + conn_1_3_n2 = A.connection( source_label_1, target_label_rr, weight2, 1.4 ) # second connection from neuron 1 to 3 # NOTE: this is not connecting to the same target label item as 'conn_1_3_n1' because 'round_robin' has been used before! @@ -237,30 +236,30 @@ def connections_on(self, gid): # incoming to neuron 3 elif gid == 3: - source_label_0 = arb.cell_global_label( + source_label_0 = A.cell_global_label( 0, "spike_source" ) # referring to the "spike_source" label of neuron 0 - source_label_1 = arb.cell_global_label( + source_label_1 = A.cell_global_label( 1, "spike_source" ) # referring to the "spike_source" label of neuron 1 - target_label_rr_halt = arb.cell_local_label( - "postsyn_target", arb.selection_policy.round_robin_halt + target_label_rr_halt = A.cell_local_label( + "postsyn_target", A.selection_policy.round_robin_halt ) # referring to the current item in the "postsyn_target" label group of neuron 3 - target_label_rr = arb.cell_local_label( - "postsyn_target", arb.selection_policy.round_robin + target_label_rr = A.cell_local_label( + "postsyn_target", A.selection_policy.round_robin ) # referring to the current item in the "postsyn_target" label group of neuron 3, moving to the next item afterwards - conn_0_3_n1 = arb.connection( + conn_0_3_n1 = A.connection( source_label_0, target_label_rr_halt, weight, 0.2 ) # first connection from neuron 0 to 3 - conn_0_3_n2 = arb.connection( + conn_0_3_n2 = A.connection( source_label_0, target_label_rr, weight, 0.2 ) # second connection from neuron 0 to 3 - conn_1_3_n1 = arb.connection( + conn_1_3_n1 = A.connection( source_label_1, target_label_rr_halt, weight2, 1.4 ) # first connection from neuron 1 to 3 - conn_1_3_n2 = arb.connection( + conn_1_3_n2 = A.connection( source_label_1, target_label_rr, weight2, 1.4 ) # second connection from neuron 1 to 3 @@ -298,24 +297,24 @@ def connections_on(self, gid): # incoming to neuron 3 elif gid == 3: - source_label_0 = arb.cell_global_label( + source_label_0 = A.cell_global_label( 0, "spike_source" ) # referring to the "spike_source" label of neuron 0 - source_label_1 = arb.cell_global_label( + source_label_1 = A.cell_global_label( 1, "spike_source" ) # referring to the "spike_source" label of neuron 1 - target_label_uni_n1 = arb.cell_local_label( - "postsyn_target_1", arb.selection_policy.univalent + target_label_uni_n1 = A.cell_local_label( + "postsyn_target_1", A.selection_policy.univalent ) # referring to an only item in the "postsyn_target_1" label group of neuron 3 - target_label_uni_n2 = arb.cell_local_label( - "postsyn_target_2", arb.selection_policy.univalent + target_label_uni_n2 = A.cell_local_label( + "postsyn_target_2", A.selection_policy.univalent ) # referring to an only item in the "postsyn_target_2" label group of neuron 3 - conn_0_3 = arb.connection( + conn_0_3 = A.connection( source_label_0, target_label_uni_n1, weight, 0.2 ) # connection from neuron 0 to 3 - conn_1_3 = arb.connection( + conn_1_3 = A.connection( source_label_1, target_label_uni_n2, weight2, 1.4 ) # connection from neuron 1 to 3 @@ -331,9 +330,7 @@ def connections_on(self, gid): def cell_description(self, gid): # spike source neuron if gid < 3: - return arb.spike_source_cell( - "spike_source", arb.explicit_schedule(self.trains[gid]) - ) + return A.spike_source_cell("spike_source", self.schedule(gid)) # spike-receiving cable neuron elif gid == 3: @@ -341,17 +338,17 @@ def cell_description(self, gid): decor.place( '"midpoint"', - arb.synapse(create_syn_mechanism()), + A.synapse(create_syn_mechanism()), "postsyn_target_1", ) # place synapse for input from one presynaptic neuron at the center of the soma decor.place( '"midpoint"', - arb.synapse(create_syn_mechanism()), + A.synapse(create_syn_mechanism()), "postsyn_target_2", ) # place synapse for input from another presynaptic neuron at the center of the soma # (using another label as above!) - return arb.cable_cell(tree, decor, labels) + return A.cable_cell(tree, decor, labels) art_spiker_recipe.cell_description = types.MethodType( cell_description, art_spiker_recipe @@ -371,15 +368,15 @@ def cell_description(self, gid): self.assertAlmostEqual(connections_from_recipe[1].delay, 1.4) # construct simulation object - sim = arb.simulation(art_spiker_recipe, context) - sim.record(arb.spike_recording.all) + sim = A.simulation(art_spiker_recipe, context) + sim.record(A.spike_recording.all) # create schedule and handle to record the membrane potential of neuron 3 - reg_sched = arb.regular_schedule(0, self.dt, self.runtime) + reg_sched = A.regular_schedule(0 * U.ms, self.dt * U.ms, self.runtime * U.ms) handle_mem = sim.sample((3, "Um"), reg_sched) # run the simulation - sim.run(self.runtime, self.dt) + sim.run(self.runtime * U.ms, self.dt * U.ms) # evaluate the outcome self.evaluate_outcome(sim, handle_mem) diff --git a/python/test/unit/test_probes.py b/python/test/unit/test_probes.py index 417ea1d51d..0744fa2f5d 100644 --- a/python/test/unit/test_probes.py +++ b/python/test/unit/test_probes.py @@ -2,6 +2,7 @@ import unittest import arbor as A +from arbor import units as U import numpy as np """ @@ -22,7 +23,7 @@ def __init__(self): dec.place("(location 0 0.08)", A.synapse("expsyn"), "syn0") dec.place("(location 0 0.09)", A.synapse("exp2syn"), "syn1") - dec.place("(location 0 0.1)", A.iclamp(20.0), "iclamp") + dec.place("(location 0 0.1)", A.iclamp(20.0 * U.nA), "iclamp") dec.paint("(all)", A.density("hh")) self.cell = A.cable_cell(st, dec) @@ -202,8 +203,8 @@ def test_probe_addr_metadata(self): def test_probe_result(self): rec = lif_recipe() sim = A.simulation(rec) - hdl = sim.sample(0, "Um", A.regular_schedule(0.1)) - sim.run(1.0, 0.05) + hdl = sim.sample(0, "Um", A.regular_schedule(0.1 * U.ms)) + sim.run(1.0 * U.ms, 0.05 * U.ms) smp = sim.samples(hdl) exp = np.array( [ diff --git a/python/test/unit/test_schedules.py b/python/test/unit/test_schedules.py index 0a7c747ec1..d23d81e2a0 100644 --- a/python/test/unit/test_schedules.py +++ b/python/test/unit/test_schedules.py @@ -4,7 +4,8 @@ import unittest -import arbor as arb +import arbor as A +from arbor import units as U """ all tests for schedules (regular, explicit, poisson) @@ -13,28 +14,28 @@ class TestRegularSchedule(unittest.TestCase): def test_none_ctor_regular_schedule(self): - rs = arb.regular_schedule(tstart=0, dt=0.1, tstop=None) - self.assertEqual(rs.dt, 0.1) + rs = A.regular_schedule(tstart=0 * U.ms, dt=0.1 * U.ms, tstop=None) + self.assertEqual(rs.dt, 0.1 * U.ms) def test_tstart_dt_tstop_ctor_regular_schedule(self): - rs = arb.regular_schedule(10.0, 1.0, 20.0) - self.assertEqual(rs.tstart, 10.0) - self.assertEqual(rs.dt, 1.0) - self.assertEqual(rs.tstop, 20.0) + rs = A.regular_schedule(10.0 * U.ms, 1.0 * U.ms, 20.0 * U.ms) + self.assertEqual(rs.tstart, 10.0 * U.ms) + self.assertEqual(rs.dt, 1.0 * U.ms) + self.assertEqual(rs.tstop, 20.0 * U.ms) def test_set_tstart_dt_tstop_regular_schedule(self): - rs = arb.regular_schedule(0.1) - self.assertAlmostEqual(rs.dt, 0.1, places=1) - rs.tstart = 17.0 - rs.dt = 0.5 - rs.tstop = 42.0 - self.assertEqual(rs.tstart, 17.0) - self.assertAlmostEqual(rs.dt, 0.5, places=1) - self.assertEqual(rs.tstop, 42.0) + rs = A.regular_schedule(0.1 * U.ms) + self.assertAlmostEqual(rs.dt.value_as(U.ms), 0.1, places=1) + rs.tstart = 17.0 * U.ms + rs.dt = 0.5 * U.ms + rs.tstop = 42.0 * U.ms + self.assertEqual(rs.tstart, 17.0 * U.ms) + self.assertAlmostEqual(rs.dt.value_as(U.ms), 0.5, places=1) + self.assertEqual(rs.tstop, 42.0 * U.ms) def test_events_regular_schedule(self): expected = [0, 0.25, 0.5, 0.75, 1.0] - rs = arb.regular_schedule(tstart=0.0, dt=0.25, tstop=1.25) + rs = A.regular_schedule(tstart=0.0 * U.ms, dt=0.25 * U.ms, tstop=1.25 * U.ms) self.assertEqual(expected, rs.events(0.0, 1.25)) self.assertEqual(expected, rs.events(0.0, 5.0)) self.assertEqual([], rs.events(5.0, 10.0)) @@ -43,41 +44,34 @@ def test_exceptions_regular_schedule(self): with self.assertRaisesRegex( RuntimeError, "tstart must be a non-negative number" ): - arb.regular_schedule(tstart=-1.0, dt=0.1) + A.regular_schedule(tstart=-1.0 * U.ms, dt=0.1 * U.ms) with self.assertRaisesRegex(RuntimeError, "dt must be a positive number"): - arb.regular_schedule(dt=-0.1) + A.regular_schedule(dt=-0.1 * U.ms) with self.assertRaisesRegex(RuntimeError, "dt must be a positive number"): - arb.regular_schedule(dt=0) + A.regular_schedule(dt=0 * U.ms) with self.assertRaises(TypeError): - arb.regular_schedule(dt=None) + A.regular_schedule(dt=None) with self.assertRaises(TypeError): - arb.regular_schedule(dt="dt") - with self.assertRaisesRegex( - RuntimeError, "tstop must be a non-negative number, or None" - ): - arb.regular_schedule(tstart=0, dt=0.1, tstop="tstop") + A.regular_schedule(dt="dt") + with self.assertRaises(TypeError): + A.regular_schedule(tstart=0 * U.ms, dt=0.1 * U.ms, tstop="tstop") with self.assertRaisesRegex(RuntimeError, "t0 must be a non-negative number"): - rs = arb.regular_schedule(0.0, 1.0, 10.0) + rs = A.regular_schedule(0.0 * U.ms, 1.0 * U.ms, 10.0 * U.ms) rs.events(-1, 0) with self.assertRaisesRegex(RuntimeError, "t1 must be a non-negative number"): - rs = arb.regular_schedule(0.0, 1.0, 10.0) + rs = A.regular_schedule(0.0 * U.ms, 1.0 * U.ms, 10.0 * U.ms) rs.events(0, -10) class TestExplicitSchedule(unittest.TestCase): def test_times_contor_explicit_schedule(self): - es = arb.explicit_schedule([1, 2, 3, 4.5]) - self.assertEqual(es.times, [1, 2, 3, 4.5]) - - def test_set_times_explicit_schedule(self): - es = arb.explicit_schedule() - es.times = [42, 43, 44, 55.5, 100] - self.assertEqual(es.times, [42, 43, 44, 55.5, 100]) + es = A.explicit_schedule([t * U.ms for t in range(1, 6)]) + self.assertEqual(es.events(0, 1000000), [1, 2, 3, 4, 5]) def test_events_explicit_schedule(self): times = [0.1, 0.3, 1.0, 2.2, 1.25, 1.7] expected = [0.1, 0.3, 1.0] - es = arb.explicit_schedule(times) + es = A.explicit_schedule([t * U.ms for t in times]) for i in range(len(expected)): self.assertAlmostEqual(expected[i], es.events(0.0, 1.25)[i], places=2) expected = [0.3, 1.0, 1.25, 1.7] @@ -85,47 +79,47 @@ def test_events_explicit_schedule(self): self.assertAlmostEqual(expected[i], es.events(0.3, 1.71)[i], places=2) def test_exceptions_explicit_schedule(self): - with self.assertRaisesRegex( - RuntimeError, "explicit time schedule cannot contain negative values" - ): - arb.explicit_schedule([-1]) + with self.assertRaises(RuntimeError): + A.explicit_schedule([-1 * U.ms]) with self.assertRaises(TypeError): - arb.explicit_schedule(["times"]) + A.explicit_schedule(["times"]) with self.assertRaises(TypeError): - arb.explicit_schedule([None]) + A.explicit_schedule([None]) with self.assertRaises(TypeError): - arb.explicit_schedule([[1, 2, 3]]) + A.explicit_schedule([[1, 2, 3]]) with self.assertRaisesRegex(RuntimeError, "t1 must be a non-negative number"): - rs = arb.regular_schedule(0.1) + rs = A.regular_schedule(0.1 * U.ms) rs.events(1.0, -1.0) class TestPoissonSchedule(unittest.TestCase): def test_freq_poisson_schedule(self): - ps = arb.poisson_schedule(42.0) - self.assertEqual(ps.freq, 42.0) + ps = A.poisson_schedule(42.0 * U.kHz) + self.assertEqual(ps.freq, 42.0 * U.kHz) def test_freq_tstart_contor_poisson_schedule(self): - ps = arb.poisson_schedule(freq=5.0, tstart=4.3) - self.assertEqual(ps.freq, 5.0) - self.assertEqual(ps.tstart, 4.3) + ps = A.poisson_schedule(freq=5.0 * U.kHz, tstart=4.3 * U.ms) + self.assertEqual(ps.freq, 5.0 * U.kHz) + self.assertEqual(ps.tstart, 4.3 * U.ms) def test_freq_seed_contor_poisson_schedule(self): - ps = arb.poisson_schedule(freq=5.0, seed=42) - self.assertEqual(ps.freq, 5.0) + ps = A.poisson_schedule(freq=5.0 * U.kHz, seed=42) + self.assertEqual(ps.freq, 5.0 * U.kHz) self.assertEqual(ps.seed, 42) def test_tstart_freq_seed_contor_poisson_schedule(self): - ps = arb.poisson_schedule(10.0, 100.0, 1000) - self.assertEqual(ps.tstart, 10.0) - self.assertEqual(ps.freq, 100.0) + ps = A.poisson_schedule(tstart=10.0 * U.ms, freq=100.0 * U.kHz, seed=1000) + self.assertEqual(ps.tstart, 10.0 * U.ms) + self.assertEqual(ps.freq, 100.0 * U.kHz) self.assertEqual(ps.seed, 1000) def test_events_poisson_schedule(self): expected = [17.4107, 502.074, 506.111, 597.116] - ps = arb.poisson_schedule(0.0, 0.01, 0) + ps = A.poisson_schedule(tstart=0.0 * U.ms, freq=0.01 * U.kHz, seed=0) for i in range(len(expected)): - self.assertAlmostEqual(expected[i], ps.events(0.0, 600.0)[i], places=3) + self.assertAlmostEqual( + expected[i], ps.events(0.0 * U.ms, 600.0 * U.ms)[i], places=3 + ) expected = [ 5030.22, 5045.75, @@ -140,50 +134,52 @@ def test_events_poisson_schedule(self): 5808.33, ] for i in range(len(expected)): - self.assertAlmostEqual(expected[i], ps.events(5000.0, 6000.0)[i], places=2) + self.assertAlmostEqual( + expected[i], ps.events(5000.0 * U.ms, 6000.0 * U.ms)[i], places=2 + ) def test_exceptions_poisson_schedule(self): with self.assertRaises(TypeError): - arb.poisson_schedule() + A.poisson_schedule() with self.assertRaises(TypeError): - arb.poisson_schedule(tstart=10.0) + A.poisson_schedule(tstart=10.0 * U.ms) with self.assertRaises(TypeError): - arb.poisson_schedule(seed=1432) + A.poisson_schedule(seed=1432) with self.assertRaisesRegex( RuntimeError, "tstart must be a non-negative number" ): - arb.poisson_schedule(freq=34.0, tstart=-10.0) + A.poisson_schedule(freq=34.0 * U.kHz, tstart=-10.0 * U.ms) with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, tstart=None) + A.poisson_schedule(freq=34.0 * U.kHz, tstart=None) with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, tstart="tstart") + A.poisson_schedule(freq=34.0, tstart="tstart") with self.assertRaisesRegex( RuntimeError, "frequency must be a non-negative number" ): - arb.poisson_schedule(freq=-100.0) + A.poisson_schedule(freq=-100.0 * U.kHz) with self.assertRaises(TypeError): - arb.poisson_schedule(freq="freq") + A.poisson_schedule(freq="freq") with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, seed=-1) + A.poisson_schedule(freq=34.0 * U.kHz, seed=-1) with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, seed=10.0) + A.poisson_schedule(freq=34.0 * U.kHz, seed=10.0) with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, seed="seed") + A.poisson_schedule(freq=34.0 * U.kHz, seed="seed") with self.assertRaises(TypeError): - arb.poisson_schedule(freq=34.0, seed=None) + A.poisson_schedule(freq=34.0 * U.kHz, seed=None) with self.assertRaisesRegex(RuntimeError, "t0 must be a non-negative number"): - ps = arb.poisson_schedule(0, 0.01) - ps.events(-1.0, 1.0) + ps = A.poisson_schedule(tstart=0 * U.ms, freq=0.01 * U.kHz) + ps.events(-1.0 * U.ms, 1.0 * U.ms) with self.assertRaisesRegex(RuntimeError, "t1 must be a non-negative number"): - ps = arb.poisson_schedule(0, 0.01) - ps.events(1.0, -1.0) - with self.assertRaisesRegex( - RuntimeError, "tstop must be a non-negative number, or None" - ): - arb.poisson_schedule(0, 0.1, tstop="tstop") - ps.events(1.0, -1.0) + ps = A.poisson_schedule(tstart=0 * U.ms, freq=0.01 * U.kHz) + ps.events(1.0 * U.ms, -1.0 * U.ms) + with self.assertRaises(TypeError): + ps = A.poisson_schedule(tstart=0 * U.ms, freq=0.1 * U.kHz, tstop="tstop") + ps.events(1.0 * U.ms, -1.0 * U.ms) def test_tstop_poisson_schedule(self): tstop = 50 - events = arb.poisson_schedule(0.0, 1, 0, tstop).events(0, 100) + events = A.poisson_schedule( + tstart=0.0 * U.ms, freq=1 * U.kHz, seed=0, tstop=tstop * U.ms + ).events(0 * U.ms, 100 * U.ms) self.assertTrue(max(events) < tstop) diff --git a/python/test/unit/test_spikes.py b/python/test/unit/test_spikes.py index cd1e6b81e5..146a3a0acd 100644 --- a/python/test/unit/test_spikes.py +++ b/python/test/unit/test_spikes.py @@ -4,6 +4,7 @@ import unittest import arbor as A +from arbor import units as U from .. import fixtures """ @@ -18,11 +19,11 @@ def test_spikes_sorted(self, art_spiking_sim): sim = art_spiking_sim sim.record(A.spike_recording.all) # run simulation in 5 steps, forcing 5 epochs - sim.run(1, 0.01) - sim.run(2, 0.01) - sim.run(3, 0.01) - sim.run(4, 0.01) - sim.run(5, 0.01) + sim.run(1 * U.ms, 0.01 * U.ms) + sim.run(2 * U.ms, 0.01 * U.ms) + sim.run(3 * U.ms, 0.01 * U.ms) + sim.run(4 * U.ms, 0.01 * U.ms) + sim.run(5 * U.ms, 0.01 * U.ms) spikes = sim.spikes() times = spikes["time"].tolist() diff --git a/python/units.cpp b/python/units.cpp index 5ac7f82a20..92bccc5e2c 100644 --- a/python/units.cpp +++ b/python/units.cpp @@ -17,6 +17,8 @@ void register_units(py::module& m) { unit .def(py::self * py::self) + .def(py::self == py::self) + .def(py::self != py::self) .def(py::self / py::self) .def(py::self * double()) .def(py::self / double()) @@ -33,6 +35,8 @@ void register_units(py::module& m) { quantity .def(py::self * py::self) .def(py::self / py::self) + .def(py::self == py::self) + .def(py::self != py::self) .def(py::self + py::self) .def(py::self - py::self) .def(py::self * double())