From 472831c98a91e49e7e338b88ff0da9a5694c445d Mon Sep 17 00:00:00 2001 From: Yury Hayeu Date: Mon, 4 Mar 2024 20:52:40 +0100 Subject: [PATCH] Fix Discrete Action. Implement template CLI --- diploma_thesis/agents/utils/memory/memory.py | 1 + .../agents/utils/policy/discrete_action.py | 11 ++--- diploma_thesis/simulator/td.py | 1 + diploma_thesis/utils/modified.py | 45 +++++++++++++++++++ 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/diploma_thesis/agents/utils/memory/memory.py b/diploma_thesis/agents/utils/memory/memory.py index 9762a03..d18725b 100644 --- a/diploma_thesis/agents/utils/memory/memory.py +++ b/diploma_thesis/agents/utils/memory/memory.py @@ -18,6 +18,7 @@ class Record: POLICY_KEY = "policy" VALUES_KEY = "values" REWARD_KEY = "reward" + RETURN_KEY = "return" ACTION_KEY = "actions" ADVANTAGE_KEY = "advantage" diff --git a/diploma_thesis/agents/utils/policy/discrete_action.py b/diploma_thesis/agents/utils/policy/discrete_action.py index 27f4e65..4e687c5 100644 --- a/diploma_thesis/agents/utils/policy/discrete_action.py +++ b/diploma_thesis/agents/utils/policy/discrete_action.py @@ -34,12 +34,13 @@ def __init__(self, def update(self, phase: Phase): self.phase = phase - for module in [self.value_model, self.advantage_model, self.action_selector]: + for module in [self.value_model, self.action_model, self.action_selector]: if isinstance(module, PhaseUpdatable): module.update(phase) def __call__(self, state: State, parameters: Input) -> Record: - values, actions = self.predict(state).view(-1) + values, actions = self.predict(state) + values, actions = values.squeeze(), actions.squeeze() action, policy = self.action_selector(actions) action = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.long) @@ -52,13 +53,13 @@ def __call__(self, state: State, parameters: Input) -> Record: return Record(state, action, info, batch_size=[]) def predict(self, state: State): - values = None - actions = None + values = torch.tensor(0, dtype=torch.long) + actions = torch.tensor(0, dtype=torch.long) if self.value_model is not None: values = self.value_model(state) - if self.advantage_model is not None: + if self.action_model is not None: actions = self.action_model(state) match self.policy_estimation_method: diff --git a/diploma_thesis/simulator/td.py b/diploma_thesis/simulator/td.py index ea20a08..664bf3b 100644 --- a/diploma_thesis/simulator/td.py +++ b/diploma_thesis/simulator/td.py @@ -3,6 +3,7 @@ from agents.base import Agent + class TDSimulator(Simulator): """ A simulator, which estimates returns in Temporal Difference manner and send information for training as soon as diff --git a/diploma_thesis/utils/modified.py b/diploma_thesis/utils/modified.py index 88f4dd4..c85b6fd 100644 --- a/diploma_thesis/utils/modified.py +++ b/diploma_thesis/utils/modified.py @@ -10,6 +10,11 @@ def modified(parameters): with open(base_path) as file: base_parameters = yaml.safe_load(file) + template = dict() + + if 'template' in parameters: + template = __load_template__(parameters, base_path) + mods = parameters['mods'] mods_dir = os.path.dirname(base_path) @@ -20,4 +25,44 @@ def modified(parameters): mod = yaml.safe_load(file) base_parameters = merge_dicts(base_parameters, mod) + base_parameters = __apply_template__(base_parameters, template) + return base_parameters + + +def __load_template__(parameters, base_path): + template_path = os.path.dirname(base_path) + template_path = os.path.join(template_path, 'templates', parameters['template']) + + values = dict() + + for file in os.listdir(template_path): + if file.endswith('.yml'): + with open(os.path.join(template_path, file)) as f: + template = yaml.safe_load(f) + values[os.path.basename(file).split('.')[0]] = template + + values = {f'__{k}__': v for k, v in values.items()} + + return values + + +def __apply_template__(parameters, template): + if isinstance(parameters, dict): + updates = dict() + + for k, v in parameters.items(): + if k in template: + updates[k] = template[k] + else: + parameters[k] = __apply_template__(v, template) + + for k, v in updates.items(): + del parameters[k] + parameters.update(v) + + if isinstance(parameters, list): + for i, v in enumerate(parameters): + parameters[i] = __apply_template__(v, template) + + return parameters