diff --git a/psyneulink/library/compositions/pytorchwrappers.py b/psyneulink/library/compositions/pytorchwrappers.py index 35d4d5a0dc..60a6d4ad6d 100644 --- a/psyneulink/library/compositions/pytorchwrappers.py +++ b/psyneulink/library/compositions/pytorchwrappers.py @@ -58,12 +58,25 @@ class PytorchCompositionWrapper(torch.nn.Module): # # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT # class PytorchCompositionWrapper(torch.jit.ScriptModule): # MODIFIED 7/29/24 END - """Wrapper for a Composition as a Pytorch Module - Class that wraps a `Composition ` as a PyTorch module. + """Wrapper for a Composition as a Pytorch Module. + + Wraps an `AutodiffComposition` as a `PyTorch module + `_, with each `Mechanism ` in the + AutodiffComposition wrapped as a `PytorchMechanismWrapper`, each `Projection ` wrapped as a + `PytorchProjectionWrapper`, and any nested Compositions wrapped as `PytorchCompositionWrapper`\\s. Each + PytorchMechanismWrapper implements a Pytorch version of the `function(s) ` of the wrapped + `Mechanism`, which are executed in the PyTorchCompositionWrapper's `forward ` + method in the order specified by the AutodiffComposition's `scheduler `. The + `matrix ` Parameters of each wrapped `Projection` are assigned as parameters of the + `PytorchMechanismWrapper` Pytorch module and used, together with a Pytorch `matmul + `_ operation, to generate the input to each + PyTorch function as specified by the `PytorchProjectionWrapper`\\'s `graph `. The graph + can be visualized using the AutodiffComposition's `show_graph ` method and setting its + *show_pytorch* argument to True (see `PytorchShowGraph` for additional information). Two main responsibilities: - 1) Set up parameters of PyTorch model & information required for forward computation: + 1) Set up functions and parameters of PyTorch module required for it forward computation: Handle nested compositions (flattened in infer_backpropagation_learning_pathways): Deal with Projections into and/or out of a nested Composition as shown in figure below: (note: Projections in outer Composition to/from a nested Composition's CIMs are learnable, @@ -115,11 +128,12 @@ class PytorchCompositionWrapper(torch.nn.Module): `AutodiffComposition` being wrapped. wrapped_nodes : List[PytorchMechanismWrapper] - list of nodes in the PytorchCompositionWrapper corresponding to PyTorch functions. Generally these are - `Mechanisms ` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition` - being wrapped is itself a nested Composition, then the wrapped nodes are `PytorchCompositionWrapper` objects. + list of nodes in the PytorchCompositionWrapper corresponding to the PyTorch functions that comprise the + forward method of the Pytorch module implemented by the PytorchCompositionWrapper. Generally these are + `Mechanisms ` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition` Node + being wrapped is a nested Composition, then the wrapped node is itself a `PytorchCompositionWrapper` object. When the PyTorch model is executed, all of these are "flattened" into a single PyTorch module, corresponding - to the outermost AutodiffComposition being wrapped,, which can be visualized using that AutodiffComposition's + to the outermost AutodiffComposition being wrapped, which can be visualized using that AutodiffComposition's `show_graph ` method and setting its *show_pytorch* argument to True (see `PytorchShowGraph` for additional information).