From 094808e4997184dbb4944561890f3c1e2b337397 Mon Sep 17 00:00:00 2001 From: Yury Hayeu Date: Thu, 7 Mar 2024 12:51:32 +0100 Subject: [PATCH] Fix graph node storage. Other small fixes --- .../machine/state/auxiliary_graph_encoder.py | 9 ++--- .../agents/utils/nn/layers/graph.py | 38 ++++++++++++++++--- .../agents/utils/nn/neural_network.py | 4 +- .../agents/utils/return_estimator/gae.py | 2 +- .../agents/utils/return_estimator/n_step.py | 1 - diploma_thesis/agents/utils/rl/ppo.py | 3 ++ diploma_thesis/agents/utils/rl/rl.py | 11 ++++-- .../configuration/jsp_stream_experiment.yml | 7 +--- .../machine/templates/auxiliary/model.yml | 2 +- .../mods/run/mods/timeline/no_warmup.yml | 7 +++- diploma_thesis/configuration/simulation.yml | 13 +++---- .../simulator/graph/util/encoder.py | 34 ++++++++--------- .../simulator/tape/queue/machine_queue.py | 18 ++++----- diploma_thesis/simulator/td.py | 8 ++-- diploma_thesis/workflow/multi_simulation.py | 27 ++++++------- diploma_thesis/workflow/simulation.py | 26 +++++++++++-- 16 files changed, 126 insertions(+), 84 deletions(-) diff --git a/diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py b/diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py index 35b5d67..910fd89 100644 --- a/diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py +++ b/diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py @@ -18,19 +18,16 @@ def encode(self, parameters: StateEncoder.Input) -> State: graph = parameters.graph - job_ids = graph.data[Graph.JOB_INDEX_KEY][0, :].unique() + job_ids = graph.data[Graph.JOB_INDEX_KEY][:, 0].unique() processing_times = [] for job_id in job_ids: - processing_times += [parameters.machine.shop_floor.job(job_id).processing_times] + processing_times += [parameters.machine.shop_floor.job(job_id).processing_times.view(-1)] - processing_times = torch.cat(processing_times, dim=0) + processing_times = torch.cat(processing_times, dim=0).view(-1, 1) graph.data[Graph.OPERATION_KEY].x = processing_times - # TODO: - Remove - graph.data = graph.data.to_homogeneous() - return self.State(parameters.graph, batch_size=[]) @staticmethod diff --git a/diploma_thesis/agents/utils/nn/layers/graph.py b/diploma_thesis/agents/utils/nn/layers/graph.py index 1bf46e5..8fada76 100644 --- a/diploma_thesis/agents/utils/nn/layers/graph.py +++ b/diploma_thesis/agents/utils/nn/layers/graph.py @@ -4,26 +4,54 @@ import torch_geometric as pyg +# if len(self.configuration.graph) > 0: +# self.graph_encoder = pyg_nn.Sequential('x, edge_index', [ +# (layer, layer.signature) if isinstance(layer, GraphLayer) else layer +# for layer in self.configuration.graph +# ]) +# else: +# self.graph_encoder = None + +# +# if not self.is_configured: +# if isinstance(data, HeteroData): +# self.graph_encoder = to_hetero(self.graph_encoder, data.metadata()) +# +# encoded_graph = self.graph_encoder(data.x_dict, data.edge_index_dict) + class GraphLayer(Layer): - def __init__(self, kind: str, parameters: Dict): + def __init__(self, configuration: Dict): super().__init__() self.kind = kind + self.configuration = configuration self.layer = self.__build__() def __build__(self): match self.kind: case 'SageConv': - return pyg.nn.SAGEConv(out_channels=1, in_channels=1) + return pyg.nn.SAGEConv(in_channels=-1, **self.configuration) + case 'GIN': + return pyg.nn.GIN(in_channels=-1, **self.configuration) + case 'GAT': + return pyg.nn.GAT(in_channels=-1, **self.configuration) + case _: + raise ValueError(f"Unknown graph layer {self.kind}") + + @property + def signature(self): + match self.kind: + case 'SageConv' | 'GIN' | 'GAT': + return 'x, edge_index -> x' case _: raise ValueError(f"Unknown graph layer {self.kind}") - def forward(self, data: pyg.data.Data) -> pyg.data.Data: - return self.layer(data) + def forward(self, x, edge_index) -> pyg.data.Data: + return self.layer(x, edge_index) @classmethod def from_cli(cls, parameters: dict) -> 'Layer': - return cls(parameters['kind'], parameters=parameters) + return cls(parameters['kind'], configuration=parameters['parameters']) diff --git a/diploma_thesis/agents/utils/nn/neural_network.py b/diploma_thesis/agents/utils/nn/neural_network.py index e34e7ef..957f84b 100644 --- a/diploma_thesis/agents/utils/nn/neural_network.py +++ b/diploma_thesis/agents/utils/nn/neural_network.py @@ -3,11 +3,12 @@ import torch from torch import nn +from torch_geometric import nn as pyg_nn from torch_geometric.data import HeteroData from torch_geometric.nn import to_hetero from agents.base.state import TensorState, GraphState -from .layers import Layer, from_cli as layer_from_cli, Merge +from .layers import Layer, GraphLayer, from_cli as layer_from_cli, Merge class NeuralNetwork(nn.Module): @@ -17,7 +18,6 @@ class NeuralNetwork(nn.Module): @dataclass class Configuration: - graph: list[Layer] state: list[Layer] merge: Layer diff --git a/diploma_thesis/agents/utils/return_estimator/gae.py b/diploma_thesis/agents/utils/return_estimator/gae.py index 7b7fc21..db4cb7a 100644 --- a/diploma_thesis/agents/utils/return_estimator/gae.py +++ b/diploma_thesis/agents/utils/return_estimator/gae.py @@ -19,7 +19,7 @@ def update_returns(self, records: List[Record]) -> List[Record]: for i in reversed(range(len(records))): next_value = 0 if i == len(records) - 1 else records[i + 1].info[Record.ADVANTAGE_KEY] - value = records[i].info[Record.VALUES_KEY] + value = records[i].info[Record.VALUES_KEY][records[i].action] advantage = records[i].reward + self._discount_factor * next_value - value next_advantage = 0 if i == len(records) - 1 else records[i + 1].info[Record.ADVANTAGE_KEY] diff --git a/diploma_thesis/agents/utils/return_estimator/n_step.py b/diploma_thesis/agents/utils/return_estimator/n_step.py index 7ba7cd2..944ae06 100644 --- a/diploma_thesis/agents/utils/return_estimator/n_step.py +++ b/diploma_thesis/agents/utils/return_estimator/n_step.py @@ -83,7 +83,6 @@ def update_returns(self, records: List[Record]) -> List[Record]: for j in range(n): g += td_errors[i + j] * lambdas[j] * weights[j] * self.configuration.discount_factor ** j - records[i].info[Record.REWARD_KEY] = records[i].reward records[i].reward = g return records diff --git a/diploma_thesis/agents/utils/rl/ppo.py b/diploma_thesis/agents/utils/rl/ppo.py index c7b84fa..bdf096e 100644 --- a/diploma_thesis/agents/utils/rl/ppo.py +++ b/diploma_thesis/agents/utils/rl/ppo.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import Dict +import torch + from agents.utils.memory import NotReadyException from .rl import * @@ -54,6 +56,7 @@ def __step__(self, model: Policy): advantages = batch.info[Record.ADVANTAGE_KEY] value, logits = model.predict(batch.state) + value = value[torch.arange(batch.shape[0]), batch.action] distribution = torch.distributions.Categorical(logits=logits) loss = 0 diff --git a/diploma_thesis/agents/utils/rl/rl.py b/diploma_thesis/agents/utils/rl/rl.py index a5d9d24..6c9cf5a 100644 --- a/diploma_thesis/agents/utils/rl/rl.py +++ b/diploma_thesis/agents/utils/rl/rl.py @@ -19,6 +19,7 @@ class TrainSchedule(StrEnum): ON_TIMELINE = 'on_timeline' ON_STORE = 'on_store' + ON_STORED_DATA_EXCLUSIVELY = 'on_stored_data_exclusively' @staticmethod def from_cli(parameters: dict): @@ -69,16 +70,20 @@ def train_step(self, model: Policy): def loss_record(self) -> pd.DataFrame: return pd.DataFrame(self.loss_cache) - def store(self, sample: TrainingSample): + def store(self, sample: TrainingSample, model: Policy): records = self.__prepare__(sample) records.info['episode'] = torch.full(records.reward.shape, sample.episode_id, device=records.reward.device) + if self.train_schedule == TrainSchedule.ON_STORED_DATA_EXCLUSIVELY: + self.memory.clear() + self.memory.store(records) - if self.train_schedule != TrainSchedule.ON_STORE: + if (self.train_schedule != TrainSchedule.ON_STORE and + self.train_schedule != TrainSchedule.ON_STORED_DATA_EXCLUSIVELY): return - self.__train__(sample.model) + self.__train__(model) def clear(self): self.loss_cache = [] diff --git a/diploma_thesis/configuration/jsp_stream_experiment.yml b/diploma_thesis/configuration/jsp_stream_experiment.yml index 99fe70d..5316cd2 100644 --- a/diploma_thesis/configuration/jsp_stream_experiment.yml +++ b/diploma_thesis/configuration/jsp_stream_experiment.yml @@ -38,7 +38,6 @@ marl_ddqn: &marl_ddqn template: *template mods: - 'agent/dqn/ddqn.yml' - - 'util/agent/centralized.yml' marl_dueling_ddqn: &marl_dueling_ddqn base_path: 'configuration/mods/machine/dqn.yml' @@ -46,7 +45,6 @@ marl_dueling_ddqn: &marl_dueling_ddqn mods: - 'agent/dqn/ddqn.yml' - 'agent/dqn/dueling.yml' - - 'util/agent/centralized.yml' marl_dueling_ddqn_pr: &marl_dueling_ddqn_pr base_path: 'configuration/mods/machine/dqn.yml' @@ -55,7 +53,6 @@ marl_dueling_ddqn_pr: &marl_dueling_ddqn_pr - 'agent/dqn/ddqn.yml' - 'agent/dqn/dueling.yml' - 'agent/dqn/prioritized.yml' - - 'util/agent/centralized.yml' marl_dueling_ddqn_n_step: &marl_dueling_ddqn_n_step base_path: 'configuration/mods/machine/dqn.yml' @@ -63,7 +60,6 @@ marl_dueling_ddqn_n_step: &marl_dueling_ddqn_n_step mods: - 'agent/dqn/ddqn.yml' - 'agent/dqn/dueling.yml' - - 'util/agent/centralized.yml' - 'agent/dqn/3_step.yml' reinforce: &reinforce @@ -172,6 +168,7 @@ task: kind: 'multi_task' n_workers: 10 debug: False + store_run_statistics: False output_dir: 'results/jsp/marl_direct' tasks: @@ -586,7 +583,6 @@ task: machine_agent: parameters: - *dueling_ddqn_pr_n_step - - *marl_dueling_ddqn_n_step - *reinforce - *ppo @@ -855,7 +851,6 @@ task: - *dueling_ddqn_pr - *reinforce - *ppo - - *marl_dueling_ddqn_n_step - *dueling_ddqn_pr_n_step tape: diff --git a/diploma_thesis/configuration/mods/machine/templates/auxiliary/model.yml b/diploma_thesis/configuration/mods/machine/templates/auxiliary/model.yml index c31e157..ff2ffd1 100644 --- a/diploma_thesis/configuration/mods/machine/templates/auxiliary/model.yml +++ b/diploma_thesis/configuration/mods/machine/templates/auxiliary/model.yml @@ -4,4 +4,4 @@ graph: parameters: kind: 'SageConv' parameters: - out_channels: 64 \ No newline at end of file + out_channels: 32 diff --git a/diploma_thesis/configuration/mods/run/mods/timeline/no_warmup.yml b/diploma_thesis/configuration/mods/run/mods/timeline/no_warmup.yml index b6aca4d..8e146ed 100644 --- a/diploma_thesis/configuration/mods/run/mods/timeline/no_warmup.yml +++ b/diploma_thesis/configuration/mods/run/mods/timeline/no_warmup.yml @@ -2,4 +2,9 @@ parameters: timeline: - warmup: [] + warmup: + - 0 + - 0 + - 0 + - 0 + - 0 diff --git a/diploma_thesis/configuration/simulation.yml b/diploma_thesis/configuration/simulation.yml index 0874b0a..331b888 100644 --- a/diploma_thesis/configuration/simulation.yml +++ b/diploma_thesis/configuration/simulation.yml @@ -5,16 +5,17 @@ task: output_dir: 'tmp' log_stdout: True log_run: True + store_run_statistics: False machine_agent: kind: 'mod' parameters: - base_path: 'configuration/mods/machine/dqn.yml' + base_path: 'configuration/mods/machine/ppo.yml' template: 'marl_direct' mods: - - 'agent/dqn/ddqn.yml' - - 'agent/dqn/dueling.yml' - - 'util/agent/multi_agent.yml' +# - 'agent/dqn/ddqn.yml' +# - 'agent/dqn/dueling.yml' +# - 'util/agent/multi_agent.yml' - 'util/rules/all_rules.yml' work_center_agent: @@ -30,8 +31,6 @@ task: tape: machine_reward: kind: 'surrogate_tardiness' - parameters: - span: 256 work_center_reward: kind: 'no' @@ -53,7 +52,7 @@ task: is_work_center_set_in_shop_floor_connected: True simulator: - kind: 'episodic' + kind: 'td' parameters: memory: 5 diff --git a/diploma_thesis/simulator/graph/util/encoder.py b/diploma_thesis/simulator/graph/util/encoder.py index 859e08f..08447d9 100644 --- a/diploma_thesis/simulator/graph/util/encoder.py +++ b/diploma_thesis/simulator/graph/util/encoder.py @@ -51,13 +51,13 @@ def __construct_initial_graph__(self): t = torch.tensor([], dtype=torch.long) # Nodes - result.data[Graph.OPERATION_KEY].x = t.view(1, 0) - result.data[Graph.GROUP_KEY].x = t.view(1, 0) - result.data[Graph.MACHINE_KEY].x = t.view(1, 0) - result.data[Graph.WORK_CENTER_KEY].x = t.view(1, 0) + result.data[Graph.OPERATION_KEY].x = t.to(torch.float32).view(0, 1) + result.data[Graph.GROUP_KEY].x = t.to(torch.float32).view(0, 1) + result.data[Graph.MACHINE_KEY].x = t.to(torch.float32).view(0, 1) + result.data[Graph.WORK_CENTER_KEY].x = t.to(torch.float32).view(0, 1) # Indices - result.data[Graph.JOB_INDEX_KEY] = t.view(4, 0) + result.data[Graph.JOB_INDEX_KEY] = t.view(0, 4) def __init_rel__(key): result.data[*key].edge_index = t.view(2, 0) @@ -87,7 +87,7 @@ def __update_job_index__(self, result: Graph, source: Graph, job_operation_map: n_all_ops = 0 - result.data[Graph.JOB_INDEX_KEY] = torch.tensor([], dtype=torch.long).view(4, 0) + result.data[Graph.JOB_INDEX_KEY] = torch.tensor([], dtype=torch.long).view(0, 4) for job_id in job_ids: if job_id not in job_operation_map: @@ -110,9 +110,9 @@ def __update_job_index__(self, result: Graph, source: Graph, job_operation_map: ] ) - result.data[Graph.JOB_INDEX_KEY] = torch.cat([result.data[Graph.JOB_INDEX_KEY], index], dim=1) + result.data[Graph.JOB_INDEX_KEY] = torch.cat([result.data[Graph.JOB_INDEX_KEY], index.T], dim=0) - result.data[Graph.OPERATION_KEY].x = torch.zeros(n_all_ops, dtype=torch.long).view(1, -1) + result.data[Graph.OPERATION_KEY].x = torch.zeros(n_all_ops, dtype=torch.float32).view(-1, 1) def __update_forward_graph__( self, result: Graph, source: Graph, shop_floor: ShopFloor, job_operation_map: JOB_OPERATION_MAP_TYPE @@ -147,7 +147,7 @@ def __update_forward_graph__( n_steps = len(job.step_idx) forward_graph = source.data[Graph.JOB_KEY][job_id][Graph.FORWARD_GRAPH_KEY] - local_operation_to_global_map = torch.arange(min_operation_id, max_operation_id + 1) + local_operation_to_global_map = torch.arange(min_operation_id, max_operation_id + 1, dtype=torch.long) if has_groups: # Group nodes are encoded as -step_id -1, which means that they access items in the end of the list @@ -170,12 +170,12 @@ def __update_forward_graph__( s = result.data[Graph.OPERATION_KEY, Graph.FORWARD_RELATION_KEY, Graph.OPERATION_KEY] s.edge_index = torch.cat([s.edge_index, local_operation_to_global_map[forward_graph]], dim=1) - result.data[Graph.GROUP_KEY].x = torch.zeros(n_groups, dtype=torch.long).view(1, -1) + result.data[Graph.GROUP_KEY].x = torch.zeros(n_groups, dtype=torch.float32).view(1, -1) def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: ShopFloor): - result.data[Graph.MACHINE_INDEX_KEY] = source.data[Graph.MACHINE_INDEX_KEY] - result.data[Graph.MACHINE_KEY].x = torch.zeros(len(shop_floor.machines), dtype=torch.long).view(1, -1) - result.data[Graph.WORK_CENTER_KEY].x = torch.zeros(len(shop_floor.work_centers), dtype=torch.long).view(1, -1) + result.data[Graph.MACHINE_INDEX_KEY] = source.data[Graph.MACHINE_INDEX_KEY].view(-1, 2) + result.data[Graph.MACHINE_KEY].x = torch.zeros(len(shop_floor.machines), dtype=torch.float32).view(-1, 1) + result.data[Graph.WORK_CENTER_KEY].x = torch.zeros(len(shop_floor.work_centers), dtype=torch.float32).view(-1, 1) if self.configuration.is_machine_set_in_work_center_connected: for work_center in shop_floor.work_centers: @@ -184,10 +184,10 @@ def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: Sho torch.arange(len(work_center.machines)) ]) - machine_node_id = self.__get_node_ids__(machine_index, result.data[Graph.MACHINE_INDEX_KEY]) + machine_node_id = self.__get_node_ids__(machine_index, result.data[Graph.MACHINE_INDEX_KEY].T) edges = itertools.combinations(machine_node_id, 2) - edges = torch.tensor(list(edges)).view(-1, 2).T + edges = torch.tensor(list(edges), dtype=torch.long).view(-1, 2).T s = result.data[Graph.MACHINE_KEY, Graph.IN_WORK_CENTER_RELATION_KEY, Graph.MACHINE_KEY] s.edge_index = torch.cat([s.edge_index, edges], dim=1) @@ -200,7 +200,7 @@ def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: Sho if self.configuration.is_work_center_set_in_shop_floor_connected: edges = itertools.combinations(range(len(shop_floor.work_centers)), 2) - edges = torch.tensor(list(edges)).T + edges = torch.tensor(list(edges), dtype=torch.long).T s = result.data[Graph.WORK_CENTER_KEY, Graph.IN_SHOP_FLOOR_RELATION_KEY, Graph.WORK_CENTER_KEY] s.edge_index = edges @@ -208,7 +208,7 @@ def __append_machine_nodes__(self, result: Graph, source: Graph, shop_floor: Sho def __update_schedule_graphs__(self, result: Graph, source: Graph): self.__reset_schedule_graph__(result) - index = result.data[Graph.JOB_INDEX_KEY][[0, 1], :] + index = result.data[Graph.JOB_INDEX_KEY][:, [0, 1]].view(2, -1) for machine_id in range(len(result.data[Graph.MACHINE_KEY])): has_graph = Graph.SCHEDULED_GRAPH_KEY in source.data[Graph.MACHINE_KEY][machine_id] diff --git a/diploma_thesis/simulator/tape/queue/machine_queue.py b/diploma_thesis/simulator/tape/queue/machine_queue.py index 9063e2f..ff85d72 100644 --- a/diploma_thesis/simulator/tape/queue/machine_queue.py +++ b/diploma_thesis/simulator/tape/queue/machine_queue.py @@ -32,23 +32,19 @@ def register(self, context: Context, machine: Machine, record: MachineModel.Reco if record.result is None: mode = NextStateRecordMode.on_next_action - self.__record_next_state_on_action__(record.record.state, machine.key) + self.__record_next_state_on_action__(context, record.record.state, machine.key) self.__append_to_queue__(context, machine, record, mode) def did_produce(self, context: Context, machine: Machine, job: Job): record = self.queue[machine.key][-1] record.record.reward = self.reward.reward_after_production(record.context) - if record.record.reward is not None: - self.__emit_rewards__(context, machine.work_center_idx, machine.machine_idx) - return + if record.mode == NextStateRecordMode.on_produce: + state = self.simulator.encode_machine_state(context=context, machine=machine) - if record.mode != NextStateRecordMode.on_produce: - return + record.record.next_state = state - state = self.simulator.encode_machine_state(context=context, machine=machine) - - record.record.next_state = state + self.__emit_rewards__(context, machine.work_center_idx, machine.machine_idx) def did_complete(self, context: Context, job: Job): records = self.__fetch_records_from_job_path__(context, job) @@ -69,7 +65,7 @@ def did_complete(self, context: Context, job: Job): # Utility - def __record_next_state_on_action__(self, state, machine_key): + def __record_next_state_on_action__(self, context, state, machine_key): if len(self.queue[machine_key]) == 0: return @@ -80,6 +76,8 @@ def __record_next_state_on_action__(self, state, machine_key): record.record.next_state = state + self.__emit_rewards__(context, machine_key.work_center_id, machine_key.machine_id) + def __append_to_queue__( self, context: Context, machine: Machine, record: MachineModel.Record, mode: NextStateRecordMode ): diff --git a/diploma_thesis/simulator/td.py b/diploma_thesis/simulator/td.py index 0ab4fb0..5d9e3bd 100644 --- a/diploma_thesis/simulator/td.py +++ b/diploma_thesis/simulator/td.py @@ -38,9 +38,9 @@ def __store_or_forward_td__(self, context: Context, queue: Queue, agent, key, re if queue.group_len(context.shop_floor.id, key) > self.memory: # Pass a copy of records to avoid modification of the original - records = queue.pop_group(context.shop_floor.id, key) - records = torch.cat([record.view(-1) for record in records]) - records: List[Record] = list(records.clone().unbind(dim=0)) + original_records = queue.pop_group(context.shop_floor.id, key) + records = torch.cat([record.view(-1) for record in original_records]).clone() + records: List[Record] = list(records.unbind(dim=0)) if self.send_as_trajectory: agent.store(key, Trajectory(episode_id=self.episode, records=records)) @@ -49,7 +49,7 @@ def __store_or_forward_td__(self, context: Context, queue: Queue, agent, key, re else: agent.store(key, Slice(episode_id=context.shop_floor.id, records=records)) - queue.store_group(context.shop_floor.id, key, records[1:]) + queue.store_group(context.shop_floor.id, key, original_records[1:]) return diff --git a/diploma_thesis/workflow/multi_simulation.py b/diploma_thesis/workflow/multi_simulation.py index d3a706b..d428d3f 100644 --- a/diploma_thesis/workflow/multi_simulation.py +++ b/diploma_thesis/workflow/multi_simulation.py @@ -27,8 +27,11 @@ def workflow_id(self) -> str: def run(self): parameters = self.__fetch_tasks__() - parameters = self.__add_debug_info__(parameters) - parameters = self.__append_output_dir__(parameters) + parameters = self.__passthrough_parameters__(dict( + debug=False, + output_dir='', + store_run_statistics=False + ), parameters) parameters = self.__fix_names__(parameters) print(f'Running {len(parameters)} simulations') @@ -53,23 +56,15 @@ def __fetch_tasks__(self): return result - def __add_debug_info__(self, simulations: [Dict]): + def __passthrough_parameters__(self, values, simulations: [Dict]): result = simulations - if self.parameters.get('debug', False): - for index, _ in enumerate(result): - result[index]['debug'] = True + for key, default in values.items(): + value = self.parameters.get(key, default) - return result - - def __append_output_dir__(self, simulations: [Dict]): - result = simulations - - output_dir = self.parameters.get('output_dir') - - if output_dir: - for index, _ in enumerate(result): - result[index]['output_dir'] = os.path.join(output_dir, result[index]['output_dir']) + if value is not None: + for index, _ in enumerate(result): + result[index][key] = value return result diff --git a/diploma_thesis/workflow/simulation.py b/diploma_thesis/workflow/simulation.py index 3ddbf6a..c114930 100644 --- a/diploma_thesis/workflow/simulation.py +++ b/diploma_thesis/workflow/simulation.py @@ -32,6 +32,10 @@ def workflow_id(self) -> str: def log_stdout(self): return self.parameters.get('log_stdout', False) + @property + def store_run_statistics(self): + return self.parameters.get('store_run_statistics', False) + def run_log_file(self, output_dir): if not self.parameters.get('log_run', False): return None @@ -76,7 +80,10 @@ def __run__(self, simulator: Simulator, output_dir: str): if not self.is_debug: reward_cache = simulator.train(environment, config) - self.__store_simulations__(config.simulations, reward_cache, simulation_output_dir) + self.__store_simulations__(config.simulations, + reward_cache, + self.store_run_statistics, + simulation_output_dir) agent_output_dir = os.path.join(output_dir, 'agent') @@ -109,7 +116,10 @@ def __evaluate__(self, simulator: Simulator, output_dir: str): if not self.is_debug: simulator.evaluate(environment, config) - self.__store_simulations__(config.simulations, reward_cache=None, output_dir=simulation_output_dir) + self.__store_simulations__(config.simulations, + reward_cache=None, + store_run_statistics=self.store_run_statistics, + output_dir=simulation_output_dir) def __make_simulator__(self): machine = machine_from_cli(parameters=self.parameters['machine_agent']) @@ -127,16 +137,24 @@ def __make_simulator__(self): return simulator @staticmethod - def __store_simulations__(simulations: List[simulator.Simulation], reward_cache: RewardCache, output_dir: str): + def __store_simulations__(simulations: List[simulator.Simulation], + reward_cache: RewardCache, + store_run_statistics: bool, + output_dir: str): reward_cache = RewardCache(batch_size=[]) if reward_cache is None else reward_cache machine_reward, work_center_reward = Simulation.__process_reward_cache__(reward_cache) for simulation in simulations: path = os.path.join(output_dir, simulation.simulation_id) + if not os.path.exists(path): + os.makedirs(path) + if shop_floor := simulation.shop_floor: statistics = shop_floor.statistics - statistics.save(path) + + if store_run_statistics: + statistics.save(path) sh_id = shop_floor.id.item()