From b6c03ed91525962d504d70ee8400a8d87d5c49e0 Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Tue, 8 Nov 2022 19:58:49 -0500 Subject: [PATCH] EpisodicMemoryMechanism: make memory (_memory_init) a FunctionParameter Shared with its function initializer. Changes conflict behavior to be consistent with other SharedParameters (function value favored over owner value). For discussion on this, see https://github.com/PrincetonUniversity/PsyNeuLink/issues/2600 --- .../integrator/episodicmemorymechanism.py | 37 +++++++++---------- tests/mechanisms/test_episodic_memory.py | 7 +++- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/psyneulink/library/components/mechanisms/processing/integrator/episodicmemorymechanism.py b/psyneulink/library/components/mechanisms/processing/integrator/episodicmemorymechanism.py index 891d377d17c..815aa969ab3 100644 --- a/psyneulink/library/components/mechanisms/processing/integrator/episodicmemorymechanism.py +++ b/psyneulink/library/components/mechanisms/processing/integrator/episodicmemorymechanism.py @@ -404,6 +404,7 @@ """ +import copy import warnings from typing import Optional, Union @@ -416,7 +417,7 @@ from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism_Base from psyneulink.core.components.ports.inputport import InputPort from psyneulink.core.globals.keywords import EPISODIC_MEMORY_MECHANISM, INITIALIZER, NAME, OWNER_VALUE, VARIABLE -from psyneulink.core.globals.parameters import Parameter, check_user_specified +from psyneulink.core.globals.parameters import FunctionParameter, Parameter, check_user_specified from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set from psyneulink.core.globals.utilities import deprecation_warning, convert_to_np_array, convert_all_elements_to_np_array @@ -508,6 +509,13 @@ class Parameters(ProcessingMechanism_Base.Parameters): """ variable = Parameter([[0,0]], pnl_internal=True, constructor_argument='default_variable') function = Parameter(ContentAddressableMemory, stateful=False, loggable=False) + memory = FunctionParameter(None, function_parameter_name='initializer') + + def _parse_memory(self, memory): + if memory is None: + return memory + + return ContentAddressableMemory._enforce_memory_shape(memory) @check_user_specified def __init__(self, @@ -538,8 +546,6 @@ def __init__(self, size += kwargs['assoc_size'] kwargs.pop('assoc_size') - self._memory_init = memory - super().__init__( default_variable=default_variable, size=size, @@ -547,6 +553,7 @@ def __init__(self, params=params, name=name, prefs=prefs, + memory=memory, **kwargs ) @@ -564,18 +571,15 @@ def _handle_default_variable(self, default_variable=None, size=None, input_ports variable_shape = convert_all_elements_to_np_array(default_variable).shape \ if default_variable is not None else None function_instance = self.function if isinstance(self.function, Function) else None - function_type = self.function if isinstance(self.function, type) else self.function.__class__ # **memory** arg is specified in constructor, so use that to initialize or validate default_variable - if self._memory_init: - try: - self._memory_init = function_type._enforce_memory_shape(self._memory_init) - except: - pass + if self.parameters.memory._user_specified: + memory = self.defaults.memory + if default_variable is None: - default_variable = self._memory_init[0] + default_variable = copy.deepcopy(memory[0]) else: - entry_shape = convert_all_elements_to_np_array(self._memory_init[0]).shape + entry_shape = convert_all_elements_to_np_array(memory[0]).shape if entry_shape != variable_shape: raise EpisodicMemoryMechanismError(f"Shape of 'variable' for {self.name} ({variable_shape}) " f"does not match the shape of entries ({entry_shape}) in " @@ -610,14 +614,9 @@ def _instantiate_input_ports(self, context=None): def _instantiate_function(self, function, function_params, context): """Assign memory to function if specified in Mechanism's constructor""" - if self._memory_init is not None: - if isinstance(function, type): - function_params.update({INITIALIZER:self._memory_init}) - else: - if len(function.memory): - warnings.warn(f"The 'memory' argument specified for {self.name} will override the specification " - f"for the {repr(INITIALIZER)} argument of its function ({self.function.name}).") - function.reset(self._memory_init) + memory = self.parameters.memory._get(context) + if memory is not None: + function.reset(memory) super()._instantiate_function(function, function_params, context) def _instantiate_output_ports(self, context=None): diff --git a/tests/mechanisms/test_episodic_memory.py b/tests/mechanisms/test_episodic_memory.py index 479becb96ee..409db0eed32 100644 --- a/tests/mechanisms/test_episodic_memory.py +++ b/tests/mechanisms/test_episodic_memory.py @@ -221,8 +221,11 @@ def test_with_contentaddressablememory(name, func, func_params, mech_params, tes def test_contentaddressable_memory_warnings_and_errors(): # both memory arg of Mechanism and initializer for its function are specified - text = "The 'memory' argument specified for EpisodicMemoryMechanism-0 will override the specification " \ - "for the 'initializer' argument of its function" + text = ( + r"Specification of the \"memory\" parameter[.\S\s]*The value" + + r" specified on \(ContentAddressableMemory ContentAddressableMemory" + + r" Function-\d\) will be used\." + ) with pytest.warns(UserWarning, match=text): em = EpisodicMemoryMechanism( memory = [[[1,2,3],[4,5,6]]],