Skip to content

Commit

Permalink
Implement compilation support and add new modifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 12, 2024
1 parent 5f019d7 commit 5c78a31
Show file tree
Hide file tree
Showing 39 changed files with 743 additions and 136 deletions.
1 change: 0 additions & 1 deletion diploma_thesis/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class TrainingSample:
episode_id: int
records: List[Record]


@dataclass
class Slice(TrainingSample):
pass
Expand Down
16 changes: 15 additions & 1 deletion diploma_thesis/agents/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Generic

import torch
from tensordict.prototype import tensorclass

from agents.utils.policy import Policy, PolicyRecord
from agents.utils import Phase, PhaseUpdatable
from utils import Loggable
from dataclasses import dataclass

State = TypeVar('State')
Input = TypeVar('Input')
Expand All @@ -29,10 +31,22 @@ def __call__(self, state: State, parameters: Input) -> Record:

class DeepPolicyModel(Model[Input, State, Action, Result], PhaseUpdatable, metaclass=ABCMeta):

def __init__(self, policy: Policy[Input]):
@dataclass
class Configuration:
compile: bool = False

@staticmethod
def from_cli(parameters):
return DeepPolicyModel.Configuration(compile=parameters.get('compile', True))

def __init__(self, policy: Policy[Input], configuration: Configuration):
super().__init__()

self.policy = policy
self.configuration = configuration

if configuration.compile:
self.policy.compile()

def update(self, phase: Phase):
super().update(phase)
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/base/rl_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from agents.utils import TrainingPhase
from agents.utils.rl import RLTrainer
from utils import filter
Expand Down
10 changes: 7 additions & 3 deletions diploma_thesis/agents/machine/model/deep_multi_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

class DeepMultiRule(DeepPolicyMachineModel):

def __init__(self, rules: List[SchedulingRule], policy: Policy[MachineInput]):
super().__init__(policy)
def __init__(self, rules: List[SchedulingRule],
policy: Policy[MachineInput],
configuration: DeepPolicyModel.Configuration):
super().__init__(policy, configuration)

self.rules = rules

Expand Down Expand Up @@ -42,4 +44,6 @@ def from_cli(cls, parameters: Dict):

policy = policy_from_cli(policy_parameters)

return cls(rules, policy)
configuration = DeepPolicyModel.Configuration.from_cli(parameters)

return cls(rules, policy, configuration)
4 changes: 3 additions & 1 deletion diploma_thesis/agents/machine/model/deep_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ def from_cli(cls, parameters: Dict):
policy_parameters = parameters['policy']
policy = policy_from_cli(policy_parameters)

return cls(policy)
configuration = DeepPolicyModel.Configuration.from_cli(parameters)

return cls(policy, configuration)
31 changes: 18 additions & 13 deletions diploma_thesis/agents/utils/nn/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,7 @@ def input_dim(self):
return self.input_dim

def forward(self, state):
encoded_state = None
encoded_graph = None

if isinstance(state, TensorState) and self.state_encoder is not None:
encoded_state = self.state_encoder(torch.atleast_2d(state.state))

if isinstance(state, GraphState) and self.graph_encoder is not None:
data = state.graph.data

encoded_graph = self.graph_encoder(data)

hidden = self.merge(encoded_state, encoded_graph)
output = self.output(hidden)
output = self.__forward__(state)

_is_configured = True

Expand Down Expand Up @@ -100,6 +88,23 @@ def __build__(self):
self.merge = self.configuration.merge
self.output = nn.Sequential(*self.configuration.output)

def __forward__(self, state):
encoded_state = None
encoded_graph = None

if isinstance(state, TensorState) and self.state_encoder is not None:
encoded_state = self.state_encoder(torch.atleast_2d(state.state))

if isinstance(state, GraphState) and self.graph_encoder is not None:
data = state.graph.data

encoded_graph = self.graph_encoder(data)

hidden = self.merge(encoded_state, encoded_graph)
output = self.output(hidden)

return output

@staticmethod
def from_cli(parameters: dict) -> 'NeuralNetwork':
configuration = NeuralNetwork.Configuration.from_cli(parameters)
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import StrEnum
from typing import Dict

import torch

from agents.utils import NeuralNetwork, Phase
from agents.utils.action import ActionSelector, from_cli as action_selector_from_cli
from .policy import *
Expand Down
4 changes: 2 additions & 2 deletions diploma_thesis/agents/utils/rl/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def __init__(self, configuration: Configuration, *args, **kwargs):
def configure(self, model: Policy):
super().configure(model)

layer = Linear(1, 'none')
layer = Linear(1, 'none', dropout=None)

for critic in self.configuration.critics:
critic.neural_network.append_output_layer(layer)

def __train__(self, model: Policy):
try:
batch = self.storage.sample(update_returns=False)
batch, index = self.storage.sample(update_returns=False)
except NotReadyException:
return

Expand Down
2 changes: 0 additions & 2 deletions diploma_thesis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from workflow import Workflow, Simulation, Tournament, MultiSimulation

# torch.set_num_threads(1)


def make_workflow(configuration: Dict) -> Workflow:
configuration = configuration['task']
Expand Down
Loading

0 comments on commit 5c78a31

Please sign in to comment.