Skip to content

Commit

Permalink
Fix Discrete Action. Implement template CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 4, 2024
1 parent 88735d1 commit 472831c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 5 deletions.
1 change: 1 addition & 0 deletions diploma_thesis/agents/utils/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Record:
POLICY_KEY = "policy"
VALUES_KEY = "values"
REWARD_KEY = "reward"
RETURN_KEY = "return"
ACTION_KEY = "actions"
ADVANTAGE_KEY = "advantage"

Expand Down
11 changes: 6 additions & 5 deletions diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ def __init__(self,
def update(self, phase: Phase):
self.phase = phase

for module in [self.value_model, self.advantage_model, self.action_selector]:
for module in [self.value_model, self.action_model, self.action_selector]:
if isinstance(module, PhaseUpdatable):
module.update(phase)

def __call__(self, state: State, parameters: Input) -> Record:
values, actions = self.predict(state).view(-1)
values, actions = self.predict(state)
values, actions = values.squeeze(), actions.squeeze()
action, policy = self.action_selector(actions)
action = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.long)

Expand All @@ -52,13 +53,13 @@ def __call__(self, state: State, parameters: Input) -> Record:
return Record(state, action, info, batch_size=[])

def predict(self, state: State):
values = None
actions = None
values = torch.tensor(0, dtype=torch.long)
actions = torch.tensor(0, dtype=torch.long)

if self.value_model is not None:
values = self.value_model(state)

if self.advantage_model is not None:
if self.action_model is not None:
actions = self.action_model(state)

match self.policy_estimation_method:
Expand Down
1 change: 1 addition & 0 deletions diploma_thesis/simulator/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from agents.base import Agent


class TDSimulator(Simulator):
"""
A simulator, which estimates returns in Temporal Difference manner and send information for training as soon as
Expand Down
45 changes: 45 additions & 0 deletions diploma_thesis/utils/modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def modified(parameters):
with open(base_path) as file:
base_parameters = yaml.safe_load(file)

template = dict()

if 'template' in parameters:
template = __load_template__(parameters, base_path)

mods = parameters['mods']

mods_dir = os.path.dirname(base_path)
Expand All @@ -20,4 +25,44 @@ def modified(parameters):
mod = yaml.safe_load(file)
base_parameters = merge_dicts(base_parameters, mod)

base_parameters = __apply_template__(base_parameters, template)

return base_parameters


def __load_template__(parameters, base_path):
template_path = os.path.dirname(base_path)
template_path = os.path.join(template_path, 'templates', parameters['template'])

values = dict()

for file in os.listdir(template_path):
if file.endswith('.yml'):
with open(os.path.join(template_path, file)) as f:
template = yaml.safe_load(f)
values[os.path.basename(file).split('.')[0]] = template

values = {f'__{k}__': v for k, v in values.items()}

return values


def __apply_template__(parameters, template):
if isinstance(parameters, dict):
updates = dict()

for k, v in parameters.items():
if k in template:
updates[k] = template[k]
else:
parameters[k] = __apply_template__(v, template)

for k, v in updates.items():
del parameters[k]
parameters.update(v)

if isinstance(parameters, list):
for i, v in enumerate(parameters):
parameters[i] = __apply_template__(v, template)

return parameters

0 comments on commit 472831c

Please sign in to comment.