Skip to content

Commit

Permalink
Implement fixes for multirun simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
yura-hb committed Feb 24, 2024
1 parent 1a36059 commit 6174dcf
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 41 deletions.
10 changes: 6 additions & 4 deletions diploma_thesis/configuration/jsp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ task:
# Action Set
- [
'model/rules/all.yml',
'model/rules/marl_as.yml'
'model/rules/marl.yml'
]
# Model
- [
'model/model/marl_as.yml'
'model/model/marl_as.yml',
'model/model/marl_mr.yml'
]

Expand All @@ -98,8 +98,10 @@ task:
run:
parameters:
mods:
# - ['test.yml']
- [ 'multi.yml' ]
# - [ 'multi.yml', 'concurrent.yml' ]
# - [ 'util_70.yml' ]
# - [ 'util_80.yml' ]
# - [ 'util_90.yml' ]
# - [ 'multi.yml' ]
- [ 'multi.yml', 'concurrent.yml' ]

87 changes: 54 additions & 33 deletions diploma_thesis/workflow/multi_simulation.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,83 @@
import gc
import traceback
import time
from typing import Dict

import tqdm
from joblib import Parallel, delayed

from .simulation import Simulation
from utils.multi_value_cli import multi_value_cli
from typing import Dict
from joblib import Parallel, delayed
from .simulation import Simulation


def __run__(s: Dict):
s = Simulation(s)

print(f'Simulation started {s.parameters["name"]}')

start = time.time()

try:
s.run()
except Exception as e:
print(f'Error in simulation {s.parameters["name"]}: {e}')
print(traceback.format_exc())

print(f'Simulation finished {s.parameters["name"]}. Elapsed time: {time.time() - start} seconds.')

del s

gc.collect()


class MultiSimulation:

def __init__(self, parameters: Dict):
self.parameters = parameters

@property
def workflow_id(self) -> str:
return ''

def run(self):
simulations = self.__fetch_tasks__()
parameters = self.__fetch_tasks__()
parameters = self.__add_debug_info__(parameters)
parameters = self.__fix_names__(parameters)

self.__add_debug_info__(simulations)
self.__fix_names__(simulations)
print(f'Running {len(parameters)} simulations')

n_workers = self.parameters.get('n_workers', -1)

def __run__(s: Simulation):
try:
s.run()
except Exception as e:
print(f'Error in simulation {s.parameters["name"]}: {e}')

return s

iter = Parallel(
n_jobs=n_workers,
backend='loky',
return_as='generator',
prefer='processes',
)(delayed(lambda s: __run__(s))(s) for s in simulations)

for s in tqdm.tqdm(iter, total=len(simulations)):
print(f'Simulation finished {s.parameters["name"]}')
Parallel(
n_jobs=n_workers
)(delayed(__run__)(s) for s in parameters)

def __fetch_tasks__(self):
result: [Simulation] = []
result: [Dict] = []

for task in self.parameters['tasks']:
match self.parameters['kind']:
case 'task':
result += [Simulation(task['parameters'])]
result += [task['parameters']]
case 'multi_task':
result += multi_value_cli(task['parameters'], lambda p: Simulation(p))
result += multi_value_cli(task['parameters'], lambda p: p)
case _:
raise ValueError(f"Unknown kind: {self.parameters['kind']}")

return result

def __add_debug_info__(self, simulations: [Simulation]):
def __add_debug_info__(self, simulations: [Dict]):
result = simulations

if self.parameters.get('debug', False):
for simulation in simulations:
simulation.parameters['debug'] = True
for index, _ in enumerate(result):
result[index]['debug'] = True

return result

def __fix_names__(self, simulations: [Simulation]):
for i, simulation in enumerate(simulations):
simulation.parameters['name'] = f"{simulation.parameters['name']}_{i}"
def __fix_names__(self, simulations: [Dict]):
result = simulations

for i, simulation in enumerate(result):
result[i]['name'] = f"{simulation['name']}_{i}"

return result
6 changes: 6 additions & 0 deletions diploma_thesis/workflow/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(self, parameters: Dict):

self.parameters = parameters

@property
def workflow_id(self) -> str:
return self.parameters.get('name', '')

@property
def log_stdout(self):
return self.parameters.get('log_stdout', False)
Expand Down Expand Up @@ -152,6 +156,8 @@ def __to_dataframe__(data):
if data.batch_size == torch.Size([]):
return pd.DataFrame(columns=['shop_floor_id'])

data = data.to_dict()

return pd.DataFrame(data)

machine_reward = __to_dataframe__(reward_cache.machines)
Expand Down
4 changes: 4 additions & 0 deletions diploma_thesis/workflow/tournament.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class Tournament(Workflow):
def __init__(self, parameters: Dict):
self.parameters = parameters

@property
def workflow_id(self) -> str:
return ''

def run(self):
candidates = self.__make_candidates__()
criteria = self.__make_criteria__()
Expand Down
19 changes: 15 additions & 4 deletions diploma_thesis/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ class Workflow(metaclass=ABCMeta):
def run(self):
pass

@property
@abstractmethod
def workflow_id(self) -> str:
pass

def __make_logger__(self, name: str, filename: str = None, log_stdout: bool = False) -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger = self.__get_logger__(name)

formatter = logging.Formatter('%(asctime)s | | %(name)s | %(levelname)s | %(message)s')

Expand All @@ -39,8 +43,7 @@ def format(self, record):
record.time = str(time)
return super(_Formatter, self).format(record)

logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger = self.__get_logger__(name)

formatter = _Formatter('%(asctime)s | %(time)s | %(name)s | %(levelname)s | %(message)s')

Expand All @@ -67,3 +70,11 @@ def __add_handlers__(self, logger, formatter, filename: str, log_stdout: bool):
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)

def __get_logger__(self, name):
workflow_id = self.workflow_id

logger = logging.Logger(name + '_' + self.workflow_id if len(workflow_id) > 0 else name)
logger.setLevel(logging.INFO)

return logger

0 comments on commit 6174dcf

Please sign in to comment.