Skip to content

Commit

Permalink
Fix graph node storage. Other small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 7, 2024
1 parent 273e04c commit 094808e
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ def encode(self, parameters: StateEncoder.Input) -> State:

graph = parameters.graph

job_ids = graph.data[Graph.JOB_INDEX_KEY][0, :].unique()
job_ids = graph.data[Graph.JOB_INDEX_KEY][:, 0].unique()

processing_times = []

for job_id in job_ids:
processing_times += [parameters.machine.shop_floor.job(job_id).processing_times]
processing_times += [parameters.machine.shop_floor.job(job_id).processing_times.view(-1)]

processing_times = torch.cat(processing_times, dim=0)
processing_times = torch.cat(processing_times, dim=0).view(-1, 1)
graph.data[Graph.OPERATION_KEY].x = processing_times

# TODO: - Remove
graph.data = graph.data.to_homogeneous()

return self.State(parameters.graph, batch_size=[])

@staticmethod
Expand Down
38 changes: 33 additions & 5 deletions diploma_thesis/agents/utils/nn/layers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,54 @@

import torch_geometric as pyg

# if len(self.configuration.graph) > 0:
# self.graph_encoder = pyg_nn.Sequential('x, edge_index', [
# (layer, layer.signature) if isinstance(layer, GraphLayer) else layer
# for layer in self.configuration.graph
# ])
# else:
# self.graph_encoder = None

#
# if not self.is_configured:
# if isinstance(data, HeteroData):
# self.graph_encoder = to_hetero(self.graph_encoder, data.metadata())
#
# encoded_graph = self.graph_encoder(data.x_dict, data.edge_index_dict)


class GraphLayer(Layer):

def __init__(self, kind: str, parameters: Dict):
def __init__(self, configuration: Dict):
super().__init__()

self.kind = kind
self.configuration = configuration
self.layer = self.__build__()

def __build__(self):
match self.kind:
case 'SageConv':
return pyg.nn.SAGEConv(out_channels=1, in_channels=1)
return pyg.nn.SAGEConv(in_channels=-1, **self.configuration)
case 'GIN':
return pyg.nn.GIN(in_channels=-1, **self.configuration)
case 'GAT':
return pyg.nn.GAT(in_channels=-1, **self.configuration)
case _:
raise ValueError(f"Unknown graph layer {self.kind}")

@property
def signature(self):
match self.kind:
case 'SageConv' | 'GIN' | 'GAT':
return 'x, edge_index -> x'
case _:
raise ValueError(f"Unknown graph layer {self.kind}")

def forward(self, data: pyg.data.Data) -> pyg.data.Data:
return self.layer(data)
def forward(self, x, edge_index) -> pyg.data.Data:
return self.layer(x, edge_index)

@classmethod
def from_cli(cls, parameters: dict) -> 'Layer':
return cls(parameters['kind'], parameters=parameters)
return cls(parameters['kind'], configuration=parameters['parameters'])

4 changes: 2 additions & 2 deletions diploma_thesis/agents/utils/nn/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
from torch import nn
from torch_geometric import nn as pyg_nn

from torch_geometric.data import HeteroData
from torch_geometric.nn import to_hetero
from agents.base.state import TensorState, GraphState
from .layers import Layer, from_cli as layer_from_cli, Merge
from .layers import Layer, GraphLayer, from_cli as layer_from_cli, Merge


class NeuralNetwork(nn.Module):
Expand All @@ -17,7 +18,6 @@ class NeuralNetwork(nn.Module):

@dataclass
class Configuration:

graph: list[Layer]
state: list[Layer]
merge: Layer
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/return_estimator/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def update_returns(self, records: List[Record]) -> List[Record]:

for i in reversed(range(len(records))):
next_value = 0 if i == len(records) - 1 else records[i + 1].info[Record.ADVANTAGE_KEY]
value = records[i].info[Record.VALUES_KEY]
value = records[i].info[Record.VALUES_KEY][records[i].action]
advantage = records[i].reward + self._discount_factor * next_value - value
next_advantage = 0 if i == len(records) - 1 else records[i + 1].info[Record.ADVANTAGE_KEY]

Expand Down
1 change: 0 additions & 1 deletion diploma_thesis/agents/utils/return_estimator/n_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def update_returns(self, records: List[Record]) -> List[Record]:
for j in range(n):
g += td_errors[i + j] * lambdas[j] * weights[j] * self.configuration.discount_factor ** j

records[i].info[Record.REWARD_KEY] = records[i].reward
records[i].reward = g

return records
Expand Down
3 changes: 3 additions & 0 deletions diploma_thesis/agents/utils/rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
from typing import Dict

import torch

from agents.utils.memory import NotReadyException
from .rl import *

Expand Down Expand Up @@ -54,6 +56,7 @@ def __step__(self, model: Policy):

advantages = batch.info[Record.ADVANTAGE_KEY]
value, logits = model.predict(batch.state)
value = value[torch.arange(batch.shape[0]), batch.action]
distribution = torch.distributions.Categorical(logits=logits)

loss = 0
Expand Down
11 changes: 8 additions & 3 deletions diploma_thesis/agents/utils/rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class TrainSchedule(StrEnum):
ON_TIMELINE = 'on_timeline'
ON_STORE = 'on_store'
ON_STORED_DATA_EXCLUSIVELY = 'on_stored_data_exclusively'

@staticmethod
def from_cli(parameters: dict):
Expand Down Expand Up @@ -69,16 +70,20 @@ def train_step(self, model: Policy):
def loss_record(self) -> pd.DataFrame:
return pd.DataFrame(self.loss_cache)

def store(self, sample: TrainingSample):
def store(self, sample: TrainingSample, model: Policy):
records = self.__prepare__(sample)
records.info['episode'] = torch.full(records.reward.shape, sample.episode_id, device=records.reward.device)

if self.train_schedule == TrainSchedule.ON_STORED_DATA_EXCLUSIVELY:
self.memory.clear()

self.memory.store(records)

if self.train_schedule != TrainSchedule.ON_STORE:
if (self.train_schedule != TrainSchedule.ON_STORE and
self.train_schedule != TrainSchedule.ON_STORED_DATA_EXCLUSIVELY):
return

self.__train__(sample.model)
self.__train__(model)

def clear(self):
self.loss_cache = []
Expand Down
7 changes: 1 addition & 6 deletions diploma_thesis/configuration/jsp_stream_experiment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@ marl_ddqn: &marl_ddqn
template: *template
mods:
- 'agent/dqn/ddqn.yml'
- 'util/agent/centralized.yml'

marl_dueling_ddqn: &marl_dueling_ddqn
base_path: 'configuration/mods/machine/dqn.yml'
template: *template
mods:
- 'agent/dqn/ddqn.yml'
- 'agent/dqn/dueling.yml'
- 'util/agent/centralized.yml'

marl_dueling_ddqn_pr: &marl_dueling_ddqn_pr
base_path: 'configuration/mods/machine/dqn.yml'
Expand All @@ -55,15 +53,13 @@ marl_dueling_ddqn_pr: &marl_dueling_ddqn_pr
- 'agent/dqn/ddqn.yml'
- 'agent/dqn/dueling.yml'
- 'agent/dqn/prioritized.yml'
- 'util/agent/centralized.yml'

marl_dueling_ddqn_n_step: &marl_dueling_ddqn_n_step
base_path: 'configuration/mods/machine/dqn.yml'
template: *template
mods:
- 'agent/dqn/ddqn.yml'
- 'agent/dqn/dueling.yml'
- 'util/agent/centralized.yml'
- 'agent/dqn/3_step.yml'

reinforce: &reinforce
Expand Down Expand Up @@ -172,6 +168,7 @@ task:
kind: 'multi_task'
n_workers: 10
debug: False
store_run_statistics: False
output_dir: 'results/jsp/marl_direct'

tasks:
Expand Down Expand Up @@ -586,7 +583,6 @@ task:
machine_agent:
parameters:
- *dueling_ddqn_pr_n_step
- *marl_dueling_ddqn_n_step
- *reinforce
- *ppo

Expand Down Expand Up @@ -855,7 +851,6 @@ task:
- *dueling_ddqn_pr
- *reinforce
- *ppo
- *marl_dueling_ddqn_n_step
- *dueling_ddqn_pr_n_step

tape:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ graph:
parameters:
kind: 'SageConv'
parameters:
out_channels: 64
out_channels: 32
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@

parameters:
timeline:
warmup: []
warmup:
- 0
- 0
- 0
- 0
- 0
13 changes: 6 additions & 7 deletions diploma_thesis/configuration/simulation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ task:
output_dir: 'tmp'
log_stdout: True
log_run: True
store_run_statistics: False

machine_agent:
kind: 'mod'
parameters:
base_path: 'configuration/mods/machine/dqn.yml'
base_path: 'configuration/mods/machine/ppo.yml'
template: 'marl_direct'
mods:
- 'agent/dqn/ddqn.yml'
- 'agent/dqn/dueling.yml'
- 'util/agent/multi_agent.yml'
# - 'agent/dqn/ddqn.yml'
# - 'agent/dqn/dueling.yml'
# - 'util/agent/multi_agent.yml'
- 'util/rules/all_rules.yml'

work_center_agent:
Expand All @@ -30,8 +31,6 @@ task:
tape:
machine_reward:
kind: 'surrogate_tardiness'
parameters:
span: 256

work_center_reward:
kind: 'no'
Expand All @@ -53,7 +52,7 @@ task:
is_work_center_set_in_shop_floor_connected: True

simulator:
kind: 'episodic'
kind: 'td'
parameters:
memory: 5

Expand Down
34 changes: 17 additions & 17 deletions diploma_thesis/simulator/graph/util/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def __construct_initial_graph__(self):
t = torch.tensor([], dtype=torch.long)

# Nodes
result.data[Graph.OPERATION_KEY].x = t.view(1, 0)
result.data[Graph.GROUP_KEY].x = t.view(1, 0)
result.data[Graph.MACHINE_KEY].x = t.view(1, 0)
result.data[Graph.WORK_CENTER_KEY].x = t.view(1, 0)
result.data[Graph.OPERATION_KEY].x = t.to(torch.float32).view(0, 1)
result.data[Graph.GROUP_KEY].x = t.to(torch.float32).view(0, 1)
result.data[Graph.MACHINE_KEY].x = t.to(torch.float32).view(0, 1)
result.data[Graph.WORK_CENTER_KEY].x = t.to(torch.float32).view(0, 1)

# Indices
result.data[Graph.JOB_INDEX_KEY] = t.view(4, 0)
result.data[Graph.JOB_INDEX_KEY] = t.view(0, 4)

def __init_rel__(key):
result.data[*key].edge_index = t.view(2, 0)
Expand Down Expand Up @@ -87,7 +87,7 @@ def __update_job_index__(self, result: Graph, source: Graph, job_operation_map:

n_all_ops = 0

result.data[Graph.JOB_INDEX_KEY] = torch.tensor([], dtype=torch.long).view(4, 0)
result.data[Graph.JOB_INDEX_KEY] = torch.tensor([], dtype=torch.long).view(0, 4)

for job_id in job_ids:
if job_id not in job_operation_map:
Expand All @@ -110,9 +110,9 @@ def __update_job_index__(self, result: Graph, source: Graph, job_operation_map:
]
)

result.data[Graph.JOB_INDEX_KEY] = torch.cat([result.data[Graph.JOB_INDEX_KEY], index], dim=1)
result.data[Graph.JOB_INDEX_KEY] = torch.cat([result.data[Graph.JOB_INDEX_KEY], index.T], dim=0)

result.data[Graph.OPERATION_KEY].x = torch.zeros(n_all_ops, dtype=torch.long).view(1, -1)
result.data[Graph.OPERATION_KEY].x = torch.zeros(n_all_ops, dtype=torch.float32).view(-1, 1)

def __update_forward_graph__(
self, result: Graph, source: Graph, shop_floor: ShopFloor, job_operation_map: JOB_OPERATION_MAP_TYPE
Expand Down Expand Up @@ -147,7 +147,7 @@ def __update_forward_graph__(
n_steps = len(job.step_idx)

forward_graph = source.data[Graph.JOB_KEY][job_id][Graph.FORWARD_GRAPH_KEY]
local_operation_to_global_map = torch.arange(min_operation_id, max_operation_id + 1)
local_operation_to_global_map = torch.arange(min_operation_id, max_operation_id + 1, dtype=torch.long)

if has_groups:
# Group nodes are encoded as -step_id -1, which means that they access items in the end of the list
Expand All @@ -170,12 +170,12 @@ def __update_forward_graph__(
s = result.data[Graph.OPERATION_KEY, Graph.FORWARD_RELATION_KEY, Graph.OPERATION_KEY]
s.edge_index = torch.cat([s.edge_index, local_operation_to_global_map[forward_graph]], dim=1)

result.data[Graph.GROUP_KEY].x = torch.zeros(n_groups, dtype=torch.long).view(1, -1)
result.data[Graph.GROUP_KEY].x = torch.zeros(n_groups, dtype=torch.float32).view(1, -1)

def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: ShopFloor):
result.data[Graph.MACHINE_INDEX_KEY] = source.data[Graph.MACHINE_INDEX_KEY]
result.data[Graph.MACHINE_KEY].x = torch.zeros(len(shop_floor.machines), dtype=torch.long).view(1, -1)
result.data[Graph.WORK_CENTER_KEY].x = torch.zeros(len(shop_floor.work_centers), dtype=torch.long).view(1, -1)
result.data[Graph.MACHINE_INDEX_KEY] = source.data[Graph.MACHINE_INDEX_KEY].view(-1, 2)
result.data[Graph.MACHINE_KEY].x = torch.zeros(len(shop_floor.machines), dtype=torch.float32).view(-1, 1)
result.data[Graph.WORK_CENTER_KEY].x = torch.zeros(len(shop_floor.work_centers), dtype=torch.float32).view(-1, 1)

if self.configuration.is_machine_set_in_work_center_connected:
for work_center in shop_floor.work_centers:
Expand All @@ -184,10 +184,10 @@ def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: Sho
torch.arange(len(work_center.machines))
])

machine_node_id = self.__get_node_ids__(machine_index, result.data[Graph.MACHINE_INDEX_KEY])
machine_node_id = self.__get_node_ids__(machine_index, result.data[Graph.MACHINE_INDEX_KEY].T)

edges = itertools.combinations(machine_node_id, 2)
edges = torch.tensor(list(edges)).view(-1, 2).T
edges = torch.tensor(list(edges), dtype=torch.long).view(-1, 2).T

s = result.data[Graph.MACHINE_KEY, Graph.IN_WORK_CENTER_RELATION_KEY, Graph.MACHINE_KEY]
s.edge_index = torch.cat([s.edge_index, edges], dim=1)
Expand All @@ -200,15 +200,15 @@ def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: Sho

if self.configuration.is_work_center_set_in_shop_floor_connected:
edges = itertools.combinations(range(len(shop_floor.work_centers)), 2)
edges = torch.tensor(list(edges)).T
edges = torch.tensor(list(edges), dtype=torch.long).T

s = result.data[Graph.WORK_CENTER_KEY, Graph.IN_SHOP_FLOOR_RELATION_KEY, Graph.WORK_CENTER_KEY]
s.edge_index = edges

def __update_schedule_graphs__(self, result: Graph, source: Graph):
self.__reset_schedule_graph__(result)

index = result.data[Graph.JOB_INDEX_KEY][[0, 1], :]
index = result.data[Graph.JOB_INDEX_KEY][:, [0, 1]].view(2, -1)

for machine_id in range(len(result.data[Graph.MACHINE_KEY])):
has_graph = Graph.SCHEDULED_GRAPH_KEY in source.data[Graph.MACHINE_KEY][machine_id]
Expand Down
Loading

0 comments on commit 094808e

Please sign in to comment.