Skip to content

Commit

Permalink
pytorch: implement input port function execution
Browse files Browse the repository at this point in the history
Previously, pytorch Mechansims would only use the equivalent of a
SUM LinearCombination for each of their input ports.

This implements the use in pytorch of the actual Functions assigned to
the input ports. When the pytorch-equivalent of a psyneulink Function
(via _gen_pytorch_fct) is unavailable, an error is thrown.
  • Loading branch information
kmantel committed Dec 18, 2024
1 parent b9c2b1a commit 7a4667f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3591,7 +3591,7 @@ def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor:
_mask = (torch.abs(_input) > mask_threshold)
# Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask
masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0]))
if not any(masked_exp):
if (masked_exp == 0).all():
return masked_exp
return masked_exp / torch.sum(masked_exp, -1, keepdim=True)
# Return the function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1593,14 +1593,14 @@ def _gen_pytorch_fct(self, device, context=None):
weights = torch.tensor(weights, device=device).double()
if self.operation == SUM:
if weights is not None:
return lambda x: torch.sum(torch.stack(x) * weights, 0)
return lambda x: torch.sum(x * weights, 0)
else:
return lambda x: torch.sum(torch.stack(x), 0)
return lambda x: torch.sum(x, 0)
elif self.operation == PRODUCT:
if weights is not None:
return lambda x: torch.prod(torch.stack(x) * weights, 0)
return lambda x: torch.prod(x * weights, 0)
else:
return lambda x: torch.prod(torch.stack(x), 0)
return lambda x: torch.prod(x, 0)
else:
from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
raise AutodiffCompositionError(f"The 'operation' parameter of {function.componentName} is not supported "
Expand Down
9 changes: 7 additions & 2 deletions psyneulink/library/compositions/autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,8 +1124,13 @@ def autodiff_forward(self, inputs, targets,
for component in curr_tensors_for_trained_outputs.keys():
trial_loss = 0
for i in range(len(curr_tensors_for_trained_outputs[component])):
trial_loss += self.loss_function(curr_tensors_for_trained_outputs[component][i],
curr_target_tensors_for_trained_outputs[component][i])
# loss only accepts 0 or 1d target. reshape assuming pytorch_rep.minibatch_loss dim is correct
comp_loss = self.loss_function(
curr_tensors_for_trained_outputs[component][i],
torch.atleast_1d(curr_target_tensors_for_trained_outputs[component][i].squeeze())
)
comp_loss = comp_loss.reshape_as(pytorch_rep.minibatch_loss)
trial_loss += comp_loss
pytorch_rep.minibatch_loss += trial_loss
pytorch_rep.minibatch_loss_count += 1

Expand Down
114 changes: 72 additions & 42 deletions psyneulink/library/compositions/pytorchwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from enum import Enum, auto

from psyneulink.core.components.functions.nonstateful.transformfunctions import LinearCombination, PRODUCT, SUM
from psyneulink.core.components.functions.stateful.integratorfunctions import IntegratorFunction
from psyneulink.core.components.functions.stateful import StatefulFunction
from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism
Expand All @@ -30,7 +29,7 @@
NODE, NODE_VALUES, NODE_VARIABLES, OUTPUTS, RESULTS, RUN,
TARGETS, TARGET_MECHANISM, )
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
from psyneulink.core.globals.utilities import convert_to_np_array, get_deepcopy_with_shared, convert_to_list
from psyneulink.core.globals.utilities import convert_to_list, convert_to_np_array, get_deepcopy_with_shared
from psyneulink.core.globals.log import LogCondition
from psyneulink.core import llvm as pnlvm

Expand Down Expand Up @@ -632,7 +631,7 @@ def forward(self, inputs, optimization_rep, context=None)->dict:
variable.append(input[i])
elif input_port.default_input == DEFAULT_VARIABLE:
# input_port uses a bias, so get that
variable.append(input_port.defaults.variable)
variable.append(torch.from_numpy(input_port.defaults.variable))

# Input for the Mechanism is *not* explicitly specified, but its input_port(s) may have been
else:
Expand All @@ -644,15 +643,14 @@ def forward(self, inputs, optimization_rep, context=None)->dict:
variable.append(inputs[input_port])
elif input_port.default_input == DEFAULT_VARIABLE:
# input_port uses a bias, so get that
variable.append(input_port.defaults.variable)
variable.append(torch.from_numpy(input_port.defaults.variable))
elif not input_port.internal_only:
# otherwise, use the node's input_port's afferents
variable.append(node.aggregate_afferents(i).squeeze(0))
if len(variable) == 1:
variable = variable[0]
variable.append(node.aggregate_afferents(i))
else:
# Node is not INPUT to Composition or BIAS, so get all input from its afferents
variable = node.aggregate_afferents()
variable = node.execute_input_ports(variable)

if node.exclude_from_gradient_calc:
if node.exclude_from_gradient_calc == AFTER:
Expand Down Expand Up @@ -926,6 +924,11 @@ def __init__(self,
self.integrator_function = PytorchFunctionWrapper(mechanism.integrator_function, device, context)
self.integrator_previous_value = mechanism.integrator_function._get_pytorch_fct_param_value('initializer', device, context)

self.input_ports = [
PytorchFunctionWrapper(ip.function, device, context)
for ip in mechanism.input_ports
]

def add_efferent(self, efferent):
"""Add ProjectionWrapper for efferent from MechanismWrapper.
Implemented for completeness; not currently used
Expand Down Expand Up @@ -955,54 +958,83 @@ def aggregate_afferents(self, port=None):
proj_wrapper._curr_sender_value = proj_wrapper.sender.output[proj_wrapper._value_idx]
else:
proj_wrapper._curr_sender_value = torch.tensor(proj_wrapper.default_value)
proj_wrapper._curr_sender_value = torch.atleast_1d(proj_wrapper._curr_sender_value)

# Specific port is specified
# FIX: USING _port_idx TO INDEX INTO sender.value GETS IT WRONG IF THE MECHANISM HAS AN OUTPUT PORT
# USED BY A PROJECTION NOT IN THE CURRENT COMPOSITION
if port is not None:
return sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0)
for proj_wrapper in self.afferents
if proj_wrapper._pnl_proj
in self._mechanism.input_ports[port].path_afferents)
# Has only one input_port
elif len(self._mechanism.input_ports) == 1:
# Get value corresponding to port from which each afferent projects
return sum((proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0)
for proj_wrapper in self.afferents))
# Has multiple input_ports
res = [
proj_wrapper.execute(proj_wrapper._curr_sender_value)
for proj_wrapper in self.afferents
if proj_wrapper._pnl_proj in self._mechanism.input_ports[port].path_afferents
]
else:
return [sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0)
for proj_wrapper in self.afferents
if proj_wrapper._pnl_proj in input_port.path_afferents)
for input_port in self._mechanism.input_ports]
res = []
for input_port in self._mechanism.input_ports:
ip_res = []
for proj_wrapper in self.afferents:
if proj_wrapper._pnl_proj in input_port.path_afferents:
ip_res.append(proj_wrapper.execute(proj_wrapper._curr_sender_value))
res.append(torch.stack(ip_res))
try:
res = torch.stack(res)
except (RuntimeError, TypeError):
# is ragged, will handle ports individually during execute
pass
return res

def execute_input_ports(self, variable):
from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction

if not isinstance(variable, torch.Tensor):
try:
variable = torch.stack(variable)
except (RuntimeError, TypeError):
# ragged
pass

# must iterate over at least 1d input per port
variable = torch.atleast_2d(variable)

res = []
for i in range(len(self.input_ports)):
v = variable[i]
if isinstance(self.input_ports[i]._pnl_function, TransformFunction):
# atleast_2d to account for input port dimension reduction
v = torch.atleast_2d(v)

res.append(self.input_ports[i].function(v))

try:
res = torch.stack(res)
except (RuntimeError, TypeError):
# ragged
pass
return res

def execute(self, variable, context):
"""Execute Mechanism's _gen_pytorch version of function on variable.
Enforce result to be 2d, and assign to self.output
"""
def execute_function(function, variable, fct_has_mult_args=False, is_combination_fct=False):
def execute_function(function, variable, fct_has_mult_args=False):
"""Execute _gen_pytorch_fct on variable, enforce result to be 2d, and return it
If fct_has_mult_args is True, treat each item in variable as an arg to the function
If False, compute function for each item in variable and return results in a list
"""
if ((isinstance(variable, list) and len(variable) == 1)
or (isinstance(variable, torch.Tensor) and len(variable.squeeze(0).shape) == 1)
or isinstance(self._mechanism.function, LinearCombination)):
# Enforce 2d on value of MechanismWrapper (using unsqueeze) for single InputPort
# or if TransformFunction (which reduces output to single item from multi-item input)
if isinstance(variable, torch.Tensor):
variable = variable.squeeze(0)
return function(variable).unsqueeze(0)
elif is_combination_fct:
# Function combines the elements
return function(variable)
elif fct_has_mult_args:
# Assign each element of variable as an arg to the function
return function(*variable)
from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction
if fct_has_mult_args:
res = function(*variable)
# variable is ragged
elif isinstance(variable, list):
res = [function(variable[i]) for i in range(len(variable))]
else:
# Treat each item in variable as a separate input to the function and get result for each in a list:
# make return value 2d by creating list of the results of function returned for each item in variable
return [function(variable[i].squeeze(0)) for i in range(len(variable))]
res = function(variable)
# TransformFunction can reduce output to single item from
# multi-item input
if isinstance(function._pnl_function, TransformFunction):
res = res.unsqueeze(0)
return res

# If mechanism has an integrator_function and integrator_mode is True,
# execute it first and use result as input to the main function;
Expand All @@ -1017,9 +1049,7 @@ def execute_function(function, variable, fct_has_mult_args=False, is_combination
self.input = variable

# Compute main function of mechanism and return result
from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction
self.output = execute_function(self.function, variable,
is_combination_fct=isinstance(self._mechanism.function, TransformFunction))
self.output = execute_function(self.function, variable)
return self.output

def _gen_llvm_execute(self, ctx, builder, state, params, mech_input, data):
Expand Down

0 comments on commit 7a4667f

Please sign in to comment.