Skip to content

Commit

Permalink
Implement DeepMARL direct state encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Feb 12, 2024
1 parent fdfbfa5 commit ca49722
Show file tree
Hide file tree
Showing 35 changed files with 317 additions and 52 deletions.
17 changes: 14 additions & 3 deletions diploma_thesis/agents/base/agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

import logging
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Generic

import torch

from agents.utils import Phase, EvaluationPhase
from agents.utils import Phase, EvaluationPhase, Loggable
from .encoder import Encoder as StateEncoder, Input, State
from .model import Model, Action, Result


class Agent(metaclass=ABCMeta):
class Agent(Loggable, metaclass=ABCMeta):

def __init__(self,
model: Model[Input, State, Action, Result],
Expand All @@ -18,6 +19,16 @@ def __init__(self,
self.model = model
self.memory = memory
self.phase = EvaluationPhase()
super().__init__()

def with_logger(self, logger: logging.Logger):
super().with_logger(logger)

for module in [self.memory, self.model, self.state_encoder]:
if isinstance(module, Loggable):
module.with_logger(logger)

return self

def update(self, phase: Phase):
self.phase = phase
Expand Down
5 changes: 3 additions & 2 deletions diploma_thesis/agents/base/encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import abstractmethod

from typing import TypeVar, Generic

from agents.utils import Loggable

Input = TypeVar('Input')
State = TypeVar('State')


class Encoder(Generic[Input, State]):
class Encoder(Loggable, Generic[Input, State]):

@abstractmethod
def encode(self, parameters: Input) -> State:
Expand Down
4 changes: 3 additions & 1 deletion diploma_thesis/agents/base/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TypeVar, Generic
from agents.utils import Loggable

State = TypeVar('State')
Input = TypeVar('Input')
Action = TypeVar('Action')
Result = TypeVar('Result')


class Model(Generic[Input, State, Action, Result], metaclass=ABCMeta):
class Model(Loggable, Generic[Input, State, Action, Result], metaclass=ABCMeta):

@dataclass
class Record:
Expand Down
1 change: 1 addition & 0 deletions diploma_thesis/agents/machine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging

from .utils import Input as MachineInput
from .machine import Machine
Expand Down
2 changes: 0 additions & 2 deletions diploma_thesis/agents/machine/machine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@

from abc import ABCMeta

from agents.base.agent import Agent


class Machine(Agent, metaclass=ABCMeta):

pass
1 change: 1 addition & 0 deletions diploma_thesis/agents/machine/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging

from .model import MachineModel
from .static import StaticModel as StaticMachineModel
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/machine/model/static.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging

from .model import MachineModel
from agents.machine.state import PlainEncoder
Expand All @@ -11,6 +12,7 @@ class StaticModel(MachineModel[PlainEncoder.State, None]):

def __init__(self, rule: SchedulingRule):
self.rule = rule
super().__init__()

def __call__(self, state: State, parameters: MachineModel.Input) -> MachineModel.Record:
return MachineModel.Record(
Expand Down
12 changes: 9 additions & 3 deletions diploma_thesis/agents/machine/state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging

from typing import Dict

from .encoder import StateEncoder
from .plain import PlainEncoder

from .deep_marl_direct import DEEPMARLDirectStateEncoder
from .deep_marl_indirect import DEEPMARLIndirectStateEncoder

key_to_class = {
"plain": PlainEncoder
"plain": PlainEncoder,
"deep_marl_direct": DEEPMARLDirectStateEncoder,
"deep_marl_indirect": DEEPMARLIndirectStateEncoder
}


def from_cli(parameters) -> StateEncoder:
def from_cli(parameters: Dict) -> StateEncoder:
cls = key_to_class[parameters['kind']]

return cls.from_cli(parameters.get('parameters', {}))
131 changes: 131 additions & 0 deletions diploma_thesis/agents/machine/state/deep_marl_direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
from dataclasses import dataclass
from typing import Dict

import torch

from environment import JobReductionStrategy, Job, Machine
from .encoder import StateEncoder


class DEEPMARLDirectStateEncoder(StateEncoder):
"""
Encoded state is a tensor of dimension (5, 5) where:
1. First 4 rows contain the following information:
1. Current operation processing time on machine
2. Next remaining processing time
3. Slack upon moment
4. Average waiting in next queue
5. Work center index
Depending on the number of jobs in queue, the state is represented in the following way:
1. If there are 0 jobs, then the state is a tensor of zeros
2. If there is 1 job, then the state is a tensor of shape (4, 5) where the first row repeated
3. If there are more than 1 job, then the information of job minimum values of first 4 criterias are stored
2. Arriving job info represents information of the job that is about to arrive at the machine
"""

@dataclass
class State:
state: torch.FloatTensor
job_idx: torch.LongTensor

def __init__(self, strategy: JobReductionStrategy = JobReductionStrategy.mean):
super().__init__()

self.reduction_strategy = strategy

def encode(self, parameters: StateEncoder.Input) -> State:
state, job_idx = self.__make_initial_state(parameters)
arriving_job_state, _ = self.__make_arriving_job_state__(parameters.machine, parameters.now)

state = torch.vstack([state, arriving_job_state])

return self.State(state, job_idx)

def __make_initial_state(self, parameters: StateEncoder.Input) -> torch.FloatTensor:
machine = parameters.machine
queue_size = machine.queue_size
state = None

match queue_size:
case 0:
return torch.zeros((4, 5)), torch.zeros(4)
case 1:
job = machine.queue[0]
state = self.__make_machine_state_for_single_job__(job, machine, parameters.now)
state = state.repeat(4, 1)
case _:
candidates = torch.vstack([
self.__make_machine_state_for_single_job__(job, machine, parameters.now)
for job in machine.queue
])
pool = candidates.clone()

state = []

for i in range(4):
is_no_candidates = pool.numel() == 0
store = candidates if is_no_candidates else pool

idx = torch.argmin(store[:, i])

state += [store[idx]]

if not is_no_candidates:
pool = pool[torch.arange(pool.size(0)) != idx]

state = torch.vstack(state)

return state[:, 1:], state[:, 0]

def __make_arriving_job_state__(self, machine: Machine, now: float) -> torch.FloatTensor:
arriving_jobs = [
job for job in machine.shop_floor.in_system_jobs
if job.next_work_center_idx == machine.work_center_idx and job.release_moment_on_machine is not None
]

average_waiting_time = machine.work_center.average_waiting_time

if len(arriving_jobs) == 0:
state = torch.FloatTensor([0, 0, 0, average_waiting_time, 0])

return state, None

job = min(arriving_jobs, key=lambda job: job.release_moment_on_machine)

wait_time = job.release_moment_on_machine - now

if wait_time < -1:
self.log_error(f"Arriving job release moment is in the past: {wait_time}")

wait_time = max(wait_time, 0)

state = [
job.next_operation_processing_time(self.reduction_strategy),
job.next_remaining_processing_time(self.reduction_strategy),
job.slack_upon_moment(now, self.reduction_strategy),
average_waiting_time,
wait_time
]

state = self.__to_list_of_tensors__(state)

return torch.hstack(state), job.id

def __make_machine_state_for_single_job__(self, job: Job, machine: Machine, now) -> torch.FloatTensor:
state = [
job.id,
job.current_operation_processing_time_on_machine,
job.next_remaining_processing_time(self.reduction_strategy),
job.slack_upon_moment(now, self.reduction_strategy),
machine.shop_floor.average_waiting_in_next_queue(job),
machine.work_center_idx
]

state = self.__to_list_of_tensors__(state)

return torch.hstack(state)

@staticmethod
def from_cli(parameters: Dict):
return DEEPMARLDirectStateEncoder()
18 changes: 18 additions & 0 deletions diploma_thesis/agents/machine/state/deep_marl_indirect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import Dict

from .encoder import StateEncoder


class DEEPMARLIndirectStateEncoder(StateEncoder):

@dataclass
class State:
pass

def encode(self, parameters: StateEncoder.Input) -> State:
return self.State()

@staticmethod
def from_cli(parameters: Dict):
return DEEPMARLIndirectStateEncoder()
11 changes: 9 additions & 2 deletions diploma_thesis/agents/machine/state/encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@

import logging
from abc import ABCMeta
from typing import List, TypeVar

import torch

from agents.base import Encoder
from typing import TypeVar
from agents.machine import MachineInput

State = TypeVar('State')
Expand All @@ -10,3 +13,7 @@
class StateEncoder(Encoder[MachineInput, State], metaclass=ABCMeta):

Input = MachineInput

@staticmethod
def __to_list_of_tensors__(parameters: List) -> List[torch.FloatTensor]:
return [parameter if torch.is_tensor(parameter) else torch.tensor(parameter) for parameter in parameters]
4 changes: 2 additions & 2 deletions diploma_thesis/agents/machine/state/plain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from .encoder import StateEncoder
from dataclasses import dataclass
from typing import Dict

from .encoder import StateEncoder


class PlainEncoder(StateEncoder):

Expand Down
5 changes: 2 additions & 3 deletions diploma_thesis/agents/machine/static.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

from typing import Dict

from .machine import Machine
from .model import MachineModel, from_cli as model_from_cli
from .state import StateEncoder, from_cli as state_encoder_from_cli
from typing import Dict


class StaticMachine(Machine):
Expand All @@ -23,4 +22,4 @@ def from_cli(parameters: Dict):
model = model_from_cli(parameters['model'])
encoder = state_encoder_from_cli(parameters['encoder'])

return StaticMachine(model=model, state_encoder=encoder)
return StaticMachine(model=model, state_encoder=encoder)
3 changes: 2 additions & 1 deletion diploma_thesis/agents/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@

from .phase import WarmUpPhase, TrainingPhase, EvaluationPhase, Phase
from .phase import WarmUpPhase, TrainingPhase, EvaluationPhase, Phase
from .loggable import Loggable
31 changes: 31 additions & 0 deletions diploma_thesis/agents/utils/loggable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import logging


class Loggable:

def __init__(self):
self.logger = None

def with_logger(self, logger: logging.Logger):
self.logger = logger.getChild(self.__class__.__name__)

return self

def log(self, message, level: str = 'info'):
if hasattr(self, 'logger'):
self.logger.log(level, message)
else:
print(message)

def log_debug(self, message):
self.log(message, logging.DEBUG)

def log_info(self, message):
self.log(message, logging.INFO)

def log_warning(self, message):
self.log(message, logging.WARNING)

def log_error(self, message):
self.log(message, logging.ERROR)
1 change: 1 addition & 0 deletions diploma_thesis/agents/workcenter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging

from .utils import Input as WorkCenterInput
from .work_center import WorkCenter
Expand Down
2 changes: 2 additions & 0 deletions diploma_thesis/agents/workcenter/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from .model import WorkCenterModel
from .static import StaticModel as StaticWorkCenterModel
from .rule import RoutingRule
Expand Down
5 changes: 3 additions & 2 deletions diploma_thesis/agents/workcenter/model/static.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict

from .model import WorkCenterModel
from agents.machine.state import PlainEncoder
from .model import WorkCenterModel
from .rule import RoutingRule, ALL_ROUTING_RULES
from typing import Dict


class StaticModel(WorkCenterModel[PlainEncoder.State, None]):
Expand All @@ -11,6 +11,7 @@ class StaticModel(WorkCenterModel[PlainEncoder.State, None]):

def __init__(self, rule: RoutingRule):
self.rule = rule
super().__init__()

def __call__(self, state: State, parameters: WorkCenterModel.Input) -> WorkCenterModel.Record:
return WorkCenterModel.Record(
Expand Down
Loading

0 comments on commit ca49722

Please sign in to comment.