Skip to content

Commit

Permalink
WIP: Implement graph neural networks. Fixes in replay buffer storage
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 8, 2024
1 parent 3ee010d commit 504b3b4
Show file tree
Hide file tree
Showing 24 changed files with 445 additions and 220 deletions.
20 changes: 13 additions & 7 deletions diploma_thesis/agents/machine/state/auxiliary_graph_encoder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import torch
from tensordict.prototype import tensorclass

from agents.base.state import GraphState, Graph

from tensordict.prototype import tensorclass
from .encoder import *


class AuxiliaryGraphEncoder(StateEncoder):
class AuxiliaryGraphEncoder(GraphStateEncoder):

def __init__(self, is_homogeneous: bool):
super().__init__()

self.is_homogeneous = is_homogeneous

@tensorclass
class State(GraphState):
pass

def encode(self, parameters: StateEncoder.Input) -> State:
def __encode__(self, parameters: StateEncoder.Input) -> State:
if parameters.graph is None:
raise ValueError("Graph is not provided")

Expand All @@ -28,8 +31,11 @@ def encode(self, parameters: StateEncoder.Input) -> State:
processing_times = torch.cat(processing_times, dim=0).view(-1, 1)
graph.data[Graph.OPERATION_KEY].x = processing_times

return self.State(parameters.graph, batch_size=[])
if self.is_homogeneous:
graph.data = graph.data.to_homogeneous(node_attrs=['x'])

return self.State(graph, batch_size=[])

@staticmethod
def from_cli(parameters: dict):
return AuxiliaryGraphEncoder()
return AuxiliaryGraphEncoder(parameters.get('is_homogeneous', False))
18 changes: 17 additions & 1 deletion diploma_thesis/agents/machine/state/encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABCMeta
from abc import ABCMeta, abstractmethod
from typing import List, TypeVar

import torch

from torch_geometric.data import Batch

from agents.base import Encoder
from agents.machine import MachineInput

Expand All @@ -16,3 +18,17 @@ class StateEncoder(Encoder[MachineInput, State], metaclass=ABCMeta):
@staticmethod
def __to_list_of_tensors__(parameters: List) -> List[torch.FloatTensor]:
return [parameter if torch.is_tensor(parameter) else torch.tensor(parameter) for parameter in parameters]


class GraphStateEncoder(StateEncoder, metaclass=ABCMeta):

def encode(self, parameters: StateEncoder.Input) -> State:
result = self.__encode__(parameters)

result.graph.data = Batch.from_data_list([result.graph.data])

return result

@abstractmethod
def __encode__(self, parameters: StateEncoder.Input) -> State:
pass
61 changes: 48 additions & 13 deletions diploma_thesis/agents/utils/memory/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@

from abc import ABCMeta, abstractmethod
from typing import TypeVar
from dataclasses import field
from typing import TypeVar, List, Generic, Dict

import torch
from tensordict.prototype import tensorclass

from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer
from tensordict.prototype import tensorclass
from torchrl.data import ReplayBuffer
from dataclasses import dataclass
from torchrl.data import LazyMemmapStorage, ListStorage

State = TypeVar('State')
Action = TypeVar('Action')
Configuration = TypeVar('Configuration')
_Configuration = TypeVar('_Configuration')


@dataclass
class Configuration:
size: int
is_tensordict_storage: bool
batch_size: int
prefetch: int


@tensorclass
Expand Down Expand Up @@ -42,20 +53,33 @@ class NotReadyException(BaseException):
pass


class Memory(metaclass=ABCMeta):
class Memory(Generic[_Configuration], metaclass=ABCMeta):

def __init__(self, configuration: Configuration):
def __init__(self, configuration: _Configuration):
self.configuration = configuration
self.buffer: TensorDictReplayBuffer = self.__make_buffer__()
self.buffer: ReplayBuffer = self.__make_buffer__()

def store(self, records: List[Record]):
if self.configuration.is_tensordict_storage:
records = torch.cat(records, dim=0)

def store(self, record: Record):
self.buffer.extend(record)
self.buffer.extend(records)
else:
self.buffer.extend(records)

def sample(self, return_info: bool = False) -> Record:
if len(self.buffer) < self.buffer._batch_size:
def sample(self, return_info: bool = False) -> List[Record]:
if len(self.buffer) < self.configuration.batch_size:
raise NotReadyException()

return self.buffer.sample(return_info=return_info)
result = self.buffer.sample(return_info=return_info)

if self.configuration.is_tensordict_storage:
if not isinstance(result, list):
result = [result]

return list(record.unbind(dim=0) for record in result)

return result

def sample_n(self, batch_size: int) -> Record:
return self.buffer.sample(batch_size=batch_size)
Expand All @@ -64,7 +88,7 @@ def update_priority(self, indices: torch.LongTensor, priorities: torch.Tensor):
self.buffer.update_priority(indices, priorities)

@abstractmethod
def __make_buffer__(self) -> TensorDictReplayBuffer:
def __make_buffer__(self) -> ReplayBuffer:
pass

def clear(self):
Expand All @@ -73,6 +97,17 @@ def clear(self):
def __len__(self) -> int:
return len(self.buffer)

def __make_result_buffer__(self, params: Dict, regular_cls, tensordict_cls):
if self.configuration.is_tensordict_storage:
params['storage'] = LazyMemmapStorage(max_size=self.configuration.size)
cls = tensordict_cls
else:
params['storage'] = ListStorage(max_size=self.configuration.size)
params['collate_fn'] = lambda x: x
cls = regular_cls

return cls(**params)

# TorchRL buffer isn't yet pickable. Hence, we recreate it from the configuration

def __getstate__(self):
Expand Down
53 changes: 27 additions & 26 deletions diploma_thesis/agents/utils/memory/prioritized_replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,44 @@
from dataclasses import dataclass
from typing import Dict

from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBuffer
from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, PrioritizedReplayBuffer, ListStorage

from .memory import *


class PrioritizedReplayMemory(Memory):
@dataclass
class Configuration:
alpha: float
beta: float
size: int
batch_size: int
prefetch: int = 1

@staticmethod
def from_cli(parameters: Dict):
return PrioritizedReplayMemory.Configuration(
alpha=parameters.get('alpha', 0.6),
beta=parameters.get('beta', 0.4),
size=parameters['size'],
batch_size=parameters['batch_size'],
prefetch=parameters.get('prefetch', 1)
)

def __make_buffer__(self) -> TensorDictPrioritizedReplayBuffer:
storage = LazyMemmapStorage(max_size=self.configuration.size)

return TensorDictPrioritizedReplayBuffer(
storage=storage,
@dataclass
class Configuration(Configuration):
alpha: float
beta: float

@staticmethod
def from_cli(parameters: Dict):
return Configuration(
alpha=parameters.get('alpha', 0.6),
beta=parameters.get('beta', 0.4),
size=parameters['size'],
is_tensordict_storage=parameters.get('is_tensordict_storage', False),
batch_size=parameters['batch_size'],
prefetch=parameters.get('prefetch', 1)
)


class PrioritizedReplayMemory(Memory[Configuration]):

def __make_buffer__(self) -> PrioritizedReplayBuffer | TensorDictPrioritizedReplayBuffer:
cls = None

params = dict(
batch_size=self.configuration.batch_size,
prefetch=self.configuration.prefetch,
alpha=self.configuration.alpha,
beta=self.configuration.beta
)

return self.__make_result_buffer__(params, PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer)

@staticmethod
def from_cli(parameters: Dict) -> 'PrioritizedReplayMemory':
configuration = PrioritizedReplayMemory.Configuration.from_cli(parameters)
configuration = Configuration.from_cli(parameters)

return PrioritizedReplayMemory(configuration)
53 changes: 27 additions & 26 deletions diploma_thesis/agents/utils/memory/replay_memory.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@

from dataclasses import dataclass

from torchrl.data import LazyMemmapStorage
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer, ListStorage

from .memory import *
from .sampler import *

from .memory import Configuration as MemoryConfiguration


class ReplayMemory(Memory):
@dataclass
class Configuration(MemoryConfiguration):
sampler: Sampler

@classmethod
def from_cli(cls, parameters: Dict):
return cls(
size=parameters['size'],
batch_size=parameters['batch_size'],
prefetch=parameters.get('prefetch', 0),
is_tensordict_storage=parameters.get('is_tensordict_storage', False),
sampler=Sampler.from_cli(parameters['sampler']) if 'sampler' in parameters else None
)

@dataclass
class Configuration:
size: int
batch_size: int
prefetch: int
sampler: Sampler

@staticmethod
def from_cli(parameters: Dict):
return ReplayMemory.Configuration(
size=parameters['size'],
batch_size=parameters['batch_size'],
prefetch=parameters.get('prefetch', 0),
sampler=Sampler.from_cli(parameters['sampler']) if 'sampler' in parameters else None
)
class ReplayMemory(Memory[Configuration]):

def __make_buffer__(self) -> TensorDictReplayBuffer:
storage = LazyMemmapStorage(max_size=self.configuration.size)
def __make_buffer__(self) -> ReplayBuffer | TensorDictReplayBuffer:
sampler = self.configuration.sampler.make() if self.configuration.sampler else RandomSampler()
cls = None

params = dict(
batch_size=self.configuration.batch_size,
prefetch=self.configuration.prefetch,
sampler=sampler
)

return TensorDictReplayBuffer(storage=storage,
batch_size=self.configuration.batch_size,
sampler=sampler,
prefetch=self.configuration.prefetch)
return self.__make_result_buffer__(params, ReplayBuffer, TensorDictReplayBuffer)

@staticmethod
def from_cli(parameters: Dict) -> 'ReplayMemory':
configuration = ReplayMemory.Configuration.from_cli(parameters)
configuration = Configuration.from_cli(parameters)

return ReplayMemory(configuration)

Expand Down
26 changes: 2 additions & 24 deletions diploma_thesis/agents/utils/nn/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@

from .layer import Layer

from .linear import Linear
from .common import Flatten, InstanceNorm, LayerNorm
from .activation import Activation
from .merge import Merge
from .graph import GraphLayer
from .partial_instance_norm_1d import PartialInstanceNorm1d

from utils import from_cli
from functools import partial

key_to_class = {
'linear': Linear,
'flatten': Flatten,
'activation': Activation,
'layer_norm': LayerNorm,
'instance_norm': InstanceNorm,
'partial_instance_norm': PartialInstanceNorm1d,
'noisy_linear': ...,
'graph': GraphLayer,
'merge': merge
}

from_cli = partial(from_cli, key_to_class=key_to_class)

from .linear import Linear
from .cli import from_cli
36 changes: 36 additions & 0 deletions diploma_thesis/agents/utils/nn/layers/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

def from_cli(*args, **kwargs):
from .linear import Linear
from .common import Flatten, InstanceNorm, LayerNorm
from .activation import Activation
from .merge import Merge
from .graph_model import GraphModel
from .graph_layer import common_graph_layer, common_operation
from .partial_instance_norm_1d import PartialInstanceNorm1d

from utils import from_cli as from_cli_

import torch_geometric as pyg

key_to_class = {
'flatten': Flatten,
'merge': Merge,
'activation': Activation,

'linear': Linear,
'layer_norm': LayerNorm,
'instance_norm': InstanceNorm,
'partial_instance_norm': PartialInstanceNorm1d,
'noisy_linear': ...,

'graph_model': GraphModel,
'gin': common_graph_layer(pyg.nn.GINConv),
'sage': common_graph_layer(pyg.nn.SAGEConv),
'gat': common_graph_layer(pyg.nn.GATConv),
'gcn': common_graph_layer(pyg.nn.GCNConv),
'deep_gcn': common_graph_layer(pyg.nn.DeepGCNLayer),

'add_pool': common_operation(pyg.nn.global_add_pool),
}

return from_cli_(*args, **kwargs, key_to_class=key_to_class)
Loading

0 comments on commit 504b3b4

Please sign in to comment.