Skip to content

Commit

Permalink
Implement recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 22, 2024
1 parent bd92c05 commit c49c377
Show file tree
Hide file tree
Showing 29 changed files with 389 additions and 153 deletions.
17 changes: 1 addition & 16 deletions diploma_thesis/agents/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,12 @@ def __init__(self, policy: Policy[Input]):
super().__init__()

self.policy = policy
self.memory = dict()

def __call__(self, state: State, parameters: Input) -> PolicyRecord:
if self.policy.is_recurrent:
key = self.memory_key(parameters)

assert key is not None, 'Expect that key definition for recurrent policy'

state.memory = self.memory.get(key)
record, memory = self.policy.select(state)

self.memory[key] = memory

return record
state.memory = parameters.memory

return self.policy.select(state)

@classmethod
def memory_key(cls, parameters: Input) -> None | str:
return None

def update(self, phase: Phase):
super().update(phase)

Expand Down
3 changes: 3 additions & 0 deletions diploma_thesis/agents/machine/utils/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from agents.base import Graph
from environment import Machine

from tensordict import TensorDict


@dataclass
class Input:
machine: Machine
now: float
graph: Graph | None
memory: TensorDict | None
2 changes: 2 additions & 0 deletions diploma_thesis/agents/utils/nn/layers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def from_cli(*args, **kwargs):
from .shared import Shared
from .graph_model import GraphModel
from .partial_instance_norm_1d import PartialInstanceNorm1d
from .recurrent import Recurrent
from .output import Output

from utils import from_cli as from_cli_
Expand Down Expand Up @@ -36,6 +37,7 @@ def from_cli(*args, **kwargs):

'select_target': SelectTarget,
'shared': Shared,
'recurrent': Recurrent,
'output': Output
}

Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/nn/layers/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(self, *args):
return result

def __extract_values__(self, keys, leaf_to_arg: dict[str, torch.Tensor]):
result = dict()
result = TensorDict({}, batch_size=[])

for key, nested in keys.items():
if isinstance(nested, dict):
Expand Down
80 changes: 80 additions & 0 deletions diploma_thesis/agents/utils/nn/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

import torch
import uuid

from .layer import *
from torch import nn


class Recurrent(Layer):

def __init__(self, kind: str, memory_key: str, parameters: dict, signature: str):
super().__init__(signature=signature)

self.kind = kind
self.memory_key = memory_key
self.parameters = parameters
self.model = None

self.__build__()

def forward(self, x, memory):
match self.kind:
case 'lstm':
if memory is None:
output, (hidden_state, cell_state) = self.model(x)

return output, hidden_state, cell_state

rnn_memory = memory[self.memory_key]

hidden_state = rnn_memory['hidden_state'].unsqueeze(0)
cell_state = rnn_memory['cell_state'].unsqueeze(0)

if hidden_state.dim() == 2:
output, (hidden_state, cell_state) = self.model(x, (hidden_state, cell_state))

return output, hidden_state, cell_state

x = x.unsqueeze(0)

output, (hidden_state, cell_state) = self.model(x, (hidden_state, cell_state))

return output.squeeze(0), hidden_state, cell_state
case 'rnn' | 'gru':
if memory is None:
return self.model(x)

rnn_memory = memory[self.memory_key].unsqueeze(0)

if rnn_memory.dim() == 2:
return self.model(x, rnn_memory)

x = x.unsqueeze(0)

x, hidden_state = self.model(x, rnn_memory)

return x.squeeze(0), hidden_state
case _:
return None

def __build__(self):
match self.kind:
case 'lstm':
self.model = nn.LSTM(**self.parameters)
case 'gru':
self.model = nn.GRU(**self.parameters)
case 'rnn':
self.model = nn.RNN(**self.parameters)
case _:
raise ValueError(f'Unknown recurrent layer kind: {self.kind}')

@classmethod
def from_cli(cls, parameters: dict) -> 'Layer':
kind = parameters['kind']
memory_key = parameters['memory_key']
signature = parameters['signature']

del parameters['kind'], parameters['memory_key'], parameters['signature']

return Recurrent(kind, memory_key=memory_key, parameters=parameters, signature=signature)
34 changes: 23 additions & 11 deletions diploma_thesis/agents/utils/policy/action_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,17 @@ def encode(self, state: State):
return output

def post_encode(self, state: State, output: TensorDict):
values, actions = self.__fetch_values_and_actions__(output)

return self.__estimate_policy__(values, actions)
return self.__estimate_policy__(output)

def select(self, state: State) -> Record:
value, actions = self.__call__(state)
output = self.__call__(state)
value, actions, memory = self.__fetch_values__(output)
value, actions = value.squeeze(), actions.squeeze()

if memory is not None:
for key, item in memory.items(include_nested=True, leaves_only=True):
memory[key] = item.squeeze()

action, policy = self.action_selector(actions)
action = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.long)

Expand All @@ -80,28 +84,36 @@ def select(self, state: State) -> Record:
Keys.ACTIONS: actions
}, batch_size=[])

return Record(state, action, info, batch_size=[]).detach().cpu()
return Record(state, action, memory, info, batch_size=[]).detach().cpu()

def __estimate_policy__(self, value, actions):
def __estimate_policy__(self, output):
match self.policy_estimation_method:
case PolicyEstimationMethod.INDEPENDENT:
return value, actions
return output
case PolicyEstimationMethod.DUELING_ARCHITECTURE:
actions = output[Keys.ACTIONS]
value = output.get(Keys.VALUE, actions)

if isinstance(actions, tuple):
actions, lengths = actions

return value, value + actions - actions.sum(dim=-1) / lengths
output[Keys.ACTIONS] = value + actions - actions.sum(dim=-1) / lengths

return output
else:
return value, value + actions - actions.mean(dim=-1, keepdim=True)
output[Keys.ACTIONS] = value + actions - actions.mean(dim=-1, keepdim=True)

return output
case _:
raise ValueError(f"Policy estimation method {self.policy_estimation_method} is not supported")

@staticmethod
def __fetch_values_and_actions__(output: TensorDict):
def __fetch_values__(output: TensorDict):
actions = output[Keys.ACTIONS]
values = output.get(Keys.VALUE, actions)
memory = output.get(Keys.MEMORY, None)

return values, actions
return values, actions, memory

@staticmethod
def base_parameters_from_cli(parameters: Dict):
Expand Down
8 changes: 4 additions & 4 deletions diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def configure(self, configuration: RunConfiguration):
self.value_layer.to(configuration.device)

def post_encode(self, state, output):
values, actions = self.__fetch_values_and_actions__(output)
value, action, _ = self.__fetch_values__(output)

actions = self.action_layer(actions)
values = self.value_layer(values)
output[Keys.VALUE] = self.value_layer(value)
output[Keys.ACTIONS] = self.action_layer(action)

return self.__estimate_policy__(values, actions)
return self.__estimate_policy__(output)

@classmethod
def from_cli(cls, parameters: Dict) -> 'Policy':
Expand Down
22 changes: 15 additions & 7 deletions diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def configure(self, configuration: RunConfiguration):
super().configure(configuration)

def post_encode(self, state: State, outputs):
values, actions = self.__fetch_values_and_actions__(outputs)
values, actions, _ = self.__fetch_values__(outputs)

# Unpack node embeddings obtained from graph batch
if state.graph is not None and isinstance(state.graph, pyg.data.Batch):
Expand All @@ -40,11 +40,15 @@ def post_encode(self, state: State, outputs):
actions = torch.nn.utils.rnn.pad_sequence(result, batch_first=True, padding_value=torch.nan)
lengths = torch.tensor(lengths)

return self.__estimate_policy__(values, (actions, lengths))
outputs[Keys.ACTIONS] = (actions, lengths)

return self.__estimate_policy__(values, actions)
return self.__estimate_policy__(outputs)

return self.__estimate_policy__(outputs)

def __estimate_policy__(self, output):
value, actions, _ = self.__fetch_values__(output)

def __estimate_policy__(self, value, actions):
if isinstance(actions, tuple):
# Encode as logits with zero probability
min_value = torch.finfo(torch.float32).min
Expand All @@ -56,11 +60,15 @@ def __estimate_policy__(self, value, actions):
actions, lengths = actions
means = torch.nan_to_num(actions, nan=0.0).sum(dim=-1) / lengths

return value, post_process(value + actions - means)
output[Keys.ACTIONS] = post_process(value + actions - means)

return output
case _:
return value, post_process(actions[0])
output[Keys.ACTIONS] = post_process(actions[0])

return output

return super().__estimate_policy__(value, actions)
return super().__estimate_policy__(output)

@classmethod
def from_cli(cls, parameters: Dict) -> 'Policy':
Expand Down
6 changes: 2 additions & 4 deletions diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
class Record:
state: State
action: Action
memory: TensorDict | None
info: TensorDict = field(default_factory=lambda: TensorDict({}, batch_size=[]))


class Keys(StrEnum):
ACTIONS = 'actions'
VALUE = 'value'
MEMORY = 'memory'
ACTOR_VALUE = 'actor_value'
POLICY = 'policy'

Expand Down Expand Up @@ -67,10 +69,6 @@ def encode(self, state: State):
def post_encode(self, state: State, output: TensorDict):
pass

@property
def is_recurrent(self):
return False

def configure(self, configuration: RunConfiguration):
pass

Expand Down
6 changes: 4 additions & 2 deletions diploma_thesis/agents/utils/rl/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
class DoubleDeepQTrainer(DeepQTrainer):

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
_, actions = model(batch.next_state)
actions = self.__get_action_values__(model, batch.next_state, None)

best_actions = actions.max(dim=-1).indices
target = self.target_model(batch.next_state)[1][range(batch.shape[0]), best_actions]

target = self.__get_action_values__(self.target_model, batch.next_state, best_actions)

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())

Expand Down
15 changes: 12 additions & 3 deletions diploma_thesis/agents/utils/rl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __train__(self, model: Policy):
with torch.no_grad():
q_values = self.estimate_q(model, batch)

_, actions = model(batch.state)
actions = actions[range(batch.shape[0]), batch.action]
actions = self.__get_action_values__(model, batch.state, batch.action)

loss = self.loss(actions, q_values)
td_error = torch.square(actions - q_values)
Expand All @@ -62,13 +61,23 @@ def __train__(self, model: Policy):
self.storage.update_priority(info['index'], td_error)

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
_, target = self.target_model(batch.next_state)
target = self.__get_action_values__(self.target_model, batch.next_state, None)
target = target.max(dim=1).values

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())

return q

@staticmethod
def __get_action_values__(model: Policy, state, actions):
output = model(state)
_, action_values, _ = model.__fetch_values__(output)

if actions is None:
return action_values

return action_values[range(actions.shape[0]), actions]

@property
def target_model(self):
return self._target_model.module
Expand Down
8 changes: 4 additions & 4 deletions diploma_thesis/agents/utils/rl/p3or.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def __train__(self, model: Policy):
return

def __auxiliary_step__(self, model: Policy, batch: Batch):
output = model.encode(batch.state)
output = model(batch.state)

assert Keys.ACTOR_VALUE in output, (f"Actor value not found in output. It should be a value "
f"representing value estimate for actor head")
assert Keys.ACTOR_VALUE in output.keys(), (f"Actor value not found in output. It should be a value "
f"representing value estimate for actor head")

actor_values = output[Keys.ACTOR_VALUE]

_, actions = model.post_encode(batch.state, output)
_, actions, _ = model.__fetch_values__(output)

loss = self.configuration.value_loss(actor_values.view(-1), batch.info[Record.RETURN_KEY])
loss += self.configuration.trpo_penalty * self.trpo_loss(actions, batch.info[Record.POLICY_KEY])
Expand Down
5 changes: 4 additions & 1 deletion diploma_thesis/agents/utils/rl/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def __train__(self, model: Policy):
baseline = torch.squeeze(baseline)

# Perform policy step
loss = self.loss(model(batch.state)[1], batch.action)
output = model(batch.state)
_, actions, _ = model.__fetch_values__(output)

loss = self.loss(actions, batch.action)

if loss.numel() == 1:
raise ValueError('Loss should not have reduction to single value')
Expand Down
Loading

0 comments on commit c49c377

Please sign in to comment.