Skip to content

Commit

Permalink
Implement fixes based on new graph storage
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 10, 2024
1 parent 9169506 commit be25a6f
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 89 deletions.
2 changes: 1 addition & 1 deletion diploma_thesis/agents/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from .agent import Agent
from .encoder import Encoder
from .encoder import Encoder, GraphEncoder
from .model import Model
from .state import GraphState, TensorState, Graph
from .rl_agent import RLAgent
Expand Down
50 changes: 50 additions & 0 deletions diploma_thesis/agents/base/encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import abstractmethod
from typing import TypeVar, Generic

from torch_geometric.data import Batch
from torch_geometric.transforms import ToUndirected

from utils import Loggable
from .state import GraphState, Graph

Input = TypeVar('Input')
State = TypeVar('State')
Expand All @@ -12,3 +16,49 @@ class Encoder(Loggable, Generic[Input, State]):
@abstractmethod
def encode(self, parameters: Input) -> State:
pass


class GraphEncoder(Encoder, Generic[Input, State]):

def __init__(self, is_homogeneous: False, is_undirected: False, is_local: False):
super().__init__()

self.is_homogeneous = is_homogeneous
self.is_undirected = is_undirected
self.is_local = is_local
self.to_undirected = ToUndirected()

def encode(self, parameters: Input) -> State:
if self.is_local:
parameters.graph = self.__localize__(parameters, parameters.graph)

result = self.__encode__(parameters)

if self.is_homogeneous:
result.graph.data = result.graph.data.to_homogeneous(node_attrs=['x'])

del result.graph.data[Graph.JOB_INDEX_MAP]
del result.graph.data[Graph.MACHINE_INDEX_KEY]

if self.is_undirected:
result.graph.data = self.to_undirected(result.graph.data)

result.graph.data = Batch.from_data_list([result.graph.data])

return result

@abstractmethod
def __encode__(self, parameters: Input) -> State | GraphState:
pass

@abstractmethod
def __localize__(self, parameters: Input, graph: Graph):
pass

@classmethod
def base_parameters_from_cli(cls, parameters):
return dict(
is_homogeneous=parameters.get('is_homogeneous', False),
is_undirected=parameters.get('is_undirected', False),
is_local=parameters.get('is_local', False)
)
22 changes: 7 additions & 15 deletions diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@

from typing import Dict

from tensordict.prototype import tensorclass

from agents.base.state import GraphState, Graph
from agents.base.state import GraphState
from .encoder import *


class AuxiliaryGraphEncoder(GraphStateEncoder):

def __init__(self, is_homogeneous: bool):
super().__init__()

self.is_homogeneous = is_homogeneous

@tensorclass
class State(GraphState):
pass
Expand All @@ -31,14 +29,8 @@ def __encode__(self, parameters: StateEncoder.Input) -> State:
processing_times = torch.cat(processing_times, dim=0).view(-1, 1)
graph.data[Graph.OPERATION_KEY].x = processing_times

if self.is_homogeneous:
graph.data = graph.data.to_homogeneous(node_attrs=['x'])

del graph.data[Graph.JOB_INDEX_MAP]
del graph.data[Graph.MACHINE_INDEX_KEY]

return self.State(graph, batch_size=[])

@staticmethod
def from_cli(parameters: dict):
return AuxiliaryGraphEncoder(parameters.get('is_homogeneous', False))
@classmethod
def from_cli(cls, parameters: dict):
return AuxiliaryGraphEncoder(**cls.base_parameters_from_cli(parameters))
21 changes: 11 additions & 10 deletions diploma_thesis/agents/machine/state/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from torch_geometric.data import Batch

from agents.base import Encoder
from agents.base import Encoder, GraphEncoder, Graph
from agents.base.encoder import Input
from agents.machine import MachineInput

State = TypeVar('State')
Expand All @@ -20,15 +21,15 @@ def __to_list_of_tensors__(parameters: List) -> List[torch.FloatTensor]:
return [parameter if torch.is_tensor(parameter) else torch.tensor(parameter) for parameter in parameters]


class GraphStateEncoder(StateEncoder, metaclass=ABCMeta):
class GraphStateEncoder(GraphEncoder, metaclass=ABCMeta):

def encode(self, parameters: StateEncoder.Input) -> State:
result = self.__encode__(parameters)
def __localize__(self, parameters: StateEncoder.Input, graph: Graph):
job_ids = graph.data[Graph.JOB_INDEX_MAP][:, 0]
queued_job_ids = torch.cat([job.id.view(-1) for job in parameters.machine.queue])
mask = torch.isin(job_ids, queued_job_ids, assume_unique=True)
idx = torch.nonzero(mask).view(-1)

result.graph.data = Batch.from_data_list([result.graph.data])
graph.data = graph.data.subgraph({Graph.OPERATION_KEY: idx})
graph.data[Graph.JOB_INDEX_MAP] = graph.data[Graph.JOB_INDEX_MAP][mask]

return result

@abstractmethod
def __encode__(self, parameters: StateEncoder.Input) -> State:
pass
return graph
11 changes: 7 additions & 4 deletions diploma_thesis/agents/utils/nn/layers/graph_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(self, configuration):
def signature(self):
return self._signature or 'x, edge_index -> x'

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def forward(self, x, edge_index):
return self.model(x, edge_index)

return Wrapper

Expand All @@ -62,7 +62,10 @@ class Wrapper(BaseWrapper):
def signature(self):
return self._signature or 'x, batch -> x'

def forward(self, *args, **kwargs):
return fn(*args, **kwargs)
def forward(self, x, batch):
if isinstance(batch, tuple):
return fn(x, batch=batch[0])

return fn(x, batch=batch)

return Wrapper
42 changes: 12 additions & 30 deletions diploma_thesis/agents/utils/nn/layers/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,14 @@
import torch_geometric as pyg


class OutputType(StrEnum):
NODE = 'node'
EDGE = 'edge'
GLOBAL = 'global'


class GraphModel(Layer):

@dataclass
class Configuration:

@dataclass
class Output:
node_key: str | None
kind: OutputType

@staticmethod
def from_cli(parameters: Dict) -> 'GraphModel.Configuration.Output':
return GraphModel.Configuration.Output(
node_key=parameters.get('node_key'),
kind=OutputType(parameters['kind'])
)

layers: List[Tuple[Layer, str | None]]

hetero_aggregation: str = 'mean'
output: Output = None
hetero_aggregation_key: str = 'operation'

@staticmethod
def from_cli(parameters: Dict) -> 'GraphModel.Configuration':
Expand All @@ -47,7 +29,7 @@ def from_cli(parameters: Dict) -> 'GraphModel.Configuration':
for layer in parameters['layers']
],
hetero_aggregation=parameters.get('hetero_aggregation', 'mean'),
output=GraphModel.Configuration.Output.from_cli(parameters.get('output', {}))
hetero_aggregation_key=parameters.get('hetero_aggregation_key', 'operation')
)

def __init__(self, configuration: Configuration):
Expand All @@ -66,7 +48,7 @@ def forward(self, batch: pyg.data.Batch) -> torch.Tensor:
# Result is the dict for each edge_type
hidden = self.model(batch.x_dict, batch.edge_index_dict, batch.batch_dict)

return hidden
return self.__process_heterogeneous_output__(hidden)

hidden = self.model(batch.x, batch.edge_index, batch.batch)

Expand All @@ -93,15 +75,15 @@ def __configure_if_needed__(self, graph: pyg.data.Data | pyg.data.HeteroData):

self.is_configured = True

def __process_heterogeneous_output__(self, graph: pyg.data.Data, output: torch.Tensor):
match self.configuration.output.kind:
case OutputType.EDGE:
return output
case OutputType.GLOBAL:
pass
case OutputType.NODE:
pass
def __process_heterogeneous_output__(self, output: Dict[Tuple[str, str, str], torch.Tensor]) -> torch.Tensor:
result = []

for key, embeddings in output.items():
if key[0] == self.configuration.hetero_aggregation_key:
result += [embeddings]

# TODO: Use aggregation
return torch.stack(result).mean(dim=0)

@classmethod
def from_cli(cls, parameters: dict) -> 'Layer':
Expand Down
5 changes: 2 additions & 3 deletions diploma_thesis/agents/utils/policy/discrete_action.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import copy
from enum import StrEnum
from typing import Dict

from agents.utils import NeuralNetwork, Phase
from agents.utils.nn.layers.linear import Linear
from agents.utils.action import ActionSelector, from_cli as action_selector_from_cli
from agents.utils.nn.layers.linear import Linear
from .policy import *
from enum import StrEnum


class PolicyEstimationMethod(StrEnum):
Expand Down Expand Up @@ -53,7 +53,6 @@ def __call__(self, state: State, parameters: Input) -> Record:
return Record(state, action, info, batch_size=[])

def predict(self, state: State):
values = torch.tensor(0, dtype=torch.long)
actions = torch.tensor(0, dtype=torch.long)

if self.action_model is not None:
Expand Down
4 changes: 2 additions & 2 deletions diploma_thesis/agents/utils/rl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def from_cli(parameters: Dict):
def __init__(self, configuration: Configuration, *args, **kwargs):
super().__init__(*args, is_episodic=False, **kwargs)

self._target_model: AveragedModel | None = None
self._target_models: AveragedModel | None = None
self.configuration = configuration

def configure(self, model: Policy):
Expand Down Expand Up @@ -60,7 +60,7 @@ def __train__(self, model: Policy):
with torch.no_grad():
td_error += self.configuration.prior_eps

self.memory.update_priority(info['index'], td_error)
self.storage.update_priority(info['index'], td_error)

def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
# Note:
Expand Down
3 changes: 2 additions & 1 deletion diploma_thesis/agents/utils/rl/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def __init__(self,

self.loss = loss
self.optimizer = optimizer
self._is_configured = False
self.train_schedule = train_schedule
self.return_estimator = return_estimator

self._is_configured = False
self.storage = Storage(is_episodic, memory, return_estimator)
self.loss_cache = []

Expand Down
13 changes: 8 additions & 5 deletions diploma_thesis/agents/utils/rl/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def sample(self, update_returns: bool = True):

return self.__process_batched_data__(batch, update_returns), info

def update_priority(self, indices: torch.LongTensor, priorities: torch.FloatTensor):
self.memory.update_priority(indices, priorities)

def clear(self):
self.memory.clear()

Expand Down Expand Up @@ -75,8 +78,8 @@ def __process_episodic_batched_data__(self, batch, update_returns: bool):
if isinstance(batch[0][0].state, GraphState) and isinstance(batch[0][0].next_state, GraphState):
records = reduce(lambda x, y: x + y, batch)

result.state.graph = Graph(self.__collate_graphs__([record.state.graph for record in records]))
result.next_state.graph = Graph(self.__collate_graphs__([record.next_state.graph for record in records]))
result.state.graph = self.__collate_graphs__([record.state.graph for record in records])
result.next_state.graph = self.__collate_graphs__([record.next_state.graph for record in records])

return result

Expand All @@ -87,12 +90,12 @@ def __process_batched_data__(self, batch, update_returns: bool):
result = torch.cat([element.view(-1) for element in batch], dim=0)

if isinstance(batch[0].state, GraphState) and isinstance(batch[0].next_state, GraphState):
result.state.graph.data = self.__collate_graphs__([record.state.graph for record in batch])
result.next_state.graph.data = self.__collate_graphs__([record.next_state.graph for record in batch])
result.state.graph = self.__collate_graphs__([record.state.graph for record in batch])
result.next_state.graph = self.__collate_graphs__([record.next_state.graph for record in batch])

return result

def __collate_graphs__(self, records: List[Graph]):
data = reduce(lambda x, y: x + y, (record.data.to_data_list() for record in records))

return Batch.from_data_list(data)
return Graph(Batch.from_data_list(data))
3 changes: 2 additions & 1 deletion diploma_thesis/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@

import argparse
import torch
from typing import Dict

import torch._dynamo
import yaml

from workflow import Workflow, Simulation, Tournament, MultiSimulation

torch.set_num_threads(1)
torch._dynamo.config.suppress_errors = True


def make_workflow(configuration: Dict) -> Workflow:
Expand Down
4 changes: 2 additions & 2 deletions diploma_thesis/configuration/jsp_stream_experiment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ task:
kind: 'td'
parameters:
memory: 32
send_as_trajectory: True
emit_trajectory: True
next_state_record_mode: 'on_next_action'

graph:
Expand Down Expand Up @@ -644,7 +644,7 @@ task:
kind: 'td'
parameters:
memory: 32
send_as_trajectory: True
emit_trajectory: True
next_state_record_mode: 'on_next_action'

graph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
encoder:
kind: 'auxiliary'
parameters:
is_homogeneous: True
is_homogeneous: False
is_undirected: True
is_local: True
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ graph:

- kind: 'mean_pool'

output:
kind: 'node'
node_key: 'operation'

output:
- kind: 'linear'
parameters:
Expand Down
Loading

0 comments on commit be25a6f

Please sign in to comment.