You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Consider the simulation loop in the Network.run() function:
# Simulate network activity for `time` timesteps.
for t in range(timesteps):
-> for l in self.layers:
# Update each layer of nodes.
if isinstance(self.layers[l], AbstractInput):
self.layers[l].step(inpts[l][t], self.dt)
else:
self.layers[l].step(inpts[l], self.dt)
# Clamp neurons to spike.
clamp = clamps.get(l, None)
if clamp is not None:
self.layers[l].s[clamp] = 1
# Run synapse updates.
-> for c in self.connections:
self.connections[c].update(
reward=reward, mask=masks.get(c, None), learning=self.learning
)
# Get input to all layers.
inpts.update(self.get_inputs())
# Record state variables of interest.
for m in self.monitors:
self.monitors[m].record()
# Re-normalize connections.
-> for c in self.connections:
self.connections[c].normalize()
Where I've marked a ->, there might be an opportunity to use torch.multiprocessing. Since we do updates at time t based on network state at time t-1, all Nodes / Connections updates can be performed with a separate process (thread?) at once. Letting k = no. of layers, m = no. of connections, given enough CPU / GPU resources, the loops marked with -> would have time complexity O(1) instead of O(k), O(m) in the number of layers and connections, respectively.
I think it'd be good to keep around two (?) multiprocessing.Pool objects around, one for Nodes objects and another for Connection objects. Instead of statements of the form:
for l in self.layers:
self.layers[l].step(...)
We might rewrite this as something like:
self.nodes_pool.map(Nodes.step, self.layers)
Here, nodes_pool is defined as an attribute in the Network constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).
This same idea can also be applied in the Network's reset() and get_inputs() functions.
The text was updated successfully, but these errors were encountered:
@djsaunde any progress on this? I'll start looking into it, because I'm working with pretty small networks and GPUs won't give you much of an advantage there. This seems to be the way to speed up in that case.
@Huizerd nope, just an idea we had some time ago. I'm not sure that it will speed things up, but it might be worth a shot. Let me know if you need any help.
Check out this branch for a start on the multiprocessing work (I'm pretty sure it fails as-is). It'll need to be fast-forwarded to the current state of the master branch.
Consider the simulation loop in the
Network.run()
function:Where I've marked a
->
, there might be an opportunity to usetorch.multiprocessing
. Since we do updates at timet
based on network state at timet-1
, allNodes
/Connection
s updates can be performed with a separate process (thread?) at once. Lettingk
= no. of layers,m
= no. of connections, given enough CPU / GPU resources, the loops marked with->
would have time complexityO(1)
instead ofO(k)
,O(m)
in the number of layers and connections, respectively.I think it'd be good to keep around two (?)
multiprocessing.Pool
objects around, one forNodes
objects and another forConnection
objects. Instead of statements of the form:We might rewrite this as something like:
Here,
nodes_pool
is defined as an attribute in theNetwork
constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).This same idea can also be applied in the
Network
'sreset()
andget_inputs()
functions.The text was updated successfully, but these errors were encountered: