Skip to content

Commit

Permalink
Fixes for flexible action model
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 18, 2024
1 parent 3b4b876 commit 6a1019d
Show file tree
Hide file tree
Showing 25 changed files with 148 additions and 111 deletions.
16 changes: 6 additions & 10 deletions diploma_thesis/agents/machine/state/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,16 @@ def __localize__(self, parameters: StateEncoder.Input, graph: Graph):
return super().__localize_with_job_ids__(graph, job_ids)

def __post_encode__(self, graph: pyg.data.HeteroData, parameters: StateEncoder.Input) -> pyg.data.HeteroData:
queued_jobs = torch.hstack(list(set([job.id for job in parameters.machine.queue])))
is_in_queue = torch.isin(graph[Graph.JOB_INDEX_MAP][:, 0].view(-1), queued_jobs, assume_unique=True)
job_index = graph[Graph.JOB_INDEX_MAP]

queued_jobs = torch.hstack([job.id for job in parameters.machine.queue])
# TODO: assume_unique=False doesn't work well on MPS
is_in_queue = torch.isin(job_index[:, 0].view(-1), queued_jobs, assume_unique=False)

index = torch.hstack([parameters.machine.work_center_idx, parameters.machine.machine_idx])
is_target = torch.all(graph[Graph.JOB_INDEX_MAP][:, [2, 3]] == index, dim=1)
is_target = torch.all(job_index[:, [2, 3]] == index, dim=1)

graph[Graph.OPERATION_KEY][Graph.TARGET_KEY] = torch.logical_and(is_in_queue.view(-1), is_target.view(-1))

lhs = set(graph[Graph.JOB_INDEX_MAP][graph[Graph.OPERATION_KEY][Graph.TARGET_KEY], 0].view(-1).tolist())
rhs = set([job.id.item() for job in parameters.machine.queue])

print(lhs, rhs)

assert lhs == rhs

return graph

6 changes: 4 additions & 2 deletions diploma_thesis/agents/utils/policy/action_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def forward(self, state: State):
state = state.to(self.run_configuration.device)

values, actions = self.encode(state)
values, actions = self.post_encode(state, values, actions)

return self.__estimate_policy__(values, actions)
return self.post_encode(state, values, actions)

def encode(self, state: State, return_values: bool = True, return_actions: bool = True):
values, actions = None, None
Expand All @@ -80,6 +79,9 @@ def encode(self, state: State, return_values: bool = True, return_actions: bool

return values, actions

def post_encode(self, state: State, values: torch.FloatTensor, actions: torch.FloatTensor):
return self.__estimate_policy__(values, actions)

def select(self, state: State) -> Record:
value, actions = self.__call__(state)
value, actions = value.squeeze(), actions.squeeze()
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def post_encode(self, state, value: torch.FloatTensor, actions: torch.FloatTenso
value = self.value_layer(value)
actions = self.action_layer(actions)

return value, actions
return super().post_encode(state, value, actions)

@classmethod
def from_cli(cls, parameters: Dict) -> 'Policy':
Expand Down
19 changes: 16 additions & 3 deletions diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,24 @@

class FlexibleAction(ActionPolicy):

def __init__(self, base_parameters):
super().__init__(**base_parameters)

self.action_layer = self.make_linear_layer(1)

@property
def is_recurrent(self):
return False

def configure(self, configuration: RunConfiguration):
super().configure(configuration)

self.action_layer.to(configuration.device)

def post_encode(self, state: State, value: torch.FloatTensor, actions: torch.FloatTensor):
actions = self.action_layer(actions)

# Unpack node embeddings obtained from graph batch
if state.graph is not None and isinstance(state.graph, pyg.data.Batch):
result = []
lengths = []
Expand All @@ -31,9 +44,9 @@ def post_encode(self, state: State, value: torch.FloatTensor, actions: torch.Flo
actions = torch.nn.utils.rnn.pad_sequence(result, batch_first=True, padding_value=torch.nan)
lengths = torch.tensor(lengths)

return value, (actions, lengths)
return super().post_encode(state, value, (actions, lengths))

return value, actions
return super().post_encode(state, value, actions)

def __estimate_policy__(self, value, actions):
if isinstance(actions, tuple):
Expand All @@ -55,5 +68,5 @@ def __estimate_policy__(self, value, actions):

@classmethod
def from_cli(cls, parameters: Dict) -> 'Policy':
return FlexibleAction(**cls.base_parameters_from_cli(parameters))
return FlexibleAction(cls.base_parameters_from_cli(parameters))

2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def configure(self, configuration: RunConfiguration):
pass

def make_linear_layer(self, output_dim):
return Linear(output_dim, noise_parameters=self.noise_parameters)
return Linear(output_dim, noise_parameters=self.noise_parameters, activation='none', dropout=0)

def clone(self):
return copy.deepcopy(self)
4 changes: 3 additions & 1 deletion diploma_thesis/agents/utils/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .reinforce import Reinforce
from .ppo import PPO
from .rl import RLTrainer
from .p3or import P3OR

key_to_class = {
'dqn': DeepQTrainer,
'ddqn': DoubleDeepQTrainer,
'reinforce': Reinforce,
'ppo': PPO
'ppo': PPO,
'p3or': P3OR
}


Expand Down
11 changes: 2 additions & 9 deletions diploma_thesis/agents/utils/rl/ddqn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import tensordict
import torch

from .dqn import DeepQTrainer
from agents.utils.memory import Record
from agents.utils.policy import Policy
from .dqn import DeepQTrainer


class DoubleDeepQTrainer(DeepQTrainer):

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
_, actions = model(batch.next_state)
orig_q = actions[range(batch.shape[0]), batch.action]

best_actions = actions.max(dim=-1).indices

target = self.target_model(batch.next_state)[1][range(batch.shape[0]), best_actions]

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())
actions[range(batch.shape[0]), batch.action] = q

td_error = torch.square(orig_q - q)

return actions, td_error
return q
17 changes: 6 additions & 11 deletions diploma_thesis/agents/utils/rl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ def __train__(self, model: Policy):
return

with torch.no_grad():
q_values, td_error = self.estimate_q(model, batch)
q_values = self.estimate_q(model, batch)

_, actions = model(batch.state)
actions = actions[range(batch.shape[0]), batch.action]

loss = self.loss(actions, q_values)

td_error = torch.square(actions - q_values)

self.step(loss, self.optimizer)

self.record_loss(loss)
Expand All @@ -59,21 +63,12 @@ def __train__(self, model: Policy):
self.storage.update_priority(info['index'], td_error)

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
# Note:
# The idea is that we compute the Q-values only for performed actions. Other actions wouldn't be updated,
# because there will be zero loss and so zero gradient
_, actions = model(batch.next_state)
orig_q = actions.clone()[range(batch.shape[0]), batch.action]

_, target = self.target_model(batch.next_state)
target = target.max(dim=1).values

q = batch.reward + self.return_estimator.discount_factor * target * (1 - batch.done.int())
actions[range(batch.shape[0]), batch.action] = q

td_error = torch.square(orig_q - q)

return actions, td_error
return q

@property
def target_model(self):
Expand Down
13 changes: 9 additions & 4 deletions diploma_thesis/agents/utils/rl/p3or.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, configuration: Configuration, *args, **kwargs):
self.trpo_loss = Loss(configuration=Loss.Configuration(kind='cross_entropy', parameters=dict()))
self.auxiliary_head = None

def __configure__(self, model: Policy, configuration: RunConfiguration):
super(RLTrainer).__configure__(model, configuration)
def configure(self, model: Policy, configuration: RunConfiguration):
super().configure(model, configuration)

self.auxiliary_head = model.make_linear_layer(1).to(configuration.device)

Expand All @@ -39,17 +39,22 @@ def __train__(self, model: Policy):
for minibatch in generator:
self.__step__(minibatch, model, self.configuration)

# Load batch
batch = batch()

# Auxiliary step
self.__auxiliary_step__(model, batch)
except NotReadyException:
return

def __auxiliary_step__(self, model: Policy, batch: Batch):
actions = model.encode(batch.state, return_values=False)
actions = model.post_encode(values=None, actions=actions)
values, actions = model.encode(batch.state)

# TODO: Aggregate Q values
values = self.auxiliary_head(actions)

_, actions = model.post_encode(batch.state, values, actions)

loss = self.configuration.value_loss(values.view(-1), batch.info[Record.RETURN_KEY])
loss += self.configuration.trpo_penalty * self.trpo_loss(actions, batch.info[Record.POLICY_KEY])

Expand Down
4 changes: 1 addition & 3 deletions diploma_thesis/agents/utils/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

@dataclass
class Configuration(PPOConfiguration):
epochs: int

@staticmethod
def from_cli(parameters: Dict):
return Configuration(**PPOConfiguration.base_parameters_from_cli(parameters),
epochs=parameters.get('epochs', 1))
return Configuration(**PPOConfiguration.base_parameters_from_cli(parameters))


class PPO(PPOMixin):
Expand Down
7 changes: 1 addition & 6 deletions diploma_thesis/agents/utils/rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,8 @@ def train_step(self, model: Policy):
if self.train_schedule != TrainSchedule.ON_TIMELINE:
return

import time

start = time.time()

self.__train__(model)

print(f"Training step took {time.time() - start} seconds")

def __train__(self, model: Policy):
pass

Expand Down Expand Up @@ -119,3 +113,4 @@ def step(loss, optimizer):
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
47 changes: 30 additions & 17 deletions diploma_thesis/agents/utils/rl/storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import torch

from typing import List, Tuple, Generator
from typing import List, Tuple, Generator, Callable

from agents.base.state import Graph
from agents.utils.memory import Memory, Record
Expand All @@ -23,6 +23,9 @@ def __init__(self, is_episodic: bool, memory: Memory, return_estimator: ReturnEs
def store(self, sample: TrainingSample):
records = self.__prepare__(sample)

if records[0].state.graph is None:
print(records[0].state.graph)

for record in records:
record.info['episode'] = sample.episode_id

Expand All @@ -46,7 +49,7 @@ def sample(self, update_returns: bool, device: torch.device, batch_graphs: bool

return self.__process_batched_data__(batch, update_returns, batch_graphs, device), info

def sample_minibatches(self, update_returns, device, n, sample_ratio) -> Tuple[Record, Generator]:
def sample_minibatches(self, update_returns, device, n, sample_ratio) -> Tuple[Callable, Generator]:
batch, info = self.sample(update_returns, device, batch_graphs=False)

def generator(batch):
Expand All @@ -56,20 +59,27 @@ def generator(batch):
mask_ = mask.uniform_() < sample_ratio
idx = mask_.nonzero()

sub_batch = batch[mask_]
minibatch = batch[mask_]

if batch[0].state.graph is not None:
sub_batch.state.graph = Batch.from_data_list(
[batch.state.graph[index] for index in idx]
).to(device)
minibatch.state.graph = Batch.from_data_list([
batch.state.graph[index] for index in idx
]).to(device)

sub_batch.next_state.graph = Batch.from_data_list(
minibatch.next_state.graph = Batch.from_data_list(
[batch.next_state.graph[index] for index in idx]
).to(device)

yield sub_batch
yield minibatch

def load_batch():
if batch[0].state.graph is not None:
batch.state.graph = Batch.from_data_list(batch.state.graph).to(device)
batch.next_state.graph = Batch.from_data_list(batch.next_state.graph).to(device)

return batch, generator(batch)
return batch

return load_batch, generator(batch)

def update_priority(self, indices: torch.LongTensor, priorities: torch.FloatTensor):
self.memory.update_priority(indices, priorities)
Expand Down Expand Up @@ -123,17 +133,16 @@ def __merge_batched_data__(self, batch, batch_graphs, device):
state_graph = []
next_state_graph = []

# Elements can be sampled with repetition
for element in batch:
state_graph += [element.state.graph]
next_state_graph += [element.next_state.graph]
state_graph += [element.state.graph[0]]
next_state_graph += [element.next_state.graph[0]]

for element in batch:
element.state.graph, element.next_state.graph = None, None

result = torch.cat(batch, dim=0)

for index, element in enumerate(batch):
element.state.graph, element.next_state.graph = state_graph[index], next_state_graph[index]

result.state.graph = self.__collate_graphs__(state_graph, batch_graphs, device)
result.next_state.graph = self.__collate_graphs__(next_state_graph, batch_graphs, device)
else:
Expand All @@ -156,8 +165,6 @@ def __collate_variable_length_info_values__(self, batch):
for key in keys:
result[key] += torch.atleast_2d(element.info[key])

del element.info[key]

for key in keys:
match key:
case Record.POLICY_KEY:
Expand All @@ -167,10 +174,16 @@ def __collate_variable_length_info_values__(self, batch):

result[key] = torch.nn.utils.rnn.pad_sequence(result[key], batch_first=True, padding_value=fill_value)

# In some cases batch can be sampled with repetition
for element in batch:
for key in keys:
if key in element.info.keys():
del element.info[key]

return result

def __collate_graphs__(self, records: List[Graph], batch_graphs, device: torch.device):
graphs = [record[0].to_pyg_graph() for record in records]
graphs = [record.to_pyg_graph() for record in records]

if not batch_graphs:
return graphs
Expand Down
3 changes: 2 additions & 1 deletion diploma_thesis/agents/utils/rl/utils/ppo_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class PPOConfiguration:
entropy_regularization: float
update_advantages: bool
rollback_ratio: float
epochs: int

@staticmethod
def base_parameters_from_cli(parameters: Dict):
Expand All @@ -23,6 +24,7 @@ def base_parameters_from_cli(parameters: Dict):
entropy_regularization=parameters.get('entropy_regularization', 0.0),
update_advantages=parameters.get('update_advantages', True),
rollback_ratio=parameters.get('rollback_ratio', 0.1),
epochs=parameters.get('epochs', 1)
)


Expand All @@ -35,7 +37,6 @@ def __step__(self, batch: Record, model: Policy, configuration: PPOConfiguration
value, logits = model(batch.state)

loss = self.actor_loss(batch, logits, configuration, self.run_configuration.device)
# Maximization of negative value is equivalent to minimization
loss -= configuration.value_loss(value.view(-1), batch.info[Record.RETURN_KEY])

# Want to maximize
Expand Down
Loading

0 comments on commit 6a1019d

Please sign in to comment.