Skip to content

Commit

Permalink
Implement MARL
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Feb 24, 2024
1 parent 113db8a commit de943a2
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 14 deletions.
6 changes: 5 additions & 1 deletion diploma_thesis/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from utils import Loggable
from .encoder import Encoder as StateEncoder, Input, State
from .model import Model, Action, Result
from environment import ShopFloor

Key = TypeVar('Key')

Expand Down Expand Up @@ -42,6 +43,9 @@ def update(self, phase: Phase):
if isinstance(module, PhaseUpdatable):
module.update(phase)

def setup(self, shop_floor: ShopFloor):
pass

@property
@abstractmethod
def is_trainable(self):
Expand All @@ -55,7 +59,7 @@ def train_step(self):
def store(self, key: Key, record: Record):
pass

def schedule(self, parameters: Input) -> Model.Record:
def schedule(self, key: Key, parameters: Input) -> Model.Record:
state = self.encode_state(parameters)

return self.model(state, parameters)
Expand Down
86 changes: 86 additions & 0 deletions diploma_thesis/agents/base/marl_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import copy
from typing import Dict

from agents.utils import TrainingPhase
from agents.utils.rl import RLTrainer
from utils import filter
from .agent import *
from .model import NNModel


class MARLAgent(Generic[Key], Agent[Key]):

def __init__(self, model: NNModel, state_encoder: StateEncoder, trainer: RLTrainer, is_model_distributed: bool):
super().__init__(model, state_encoder)

self.trainer: RLTrainer | Dict[Key, RLTrainer] = trainer
self.is_model_distributed = is_model_distributed
self.is_configured = False
self.keys = None

@property
def is_trainable(self):
return True

@abstractmethod
def iterate_keys(self, shop_floor: ShopFloor):
pass

def setup(self, shop_floor: ShopFloor):
if self.is_configured:
is_key_set_equal = set(self.keys) == set(self.iterate_keys(shop_floor))
is_evaluating_with_centralized_model = self.phase == EvaluationPhase() and not self.is_model_distributed

assert is_key_set_equal or is_evaluating_with_centralized_model, \
("Multi-Agent model should be configured for the same shop floor architecture "
"or have centralized action network")
return

self.is_configured = True
self.keys = list(self.iterate_keys(shop_floor))

base_model = self.model
base_trainer = self.trainer

self.model = dict() if self.is_model_distributed else self.model
self.trainer = dict()

for key in self.keys:
if self.is_model_distributed:
self.model[key] = copy.deepcopy(base_model)

self.trainer[key] = copy.deepcopy(base_trainer)

@filter(lambda self: self.phase == TrainingPhase())
def train_step(self):
for key in self.keys:
self.trainer[key].train_step(self.__model_for_key__(key))

@filter(lambda self, *args: self.phase == TrainingPhase())
def store(self, key: Key, record: Record):
self.trainer[key].store(record)

def loss_record(self):
result = [self.trainer[key].loss_record() for key in self.keys]

return result

def clear_memory(self):
for key in self.keys:
self.trainer[key].clear()

def schedule(self, key: Key, parameters):
state = self.encode_state(parameters)

result = self.__model_for_key__(key)(state, parameters)

if not self.trainer[key].is_configured:
self.trainer[key].configure(self.model)

return result

def __model_for_key__(self, key: Key):
if self.is_model_distributed:
return self.model[key]

return self.model
4 changes: 2 additions & 2 deletions diploma_thesis/agents/base/rl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def loss_record(self):
def clear_memory(self):
self.trainer.clear()

def schedule(self, parameters):
result = super().schedule(parameters)
def schedule(self, key, parameters):
result = super().schedule(key, parameters)

if not self.trainer.is_configured:
self.trainer.configure(self.model)
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/machine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from .utils import Input as MachineInput
from .static import StaticMachine
from .rl import RLMachine
from .marl import MARLMachine


key_to_class = {
"static": StaticMachine,
'rl': RLMachine,
'marl': MARLMachine
}


Expand Down
27 changes: 27 additions & 0 deletions diploma_thesis/agents/machine/marl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict

from agents.base.marl_agent import MARLAgent
from agents.utils.rl import from_cli as rl_trainer_from_cli
from environment import MachineKey, ShopFloor
from .model import NNMachineModel, from_cli as model_from_cli
from .state import from_cli as state_encoder_from_cli


class MARLMachine(MARLAgent[MachineKey]):

def iterate_keys(self, shop_floor: ShopFloor):
for work_center in shop_floor.work_centers:
for machine in work_center.machines:
yield machine.key

@staticmethod
def from_cli(parameters: Dict):
model = model_from_cli(parameters['model'])
encoder = state_encoder_from_cli(parameters['encoder'])
trainer = rl_trainer_from_cli(parameters['trainer'])

is_model_distributed = parameters.get('is_model_distributed', True)

assert isinstance(model, NNMachineModel), f"Model must conform to NNModel"

return MARLMachine(model, encoder, trainer, is_model_distributed)
2 changes: 1 addition & 1 deletion diploma_thesis/agents/machine/model/multi_rule_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from agents.utils import DeepRule, NNCLI, ActionSelector
from .model import *
from .rule import ALL_SCHEDULING_RULES, SchedulingRule
from .rule import ALL_SCHEDULING_RULES, SchedulingRule, IdleSchedulingRule


class MultiRuleLinear(NNMachineModel, DeepRule):
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/machine/model/rule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from .atc import ATCSchedulingRule
from .avpro import AVPROSchedulingRule
from .covert import COVERTSchedulingRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, reduction_strategy: JobReductionStrategy = JobReductionStrate
self.reduction_strategy = reduction_strategy

def __call__(self, machine: 'Machine', now: float) -> Job | None:
value = self.criterion(machine, now)
value = self.criterion(machine=machine, now=now)
selector = self.selector
idx = selector(value)

Expand Down
7 changes: 6 additions & 1 deletion diploma_thesis/agents/utils/nn/nn_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def __make_layer__(self, input_dim, layer: Configuration.Layer):
raise ValueError(f"Unknown layer type {layer}")

def __make_linear_layer__(self, input_dim, output_dim, activation, dropout):
if isinstance(input_dim, torch.Size):
if len(input_dim) == 1:
input_dim = input_dim[0]
else:
raise ValueError(f"Input dim must be 1D tensor, got {input_dim}")

result = nn.Sequential(
nn.Linear(input_dim, output_dim)
)
Expand Down Expand Up @@ -164,4 +170,3 @@ def from_cli(parameters: dict) -> nn.Module:
configuration = NNCLI.Configuration.from_cli(parameters)

return NNCLI(configuration)

4 changes: 3 additions & 1 deletion diploma_thesis/agents/workcenter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from utils import from_cli
from .rl import RLWorkCenter
from .marl import MARLWorkCenter
from .static import StaticWorkCenter
from .utils import Input as WorkCenterInput

key_to_class = {
"static": StaticWorkCenter,
'rl': RLWorkCenter
'rl': RLWorkCenter,
'marl': MARLWorkCenter
}

from_cli = partial(from_cli, key_to_class=key_to_class)
26 changes: 26 additions & 0 deletions diploma_thesis/agents/workcenter/marl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict

from agents.base.marl_agent import MARLAgent
from agents.utils.rl import from_cli as rl_trainer_from_cli
from environment import WorkCenterKey, ShopFloor
from .model import NNWorkCenterModel, from_cli as model_from_cli
from .state import from_cli as state_encoder_from_cli


class MARLWorkCenter(MARLAgent[WorkCenterKey]):

def iterate_keys(self, shop_floor: ShopFloor):
for work_center in shop_floor.work_centers:
yield work_center.key

@staticmethod
def from_cli(parameters: Dict):
model = model_from_cli(parameters['model'])
encoder = state_encoder_from_cli(parameters['encoder'])
trainer = rl_trainer_from_cli(parameters['trainer'])

is_model_distributed = parameters.get('is_model_distributed', True)

assert isinstance(model, NNWorkCenterModel), f"Model must conform to NNModel"

return MARLWorkCenter(model, encoder, trainer, is_model_distributed)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class RoutingRule(metaclass=ABCMeta):
Selects a machine at random
"""
def __call__(self, job: Job, work_center: WorkCenter) -> Machine | None:
value = self.criterion(job, work_center)
value = self.criterion(job=job, work_center=work_center)
selector = self.selector
idx = selector(value)

Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/agents/workcenter/model/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, rule: RoutingRule):

def __call__(self, state: State, parameters: WorkCenterModel.Input) -> WorkCenterModel.Record:
return WorkCenterModel.Record(
result=self.rule(parameters.job, parameters.work_center),
result=self.rule(job=parameters.job, work_center=parameters.work_center),
state=state,
action=None
)
Expand Down
2 changes: 1 addition & 1 deletion diploma_thesis/configuration/jsp.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

task:
kind: 'multi_task'
n_workers: 8
n_workers: 10
debug: False

tasks:
Expand Down
11 changes: 7 additions & 4 deletions diploma_thesis/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,18 @@ def did_prepare_work_center_record(
# Agent

def schedule(self, context: Context, machine: Machine) -> Job | None:
parameters = MachineInput(machine, context.moment)
parameters = MachineInput(machine=machine, now=context.moment)

result = self.machine.schedule(parameters)
result = self.machine.schedule(machine.key, parameters)

if self.machine.is_trainable:
self.tape_model.register_machine_reward_preparation(context=context, machine=machine, record=result)

return result.result

def route(self, context: Context, work_center: WorkCenter, job: Job) -> 'Machine | None':
parameters = WorkCenterInput(work_center, job)
result = self.work_center.schedule(parameters)
parameters = WorkCenterInput(work_center=work_center, job=job)
result = self.work_center.schedule(work_center.key, parameters)

if self.work_center.is_trainable:
self.tape_model.register_work_center_reward_preparation(context=context,
Expand All @@ -235,6 +235,9 @@ def consume(simulation: Simulation):

simulation.prepare(self, self.tape_model, environment)

self.machine.setup(simulation.shop_floor)
self.work_center.setup(simulation.shop_floor)

if is_training:
self.tape_model.register(simulation.shop_floor)

Expand Down

0 comments on commit de943a2

Please sign in to comment.