Skip to content

Commit

Permalink
Implement fixes for new neural network architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 21, 2024
1 parent 8010449 commit ac7c51e
Show file tree
Hide file tree
Showing 40 changed files with 750 additions and 721 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __encode__(self, parameters: StateEncoder.Input) -> State:
if j == 0:
moment = parameters.machine.shop_floor.now
else:
moment = completions_times[j-1].max()
moment = completions_times[j-1]

completions_times[j] = moment + job.processing_times[j]
completions_times[j] = moment + job.processing_times[j].min()

status = torch.ones_like(job.step_idx)
status = self.__fill_job_matrix__(job, status)
Expand Down
3 changes: 3 additions & 0 deletions diploma_thesis/agents/utils/nn/layers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def from_cli(*args, **kwargs):
from .common import Flatten, InstanceNorm, LayerNorm
from .activation import Activation
from .merge import Merge
from .shared import Shared
from .graph_model import GraphModel
from .partial_instance_norm_1d import PartialInstanceNorm1d
from .output import Output
Expand Down Expand Up @@ -33,6 +34,8 @@ def from_cli(*args, **kwargs):
'max_pool': MaxPool,
'mean_pool': MeanPool,

'select_target': SelectTarget,
'shared': Shared,
'output': Output
}

Expand Down
29 changes: 21 additions & 8 deletions diploma_thesis/agents/utils/nn/layers/graph_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Dict

import torch.nn
import torch_geometric as pyg

from agents.base.state import Graph
from .layer import *


Expand All @@ -12,20 +14,16 @@ class GraphLayer(Layer):
class BaseWrapper(GraphLayer):

def __init__(self, configuration):
super().__init__()

self.configuration = configuration
self._signature = configuration.get('signature')

if 'signature' in configuration:
self._signature = configuration['signature']

del configuration['signature']
else:
self._signature = None

super().__init__(self._signature)

@property
def signature(self):
return self._signature or 'x -> x'
return self._signature

@classmethod
def from_cli(cls, parameters: Dict):
Expand Down Expand Up @@ -111,3 +109,18 @@ class MeanPool(GraphFunctionWrapper):

def __init__(self, configuration):
super().__init__(pyg.nn.global_mean_pool, configuration)


class SelectTarget(BaseWrapper):

def __init__(self, configuration):
super().__init__(configuration)

def forward(self, graph: Graph | pyg.data.Batch, embeddings: torch.FloatTensor):
storage = graph

if isinstance(graph, Graph):
storage = graph.data

return embeddings[storage[Graph.TARGET_KEY]]

20 changes: 6 additions & 14 deletions diploma_thesis/agents/utils/nn/layers/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ def from_cli(parameters: Dict) -> 'GraphModel.Configuration':
(from_cli(layer), layer.get('parameters', {}).get('signature'))
for layer in parameters['layers']
],
signature=parameters.get('signature'),
signature=parameters['signature'],
hetero_aggregation=parameters.get('hetero_aggregation', 'mean'),
hetero_aggregation_key=parameters.get('hetero_aggregation_key', 'operation')
)

def __init__(self, configuration: Configuration):
super().__init__()
super().__init__(configuration.signature)

self.is_configured = False
self.model: pyg.nn.Sequential = None
self.configuration = configuration

self.__build__()

def forward(self, graph: Graph | pyg.data.Batch) -> torch.Tensor:
def forward(self, graph: Graph | pyg.data.Batch) -> Tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
batch: pyg.data.Batch = None

if isinstance(graph, Graph):
Expand All @@ -52,18 +52,10 @@ def forward(self, graph: Graph | pyg.data.Batch) -> torch.Tensor:

self.__configure_if_needed__(batch)

if isinstance(batch, pyg.data.HeteroData):
return self.forward_heterogeneous(graph, batch)
# if isinstance(batch, pyg.data.HeteroData):
# return self.forward_heterogeneous(graph, batch), batch.batch_dict

return self.forward_homogeneous(graph, batch)

def forward_homogeneous(self, graph, batch):
hidden = self.model(batch.x, batch.edge_index, batch.batch)

if hidden.shape[0] == batch.num_graphs:
return hidden
else:
return hidden[batch.target]
return self.model(batch.x, batch.edge_index, batch.batch), batch.batch

def forward_heterogeneous(self, graph, batch):
assert False, "Heterogeneous input is not yet supported"
Expand Down
29 changes: 29 additions & 0 deletions diploma_thesis/agents/utils/nn/layers/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch_geometric as pyg

from .layer import Layer
from .cli import from_cli


class Shared(Layer):

def __init__(self, values: list[str], input_args, layers: [Layer]):
values = ', '.join(values)
signature = f'{values} -> {values}'

super().__init__(signature=signature)
#
self.model = pyg.nn.Sequential(
input_args,
[(layer, layer.signature) for layer in layers]
)

def forward(self, *args) -> tuple:
return tuple([self.model(arg) for arg in args])

@classmethod
def from_cli(cls, parameters: dict) -> 'Layer':
return Shared(
values=parameters['values'],
input_args=parameters['input_args'],
layers=[from_cli(layer) for layer in parameters['layers']]
)
4 changes: 4 additions & 0 deletions diploma_thesis/agents/utils/nn/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def from_cli(parameters: dict):
def __init__(self, configuration: Configuration):
super().__init__()

self.is_configured = False

self.configuration = configuration

self.__build__()

def forward(self, state):
output = self.__forward__(state)

self.is_configured = True

return output

def to_noisy(self, noise_parameters):
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/utils/nn/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def __make_optimizer__(self, parameters):
cls = torch.optim.SGD
case 'asgd':
cls = torch.optim.ASGD
case 'rmsprop':
cls = torch.optim.RMSprop
case _:
raise ValueError(f'Unknown optimizer kind: {self.configuration.optimizer.kind}')

Expand Down
7 changes: 2 additions & 5 deletions diploma_thesis/agents/utils/policy/action_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def __init__(self,

def __post_init__(self):
if self.noise_parameters is not None:
self.actor.to_noisy(self.noise_parameters)

if self.critic is not None:
self.critic.to_noisy(self.noise_parameters)
self.model.to_noisy(self.noise_parameters)

def configure(self, configuration: RunConfiguration):
self.run_configuration = configuration
Expand Down Expand Up @@ -104,7 +101,7 @@ def __fetch_values_and_actions__(output: TensorDict):
actions = output[Keys.ACTIONS]
values = output.get(Keys.VALUE, actions)

return actions, values
return values, actions

@staticmethod
def base_parameters_from_cli(parameters: Dict):
Expand Down
10 changes: 2 additions & 8 deletions diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,16 @@ class FlexibleAction(ActionPolicy):
def __init__(self, base_parameters):
super().__init__(**base_parameters)

self.action_layer = self.make_linear_layer(1)

@property
def is_recurrent(self):
return False

def configure(self, configuration: RunConfiguration):
super().configure(configuration)

self.action_layer.to(configuration.device)

def post_encode(self, state: State, outputs):
values, actions = self.__fetch_values_and_actions__(outputs)

actions = self.action_layer(actions)

# Unpack node embeddings obtained from graph batch
if state.graph is not None and isinstance(state.graph, pyg.data.Batch):
result = []
Expand All @@ -46,9 +40,9 @@ def post_encode(self, state: State, outputs):
actions = torch.nn.utils.rnn.pad_sequence(result, batch_first=True, padding_value=torch.nan)
lengths = torch.tensor(lengths)

return values, self.__estimate_policy__(values, (actions, lengths))
return self.__estimate_policy__(values, (actions, lengths))

return values, self.__estimate_policy__(values, actions)
return self.__estimate_policy__(values, actions)

def __estimate_policy__(self, value, actions):
if isinstance(actions, tuple):
Expand Down
Loading

0 comments on commit ac7c51e

Please sign in to comment.