Skip to content

Commit

Permalink
Fix incorrect override of __call__ method of nn.Modules
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Mar 12, 2024
1 parent 5c78a31 commit 3b073a0
Show file tree
Hide file tree
Showing 15 changed files with 58 additions and 80 deletions.
18 changes: 2 additions & 16 deletions diploma_thesis/agents/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Generic

import torch
from tensordict.prototype import tensorclass

from agents.utils.policy import Policy, PolicyRecord
from agents.utils import Phase, PhaseUpdatable
from agents.utils.policy import Policy, PolicyRecord
from utils import Loggable
from dataclasses import dataclass

State = TypeVar('State')
Input = TypeVar('Input')
Expand All @@ -31,22 +29,10 @@ def __call__(self, state: State, parameters: Input) -> Record:

class DeepPolicyModel(Model[Input, State, Action, Result], PhaseUpdatable, metaclass=ABCMeta):

@dataclass
class Configuration:
compile: bool = False

@staticmethod
def from_cli(parameters):
return DeepPolicyModel.Configuration(compile=parameters.get('compile', True))

def __init__(self, policy: Policy[Input], configuration: Configuration):
def __init__(self, policy: Policy[Input]):
super().__init__()

self.policy = policy
self.configuration = configuration

if configuration.compile:
self.policy.compile()

def update(self, phase: Phase):
super().update(phase)
Expand Down
10 changes: 3 additions & 7 deletions diploma_thesis/agents/machine/model/deep_multi_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

class DeepMultiRule(DeepPolicyMachineModel):

def __init__(self, rules: List[SchedulingRule],
policy: Policy[MachineInput],
configuration: DeepPolicyModel.Configuration):
super().__init__(policy, configuration)
def __init__(self, rules: List[SchedulingRule], policy: Policy[MachineInput]):
super().__init__(policy)

self.rules = rules

Expand Down Expand Up @@ -44,6 +42,4 @@ def from_cli(cls, parameters: Dict):

policy = policy_from_cli(policy_parameters)

configuration = DeepPolicyModel.Configuration.from_cli(parameters)

return cls(rules, policy, configuration)
return cls(rules, policy)
8 changes: 2 additions & 6 deletions diploma_thesis/agents/machine/model/deep_rule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Dict

import torch

from typing import Dict, List

from agents.utils.policy import from_cli as policy_from_cli
from .model import *
from .rule import SchedulingRule, ALL_SCHEDULING_RULES, IdleSchedulingRule


class DeepRule(DeepPolicyMachineModel):
Expand All @@ -23,6 +21,4 @@ def from_cli(cls, parameters: Dict):
policy_parameters = parameters['policy']
policy = policy_from_cli(policy_parameters)

configuration = DeepPolicyModel.Configuration.from_cli(parameters)

return cls(policy, configuration)
return cls(policy)
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/nn/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, kind: str):
self.kind = kind
self.activation = self.__make_activation__()

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
if self.activation is None:
return batch

Expand Down
6 changes: 3 additions & 3 deletions diploma_thesis/agents/utils/nn/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):

self.layer = nn.Flatten()

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
return self.layer(batch)

@classmethod
Expand All @@ -26,7 +26,7 @@ def __init__(self, normalized_shape: int | List[int]):

self.layer = nn.LayerNorm(normalized_shape=normalized_shape)

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
return self.layer(batch)

@classmethod
Expand All @@ -41,7 +41,7 @@ def __init__(self):

self.layer = nn.LazyInstanceNorm1d()

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
return self.layer(batch)

@classmethod
Expand Down
30 changes: 14 additions & 16 deletions diploma_thesis/agents/utils/nn/layers/graph_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from .layer import *

from typing import Dict
Expand All @@ -15,29 +14,28 @@ def signature(self) -> str | None:

class BaseWrapper(GraphLayer):

def __init__(self, configuration):
super().__init__()
def __init__(self, configuration):
super().__init__()

self.configuration = configuration
self.configuration = configuration

if 'signature' in configuration:
self._signature = configuration['signature']
if 'signature' in configuration:
self._signature = configuration['signature']

del configuration['signature']
else:
self._signature = None
del configuration['signature']
else:
self._signature = None

@property
def signature(self):
return self._signature or 'x -> x'
@property
def signature(self):
return self._signature or 'x -> x'

@classmethod
def from_cli(cls, parameters: Dict):
return cls(parameters)
@classmethod
def from_cli(cls, parameters: Dict):
return cls(parameters)


def common_graph_layer(layer):

class Wrapper(BaseWrapper):

def __init__(self, configuration):
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/nn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self,

self.__build__()

def __call__(self, batch: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
batch = self.linear(batch)
batch = self.activation(batch)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, channels: int):
self.channels = channels
self.norm = nn.InstanceNorm1d(num_features=channels)

def __call__(self, batch):
def forward(self, batch):
normalized = batch[:, :self.channels]
normalized = self.norm(normalized)

Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/utils/policy/flexible_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __get_values__(self, state):
def __get_actions__(self, state):
return self.action_model(state)

def __call__(self, state: State, parameters: Input) -> Record:
def forward(self, state: State, parameters: Input) -> Record:
values, actions = self.predict(state)
values, actions = values.squeeze(), actions.squeeze()
action, policy = self.action_selector(actions)
Expand Down
4 changes: 0 additions & 4 deletions diploma_thesis/agents/utils/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class Record:

class Policy(Generic[Input], nn.Module, PhaseUpdatable, metaclass=ABCMeta):

@abstractmethod
def __call__(self, state: State, parameters: Input) -> Record:
pass

@abstractmethod
def predict(self, state: State) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
pass
Expand Down
8 changes: 4 additions & 4 deletions diploma_thesis/configuration/experiments/jsp/dqn_path.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Evaluate the effectivenes of basic DQNs on the JSP environment

template: &template 'reference/marl_direct'
template: &template 'reference/marl_indirect'

###############################################################################################

Expand Down Expand Up @@ -236,7 +236,7 @@ long_single_source_run: &long_single_source_run
parameters:
mods:
__inout_factory__:
- [ [ 'duration/10000.yml' ] ]
- [ [ 'duration/25000.yml' ] ]
- [ [ 'size/jsp/10.yml' ] ]
- [ [ 'utilization/80.yml' ] ]
- [ [ 'initial_assignment/2.yml' ] ]
Expand Down Expand Up @@ -290,7 +290,7 @@ task:
n_workers: 4
debug: False
store_run_statistics: False
output_dir: 'results/jsp/experiments/dqn_path/marl_direct'
output_dir: 'results/jsp/experiments/dqn_path/marl_indirect'

tasks:
# TD - n-step short single source
Expand Down Expand Up @@ -342,7 +342,7 @@ task:
parameters:
base_path: 'configuration/mods/run/run.yml'
mods:
- 'n_workers/1.yml'
- 'n_workers/2.yml'
- 'timeline/warmup.yml'
nested:
parameters:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

parameters:
model:
parameters:
compile: True
2 changes: 1 addition & 1 deletion diploma_thesis/configuration/mods/run/run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ parameters:

machine_train_schedule:
pretrain_steps: 10
train_interval: 10
train_interval: 100
max_training_steps: 100000000

work_center_train_schedule:
Expand Down
8 changes: 4 additions & 4 deletions diploma_thesis/configuration/mods/simulation/simulation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ parameters:
processing_times:
kind: 'uniform'
parameters:
uniform: [ 10, 50 ]
uniform: [ 1, 50 ]
noise: [ 0, 10 ]
permutation:
uneveness: 0
uneveness: 2
due_time:
kind: 'uniform'
parameters:
uniform: [ 0.0, 2 ]
uniform: [ 0.5, 2 ]
job_arrival_time_on_machine:
kind: 'expected_utilization'
parameters:
Expand All @@ -41,6 +41,6 @@ parameters:
repair_duration:
kind: 'uniform'
parameters:
uniform: [ 10, 200 ]
uniform: [ 10, 300 ]

seed: 42
31 changes: 16 additions & 15 deletions diploma_thesis/configuration/simulation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ task:
- 'util/rules/all_rules.yml'
- 'util/optimizer/grad_norm.yml'
# - 'util/agent/multi_agent.yml'
- 'util/infrastructure/compile.yml'

work_center_agent:
kind: 'static'
Expand Down Expand Up @@ -94,21 +95,21 @@ task:
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
91, 92, 93, 94, 95, 96, 97, 98, 99, 100,
101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
141, 142, 143, 144, 145, 146, 148, 148, 149, 150,
151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192, 193, 194, 195, 196, 197, 198, 199
# 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
# 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
# 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
# 81, 82, 83, 84, 85, 86, 87, 88, 89, 90,
# 91, 92, 93, 94, 95, 96, 97, 98, 99, 100,
# 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
# 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
# 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
# 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
# 141, 142, 143, 144, 145, 146, 148, 148, 149, 150,
# 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
# 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
# 171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
# 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
# 191, 192, 193, 194, 195, 196, 197, 198, 199
]
# - kind: 'mod'
# parameters:
Expand Down

0 comments on commit 3b073a0

Please sign in to comment.