generated from rochacbruno/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement DeepMARL direct state encoding
- Loading branch information
Showing
35 changed files
with
317 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import logging | ||
|
||
from .utils import Input as MachineInput | ||
from .machine import Machine | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,7 @@ | ||
|
||
from abc import ABCMeta | ||
|
||
from agents.base.agent import Agent | ||
|
||
|
||
class Machine(Agent, metaclass=ABCMeta): | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import logging | ||
|
||
from .model import MachineModel | ||
from .static import StaticModel as StaticMachineModel | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,20 @@ | ||
import logging | ||
|
||
from typing import Dict | ||
|
||
from .encoder import StateEncoder | ||
from .plain import PlainEncoder | ||
|
||
from .deep_marl_direct import DEEPMARLDirectStateEncoder | ||
from .deep_marl_indirect import DEEPMARLIndirectStateEncoder | ||
|
||
key_to_class = { | ||
"plain": PlainEncoder | ||
"plain": PlainEncoder, | ||
"deep_marl_direct": DEEPMARLDirectStateEncoder, | ||
"deep_marl_indirect": DEEPMARLIndirectStateEncoder | ||
} | ||
|
||
|
||
def from_cli(parameters) -> StateEncoder: | ||
def from_cli(parameters: Dict) -> StateEncoder: | ||
cls = key_to_class[parameters['kind']] | ||
|
||
return cls.from_cli(parameters.get('parameters', {})) |
131 changes: 131 additions & 0 deletions
131
diploma_thesis/agents/machine/state/deep_marl_direct.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
import torch | ||
|
||
from environment import JobReductionStrategy, Job, Machine | ||
from .encoder import StateEncoder | ||
|
||
|
||
class DEEPMARLDirectStateEncoder(StateEncoder): | ||
""" | ||
Encoded state is a tensor of dimension (5, 5) where: | ||
1. First 4 rows contain the following information: | ||
1. Current operation processing time on machine | ||
2. Next remaining processing time | ||
3. Slack upon moment | ||
4. Average waiting in next queue | ||
5. Work center index | ||
Depending on the number of jobs in queue, the state is represented in the following way: | ||
1. If there are 0 jobs, then the state is a tensor of zeros | ||
2. If there is 1 job, then the state is a tensor of shape (4, 5) where the first row repeated | ||
3. If there are more than 1 job, then the information of job minimum values of first 4 criterias are stored | ||
2. Arriving job info represents information of the job that is about to arrive at the machine | ||
""" | ||
|
||
@dataclass | ||
class State: | ||
state: torch.FloatTensor | ||
job_idx: torch.LongTensor | ||
|
||
def __init__(self, strategy: JobReductionStrategy = JobReductionStrategy.mean): | ||
super().__init__() | ||
|
||
self.reduction_strategy = strategy | ||
|
||
def encode(self, parameters: StateEncoder.Input) -> State: | ||
state, job_idx = self.__make_initial_state(parameters) | ||
arriving_job_state, _ = self.__make_arriving_job_state__(parameters.machine, parameters.now) | ||
|
||
state = torch.vstack([state, arriving_job_state]) | ||
|
||
return self.State(state, job_idx) | ||
|
||
def __make_initial_state(self, parameters: StateEncoder.Input) -> torch.FloatTensor: | ||
machine = parameters.machine | ||
queue_size = machine.queue_size | ||
state = None | ||
|
||
match queue_size: | ||
case 0: | ||
return torch.zeros((4, 5)), torch.zeros(4) | ||
case 1: | ||
job = machine.queue[0] | ||
state = self.__make_machine_state_for_single_job__(job, machine, parameters.now) | ||
state = state.repeat(4, 1) | ||
case _: | ||
candidates = torch.vstack([ | ||
self.__make_machine_state_for_single_job__(job, machine, parameters.now) | ||
for job in machine.queue | ||
]) | ||
pool = candidates.clone() | ||
|
||
state = [] | ||
|
||
for i in range(4): | ||
is_no_candidates = pool.numel() == 0 | ||
store = candidates if is_no_candidates else pool | ||
|
||
idx = torch.argmin(store[:, i]) | ||
|
||
state += [store[idx]] | ||
|
||
if not is_no_candidates: | ||
pool = pool[torch.arange(pool.size(0)) != idx] | ||
|
||
state = torch.vstack(state) | ||
|
||
return state[:, 1:], state[:, 0] | ||
|
||
def __make_arriving_job_state__(self, machine: Machine, now: float) -> torch.FloatTensor: | ||
arriving_jobs = [ | ||
job for job in machine.shop_floor.in_system_jobs | ||
if job.next_work_center_idx == machine.work_center_idx and job.release_moment_on_machine is not None | ||
] | ||
|
||
average_waiting_time = machine.work_center.average_waiting_time | ||
|
||
if len(arriving_jobs) == 0: | ||
state = torch.FloatTensor([0, 0, 0, average_waiting_time, 0]) | ||
|
||
return state, None | ||
|
||
job = min(arriving_jobs, key=lambda job: job.release_moment_on_machine) | ||
|
||
wait_time = job.release_moment_on_machine - now | ||
|
||
if wait_time < -1: | ||
self.log_error(f"Arriving job release moment is in the past: {wait_time}") | ||
|
||
wait_time = max(wait_time, 0) | ||
|
||
state = [ | ||
job.next_operation_processing_time(self.reduction_strategy), | ||
job.next_remaining_processing_time(self.reduction_strategy), | ||
job.slack_upon_moment(now, self.reduction_strategy), | ||
average_waiting_time, | ||
wait_time | ||
] | ||
|
||
state = self.__to_list_of_tensors__(state) | ||
|
||
return torch.hstack(state), job.id | ||
|
||
def __make_machine_state_for_single_job__(self, job: Job, machine: Machine, now) -> torch.FloatTensor: | ||
state = [ | ||
job.id, | ||
job.current_operation_processing_time_on_machine, | ||
job.next_remaining_processing_time(self.reduction_strategy), | ||
job.slack_upon_moment(now, self.reduction_strategy), | ||
machine.shop_floor.average_waiting_in_next_queue(job), | ||
machine.work_center_idx | ||
] | ||
|
||
state = self.__to_list_of_tensors__(state) | ||
|
||
return torch.hstack(state) | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return DEEPMARLDirectStateEncoder() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
from .encoder import StateEncoder | ||
|
||
|
||
class DEEPMARLIndirectStateEncoder(StateEncoder): | ||
|
||
@dataclass | ||
class State: | ||
pass | ||
|
||
def encode(self, parameters: StateEncoder.Input) -> State: | ||
return self.State() | ||
|
||
@staticmethod | ||
def from_cli(parameters: Dict): | ||
return DEEPMARLIndirectStateEncoder() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
|
||
from .phase import WarmUpPhase, TrainingPhase, EvaluationPhase, Phase | ||
from .phase import WarmUpPhase, TrainingPhase, EvaluationPhase, Phase | ||
from .loggable import Loggable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
import logging | ||
|
||
|
||
class Loggable: | ||
|
||
def __init__(self): | ||
self.logger = None | ||
|
||
def with_logger(self, logger: logging.Logger): | ||
self.logger = logger.getChild(self.__class__.__name__) | ||
|
||
return self | ||
|
||
def log(self, message, level: str = 'info'): | ||
if hasattr(self, 'logger'): | ||
self.logger.log(level, message) | ||
else: | ||
print(message) | ||
|
||
def log_debug(self, message): | ||
self.log(message, logging.DEBUG) | ||
|
||
def log_info(self, message): | ||
self.log(message, logging.INFO) | ||
|
||
def log_warning(self, message): | ||
self.log(message, logging.WARNING) | ||
|
||
def log_error(self, message): | ||
self.log(message, logging.ERROR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import logging | ||
|
||
from .utils import Input as WorkCenterInput | ||
from .work_center import WorkCenter | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.