From a6e322cafa5d39785ff98476a4b09f363d53dd78 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Thu, 4 Jul 2024 01:19:32 +0000 Subject: [PATCH 1/3] PytorchMechanismWrapper: throw function unavailable error more strictly --- .../library/compositions/pytorchwrappers.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/psyneulink/library/compositions/pytorchwrappers.py b/psyneulink/library/compositions/pytorchwrappers.py index 5fc9f41dbb..5439300570 100644 --- a/psyneulink/library/compositions/pytorchwrappers.py +++ b/psyneulink/library/compositions/pytorchwrappers.py @@ -42,6 +42,18 @@ class DataTypeEnum(Enum): TARGETS = auto() LOSSES = auto() + +def _get_pytorch_function(obj, device, context): + pytorch_fct = getattr(obj, '_gen_pytorch_fct', None) + if pytorch_fct is None: + from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError + raise AutodiffCompositionError( + f"Function {obj} is not currently supported by AutodiffComposition" + ) + else: + return pytorch_fct(device, context) + + # # MODIFIED 7/29/24 OLD: class PytorchCompositionWrapper(torch.nn.Module): # # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT @@ -908,21 +920,11 @@ def __init__(self, self.afferents = [] self.efferents = [] - from psyneulink.core.components.functions.function import FunctionError - from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError - try: - pnl_fct = mechanism.function - self.function = pnl_fct._gen_pytorch_fct(device, context) - if hasattr(mechanism, 'integrator_function'): - pnl_fct = mechanism.integrator_function - self.integrator_function = pnl_fct._gen_pytorch_fct(device, context) - self.integrator_previous_value = pnl_fct._get_pytorch_fct_param_value('initializer', device, context) - except FunctionError as error: - from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError - raise AutodiffCompositionError(error.args[0]) - except: - raise AutodiffCompositionError(f"Function {pnl_fct} is not currently supported by AutodiffComposition") + self.function = PytorchFunctionWrapper(mechanism.function, device, context) + if hasattr(mechanism, 'integrator_function'): + self.integrator_function = PytorchFunctionWrapper(mechanism.integrator_function, device, context) + self.integrator_previous_value = mechanism.integrator_function._get_pytorch_fct_param_value('initializer', device, context) def add_efferent(self, efferent): """Add ProjectionWrapper for efferent from MechanismWrapper. @@ -1238,3 +1240,17 @@ def _gen_llvm_execute(self, ctx, builder, state, params, data): def __repr__(self): return "PytorchWrapper for: " +self._projection.__repr__() + + +class PytorchFunctionWrapper(): + def __init__(self, function, device, context=None): + self._pnl_function = function + self.name = f"PytorchFunctionWrapper[{function.name}]" + self._context = context + self.function = _get_pytorch_function(function, device, context) + + def __repr__(self): + return "PytorchWrapper for: " + self._pnl_function.__repr__() + + def __call__(self, *args, **kwargs): + return self.function(*args, **kwargs) From b9c2b1a60d50cfbb118316c6c825e42551e007bc Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Wed, 11 Dec 2024 03:29:56 +0000 Subject: [PATCH 2/3] SoftMax: use last dimension for pytorch softmax handles 1d or 2d input --- .../components/functions/nonstateful/transferfunctions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index 5e2550b150..f3936f74be 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -3583,22 +3583,22 @@ def _gen_pytorch_fct(self, device, context=None): mask_threshold = self._get_pytorch_fct_param_value('mask_threshold', device, context) if isinstance(gain, str) and gain == ADAPTIVE: - return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, 0)) + return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, -1)) elif mask_threshold: def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor: # Mask elements of input below threshold _mask = (torch.abs(_input) > mask_threshold) # Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask - masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, 0, keepdim=True)[0])) + masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0])) if not any(masked_exp): return masked_exp - return masked_exp / torch.sum(masked_exp, 0, keepdim=True) + return masked_exp / torch.sum(masked_exp, -1, keepdim=True) # Return the function return pytorch_thresholded_softmax else: - return lambda x: (torch.softmax(gain * x, 0)) + return lambda x: (torch.softmax(gain * x, -1)) def _gen_pytorch_adapt_gain_fct(self, device, context=None): scale = self._get_pytorch_fct_param_value('adapt_scale', device, context) From 7a4667fde2698ae590c443c8f46d92af028f92d0 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Tue, 17 Dec 2024 02:42:04 +0000 Subject: [PATCH 3/3] pytorch: implement input port function execution Previously, pytorch Mechansims would only use the equivalent of a SUM LinearCombination for each of their input ports. This implements the use in pytorch of the actual Functions assigned to the input ports. When the pytorch-equivalent of a psyneulink Function (via _gen_pytorch_fct) is unavailable, an error is thrown. --- .../nonstateful/transferfunctions.py | 2 +- .../nonstateful/transformfunctions.py | 8 +- .../compositions/autodiffcomposition.py | 9 +- .../library/compositions/pytorchwrappers.py | 114 +++++++++++------- 4 files changed, 84 insertions(+), 49 deletions(-) diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index f3936f74be..26c96cd3a8 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -3591,7 +3591,7 @@ def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor: _mask = (torch.abs(_input) > mask_threshold) # Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0])) - if not any(masked_exp): + if (masked_exp == 0).all(): return masked_exp return masked_exp / torch.sum(masked_exp, -1, keepdim=True) # Return the function diff --git a/psyneulink/core/components/functions/nonstateful/transformfunctions.py b/psyneulink/core/components/functions/nonstateful/transformfunctions.py index 99733aad2b..e3258ffc0f 100644 --- a/psyneulink/core/components/functions/nonstateful/transformfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transformfunctions.py @@ -1593,14 +1593,14 @@ def _gen_pytorch_fct(self, device, context=None): weights = torch.tensor(weights, device=device).double() if self.operation == SUM: if weights is not None: - return lambda x: torch.sum(torch.stack(x) * weights, 0) + return lambda x: torch.sum(x * weights, 0) else: - return lambda x: torch.sum(torch.stack(x), 0) + return lambda x: torch.sum(x, 0) elif self.operation == PRODUCT: if weights is not None: - return lambda x: torch.prod(torch.stack(x) * weights, 0) + return lambda x: torch.prod(x * weights, 0) else: - return lambda x: torch.prod(torch.stack(x), 0) + return lambda x: torch.prod(x, 0) else: from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError raise AutodiffCompositionError(f"The 'operation' parameter of {function.componentName} is not supported " diff --git a/psyneulink/library/compositions/autodiffcomposition.py b/psyneulink/library/compositions/autodiffcomposition.py index 003b43db76..f9223aa761 100644 --- a/psyneulink/library/compositions/autodiffcomposition.py +++ b/psyneulink/library/compositions/autodiffcomposition.py @@ -1124,8 +1124,13 @@ def autodiff_forward(self, inputs, targets, for component in curr_tensors_for_trained_outputs.keys(): trial_loss = 0 for i in range(len(curr_tensors_for_trained_outputs[component])): - trial_loss += self.loss_function(curr_tensors_for_trained_outputs[component][i], - curr_target_tensors_for_trained_outputs[component][i]) + # loss only accepts 0 or 1d target. reshape assuming pytorch_rep.minibatch_loss dim is correct + comp_loss = self.loss_function( + curr_tensors_for_trained_outputs[component][i], + torch.atleast_1d(curr_target_tensors_for_trained_outputs[component][i].squeeze()) + ) + comp_loss = comp_loss.reshape_as(pytorch_rep.minibatch_loss) + trial_loss += comp_loss pytorch_rep.minibatch_loss += trial_loss pytorch_rep.minibatch_loss_count += 1 diff --git a/psyneulink/library/compositions/pytorchwrappers.py b/psyneulink/library/compositions/pytorchwrappers.py index 5439300570..ab8ce37e7c 100644 --- a/psyneulink/library/compositions/pytorchwrappers.py +++ b/psyneulink/library/compositions/pytorchwrappers.py @@ -17,7 +17,6 @@ from enum import Enum, auto -from psyneulink.core.components.functions.nonstateful.transformfunctions import LinearCombination, PRODUCT, SUM from psyneulink.core.components.functions.stateful.integratorfunctions import IntegratorFunction from psyneulink.core.components.functions.stateful import StatefulFunction from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism @@ -30,7 +29,7 @@ NODE, NODE_VALUES, NODE_VARIABLES, OUTPUTS, RESULTS, RUN, TARGETS, TARGET_MECHANISM, ) from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context -from psyneulink.core.globals.utilities import convert_to_np_array, get_deepcopy_with_shared, convert_to_list +from psyneulink.core.globals.utilities import convert_to_list, convert_to_np_array, get_deepcopy_with_shared from psyneulink.core.globals.log import LogCondition from psyneulink.core import llvm as pnlvm @@ -632,7 +631,7 @@ def forward(self, inputs, optimization_rep, context=None)->dict: variable.append(input[i]) elif input_port.default_input == DEFAULT_VARIABLE: # input_port uses a bias, so get that - variable.append(input_port.defaults.variable) + variable.append(torch.from_numpy(input_port.defaults.variable)) # Input for the Mechanism is *not* explicitly specified, but its input_port(s) may have been else: @@ -644,15 +643,14 @@ def forward(self, inputs, optimization_rep, context=None)->dict: variable.append(inputs[input_port]) elif input_port.default_input == DEFAULT_VARIABLE: # input_port uses a bias, so get that - variable.append(input_port.defaults.variable) + variable.append(torch.from_numpy(input_port.defaults.variable)) elif not input_port.internal_only: # otherwise, use the node's input_port's afferents - variable.append(node.aggregate_afferents(i).squeeze(0)) - if len(variable) == 1: - variable = variable[0] + variable.append(node.aggregate_afferents(i)) else: # Node is not INPUT to Composition or BIAS, so get all input from its afferents variable = node.aggregate_afferents() + variable = node.execute_input_ports(variable) if node.exclude_from_gradient_calc: if node.exclude_from_gradient_calc == AFTER: @@ -926,6 +924,11 @@ def __init__(self, self.integrator_function = PytorchFunctionWrapper(mechanism.integrator_function, device, context) self.integrator_previous_value = mechanism.integrator_function._get_pytorch_fct_param_value('initializer', device, context) + self.input_ports = [ + PytorchFunctionWrapper(ip.function, device, context) + for ip in mechanism.input_ports + ] + def add_efferent(self, efferent): """Add ProjectionWrapper for efferent from MechanismWrapper. Implemented for completeness; not currently used @@ -955,54 +958,83 @@ def aggregate_afferents(self, port=None): proj_wrapper._curr_sender_value = proj_wrapper.sender.output[proj_wrapper._value_idx] else: proj_wrapper._curr_sender_value = torch.tensor(proj_wrapper.default_value) + proj_wrapper._curr_sender_value = torch.atleast_1d(proj_wrapper._curr_sender_value) # Specific port is specified # FIX: USING _port_idx TO INDEX INTO sender.value GETS IT WRONG IF THE MECHANISM HAS AN OUTPUT PORT # USED BY A PROJECTION NOT IN THE CURRENT COMPOSITION if port is not None: - return sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents - if proj_wrapper._pnl_proj - in self._mechanism.input_ports[port].path_afferents) - # Has only one input_port - elif len(self._mechanism.input_ports) == 1: - # Get value corresponding to port from which each afferent projects - return sum((proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents)) - # Has multiple input_ports + res = [ + proj_wrapper.execute(proj_wrapper._curr_sender_value) + for proj_wrapper in self.afferents + if proj_wrapper._pnl_proj in self._mechanism.input_ports[port].path_afferents + ] else: - return [sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents - if proj_wrapper._pnl_proj in input_port.path_afferents) - for input_port in self._mechanism.input_ports] + res = [] + for input_port in self._mechanism.input_ports: + ip_res = [] + for proj_wrapper in self.afferents: + if proj_wrapper._pnl_proj in input_port.path_afferents: + ip_res.append(proj_wrapper.execute(proj_wrapper._curr_sender_value)) + res.append(torch.stack(ip_res)) + try: + res = torch.stack(res) + except (RuntimeError, TypeError): + # is ragged, will handle ports individually during execute + pass + return res + + def execute_input_ports(self, variable): + from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction + + if not isinstance(variable, torch.Tensor): + try: + variable = torch.stack(variable) + except (RuntimeError, TypeError): + # ragged + pass + + # must iterate over at least 1d input per port + variable = torch.atleast_2d(variable) + + res = [] + for i in range(len(self.input_ports)): + v = variable[i] + if isinstance(self.input_ports[i]._pnl_function, TransformFunction): + # atleast_2d to account for input port dimension reduction + v = torch.atleast_2d(v) + + res.append(self.input_ports[i].function(v)) + + try: + res = torch.stack(res) + except (RuntimeError, TypeError): + # ragged + pass + return res def execute(self, variable, context): """Execute Mechanism's _gen_pytorch version of function on variable. Enforce result to be 2d, and assign to self.output """ - def execute_function(function, variable, fct_has_mult_args=False, is_combination_fct=False): + def execute_function(function, variable, fct_has_mult_args=False): """Execute _gen_pytorch_fct on variable, enforce result to be 2d, and return it If fct_has_mult_args is True, treat each item in variable as an arg to the function If False, compute function for each item in variable and return results in a list """ - if ((isinstance(variable, list) and len(variable) == 1) - or (isinstance(variable, torch.Tensor) and len(variable.squeeze(0).shape) == 1) - or isinstance(self._mechanism.function, LinearCombination)): - # Enforce 2d on value of MechanismWrapper (using unsqueeze) for single InputPort - # or if TransformFunction (which reduces output to single item from multi-item input) - if isinstance(variable, torch.Tensor): - variable = variable.squeeze(0) - return function(variable).unsqueeze(0) - elif is_combination_fct: - # Function combines the elements - return function(variable) - elif fct_has_mult_args: - # Assign each element of variable as an arg to the function - return function(*variable) + from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction + if fct_has_mult_args: + res = function(*variable) + # variable is ragged + elif isinstance(variable, list): + res = [function(variable[i]) for i in range(len(variable))] else: - # Treat each item in variable as a separate input to the function and get result for each in a list: - # make return value 2d by creating list of the results of function returned for each item in variable - return [function(variable[i].squeeze(0)) for i in range(len(variable))] + res = function(variable) + # TransformFunction can reduce output to single item from + # multi-item input + if isinstance(function._pnl_function, TransformFunction): + res = res.unsqueeze(0) + return res # If mechanism has an integrator_function and integrator_mode is True, # execute it first and use result as input to the main function; @@ -1017,9 +1049,7 @@ def execute_function(function, variable, fct_has_mult_args=False, is_combination self.input = variable # Compute main function of mechanism and return result - from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction - self.output = execute_function(self.function, variable, - is_combination_fct=isinstance(self._mechanism.function, TransformFunction)) + self.output = execute_function(self.function, variable) return self.output def _gen_llvm_execute(self, ctx, builder, state, params, mech_input, data):