Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcpni committed Jan 7, 2025
1 parent 19234a1 commit 328e59e
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions psyneulink/library/compositions/pytorchwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,25 @@ class PytorchCompositionWrapper(torch.nn.Module):
# # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT
# class PytorchCompositionWrapper(torch.jit.ScriptModule):
# MODIFIED 7/29/24 END
"""Wrapper for a Composition as a Pytorch Module
Class that wraps a `Composition <Composition>` as a PyTorch module.
"""Wrapper for a Composition as a Pytorch Module.
Wraps an `AutodiffComposition` as a `PyTorch module
<https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_, with each `Mechanism <Mechanism>` in the
AutodiffComposition wrapped as a `PytorchMechanismWrapper`, each `Projection <Projection>` wrapped as a
`PytorchProjectionWrapper`, and any nested Compositions wrapped as `PytorchCompositionWrapper`\\s. Each
PytorchMechanismWrapper implements a Pytorch version of the `function(s) <Mechanism_Base.function>` of the wrapped
`Mechanism`, which are executed in the PyTorchCompositionWrapper's `forward <PyTorchCompositionWrapper.forward>`
method in the order specified by the AutodiffComposition's `scheduler <Composition.scheduler>`. The
`matrix <MappingProjection.matrix>` Parameters of each wrapped `Projection` are assigned as parameters of the
`PytorchMechanismWrapper` Pytorch module and used, together with a Pytorch `matmul
<https://pytorch.org/docs/main/generated/torch.matmul.html>`_ operation, to generate the input to each
PyTorch function as specified by the `PytorchProjectionWrapper`\\'s `graph <Composition.graph>`. The graph
can be visualized using the AutodiffComposition's `show_graph <ShowGraph.show_graph>` method and setting its
*show_pytorch* argument to True (see `PytorchShowGraph` for additional information).
Two main responsibilities:
1) Set up parameters of PyTorch model & information required for forward computation:
1) Set up functions and parameters of PyTorch module required for it forward computation:
Handle nested compositions (flattened in infer_backpropagation_learning_pathways):
Deal with Projections into and/or out of a nested Composition as shown in figure below:
(note: Projections in outer Composition to/from a nested Composition's CIMs are learnable,
Expand Down Expand Up @@ -115,11 +128,12 @@ class PytorchCompositionWrapper(torch.nn.Module):
`AutodiffComposition` being wrapped.
wrapped_nodes : List[PytorchMechanismWrapper]
list of nodes in the PytorchCompositionWrapper corresponding to PyTorch functions. Generally these are
`Mechanisms <Mechanism>` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition`
being wrapped is itself a nested Composition, then the wrapped nodes are `PytorchCompositionWrapper` objects.
list of nodes in the PytorchCompositionWrapper corresponding to the PyTorch functions that comprise the
forward method of the Pytorch module implemented by the PytorchCompositionWrapper. Generally these are
`Mechanisms <Mechanism>` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition` Node
being wrapped is a nested Composition, then the wrapped node is itself a `PytorchCompositionWrapper` object.
When the PyTorch model is executed, all of these are "flattened" into a single PyTorch module, corresponding
to the outermost AutodiffComposition being wrapped,, which can be visualized using that AutodiffComposition's
to the outermost AutodiffComposition being wrapped, which can be visualized using that AutodiffComposition's
`show_graph <ShowGraph.show_graph>` method and setting its *show_pytorch* argument to True (see
`PytorchShowGraph` for additional information).
Expand Down

0 comments on commit 328e59e

Please sign in to comment.