Skip to content

Commit

Permalink
Add read_json to be a part of master.yaml config (#1015)
Browse files Browse the repository at this point in the history
* Add read_json to be a part of master.yaml config

* Fix the mypy error & other failed tests

* Addressing reviews

* Add reas_json to restore in shelve function
  • Loading branch information
EmanElsaban authored Jan 2, 2025
1 parent cd6da8c commit 254f27e
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 24 deletions.
3 changes: 3 additions & 0 deletions tests/config/config_parse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def make_tron_config(
jobs=None,
mesos_options=None,
k8s_options=None,
read_json=False,
):
return schema.TronConfig(
action_runner=action_runner or {},
Expand All @@ -350,6 +351,7 @@ def make_tron_config(
jobs=jobs or make_master_jobs(),
mesos_options=mesos_options or make_mesos_options(),
k8s_options=k8s_options or make_k8s_options(),
read_json=read_json,
)


Expand Down Expand Up @@ -498,6 +500,7 @@ def test_attributes(self):
assert test_config.nodes == expected.nodes
assert test_config.node_pools == expected.node_pools
assert test_config.k8s_options == expected.k8s_options
assert test_config.read_json == expected.read_json
for key in ["0", "1", "2", "_actions_dict", "4", "_mesos"]:
job_name = f"MASTER.test_job{key}"
assert job_name in test_config.jobs, f"{job_name} in test_config.jobs"
Expand Down
12 changes: 8 additions & 4 deletions tests/serialize/runstate/statemanager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def setup_manager(self):
self.store.build_key.side_effect = lambda t, i: f"{t}{i}"
self.buffer = StateSaveBuffer(1)
self.manager = PersistentStateManager(self.store, self.buffer)
self.read_json = False

def test__init__(self):
assert_equal(self.manager._impl, self.store)
Expand Down Expand Up @@ -98,7 +99,7 @@ def test_restore(self):

restored_state = self.manager.restore(job_names)
assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_STATE, job_names),
mock.call(runstate.JOB_STATE, job_names, self.read_json),
]
assert len(mock_restore_runs.call_args_list) == 2
assert restored_state == {
Expand All @@ -120,7 +121,9 @@ def test_restore_runs_for_job(self):
]
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"])]
assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"], self.read_json)
]
assert runs == [{"job_name": "job_a", "run_num": 3}, {"job_name": "job_a", "run_num": 2}]

def test_restore_runs_for_job_one_missing(self):
Expand All @@ -134,7 +137,7 @@ def test_restore_runs_for_job_one_missing(self):
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"], self.read_json),
]
assert runs == [{"job_name": "job_a", "run_num": 3}]

Expand Down Expand Up @@ -218,6 +221,7 @@ def setup_watcher(self):
self.watcher = StateChangeWatcher()
self.state_manager = mock.create_autospec(PersistentStateManager)
self.watcher.state_manager = self.state_manager
self.read_json = False

def test_update_from_config_no_change(self):
self.watcher.config = state_config = mock.Mock()
Expand Down Expand Up @@ -263,7 +267,7 @@ def test_disabled(self):
def test_restore(self):
jobs = mock.Mock()
self.watcher.restore(jobs)
self.watcher.state_manager.restore.assert_called_with(jobs)
self.watcher.state_manager.restore.assert_called_with(jobs, self.read_json)

def test_handler_mesos_change(self):
self.watcher.handler(
Expand Down
2 changes: 2 additions & 0 deletions tron/config/config_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ class ValidateConfig(Validator):
"mesos_options": ConfigMesos(**ValidateMesos.defaults),
"k8s_options": ConfigKubernetes(**ValidateKubernetes.defaults),
"eventbus_enabled": None,
"read_json": False,
}
node_pools = build_dict_name_validator(valid_node_pool, allow_empty=True)
nodes = build_dict_name_validator(valid_node, allow_empty=True)
Expand All @@ -987,6 +988,7 @@ class ValidateConfig(Validator):
"mesos_options": valid_mesos_options,
"k8s_options": valid_kubernetes_options,
"eventbus_enabled": valid_bool,
"read_json": valid_bool,
}
optional = False

Expand Down
1 change: 1 addition & 0 deletions tron/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_dict(cls, data: Dict[str, Any]):
"mesos_options", # ConfigMesos
"k8s_options", # ConfigKubernetes
"eventbus_enabled", # bool or None
"read_json", # bool, default is False
],
)

Expand Down
7 changes: 6 additions & 1 deletion tron/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, working_dir, config_path, boot_time):
self.context = command_context.CommandContext()
self.state_watcher = statemanager.StateChangeWatcher()
self.boot_time = boot_time
self.read_json = False
current_time = time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime(boot_time))
log.info(f"Initialized. Tron started on {current_time}!")

Expand Down Expand Up @@ -109,6 +110,7 @@ def apply_config(self, config_container, reconfigure=False, namespace_to_reconfi
(MesosClusterRepository.configure, "mesos_options"),
(KubernetesClusterRepository.configure, "k8s_options"),
(self.configure_eventbus, "eventbus_enabled"),
(self.set_read_json, "read_json"),
]
master_config = config_container.get_master()
apply_master_configuration(master_config_directives, master_config)
Expand Down Expand Up @@ -161,6 +163,9 @@ def update_state_watcher_config(self, state_config):
def set_context_base(self, command_context):
self.context.base = command_context

def set_read_json(self, read_json):
self.read_json = read_json

def configure_eventbus(self, enabled):
if enabled:
if not EventBus.instance:
Expand All @@ -182,7 +187,7 @@ def restore_state(self, action_runner):
log.info("Restoring from DynamoDB")
with timer("restore"):
# restores the state of the jobs and their runs from DynamoDB
states = self.state_watcher.restore(self.jobs.get_names())
states = self.state_watcher.restore(self.jobs.get_names(), self.read_json)
log.info(
f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}"
)
Expand Down
8 changes: 1 addition & 7 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from typing import TypeVar

import boto3 # type: ignore
import staticconf # type: ignore

import tron.prom_metrics as prom_metrics
from tron.config.static_config import get_config_watcher
from tron.config.static_config import NAMESPACE
from tron.core.job import Job
from tron.core.jobrun import JobRun
from tron.metrics import timer
Expand Down Expand Up @@ -64,17 +61,14 @@ def build_key(self, type, iden) -> str:
"""
return f"{type} {iden}"

def restore(self, keys) -> dict:
def restore(self, keys, read_json: bool = False) -> dict:
"""
Fetch all under the same parition key(s).
ret: <dict of key to states>
"""
# format of the keys always passed here is
# job_state job_name --> high level info about the job: enabled, run_nums
# job_run_state job_run_name --> high level info about the job run
config_watcher = get_config_watcher()
config_watcher.reload_if_changed()
read_json = staticconf.read("read_json.enable", namespace=NAMESPACE, default=False)
first_items = self._get_first_partitions(keys)
remaining_items = self._get_remaining_partitions(first_items, read_json)
vals = self._merge_items(first_items, remaining_items, read_json)
Expand Down
2 changes: 1 addition & 1 deletion tron/serialize/runstate/shelvestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def save(self, key_value_pairs):
self.shelve[shelve_key] = state_data
self.shelve.sync()

def restore(self, keys):
def restore(self, keys, read_json: bool = False):
items = zip(
keys,
(self.shelve.get(str(key.key)) for key in keys),
Expand Down
22 changes: 11 additions & 11 deletions tron/serialize/runstate/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ def __init__(self, persistence_impl, buffer):
self._buffer = buffer
self._impl = persistence_impl

def restore(self, job_names):
def restore(self, job_names, read_json: bool = False):
"""Return the most recent serialized state."""
log.debug("Restoring state.")

# First, restore the jobs themselves
jobs = self._restore_dicts(runstate.JOB_STATE, job_names)
jobs = self._restore_dicts(runstate.JOB_STATE, job_names, read_json)
# jobs should be a dictionary that contains job name and number of runs
# {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}}

Expand All @@ -118,7 +118,7 @@ def restore(self, job_names):
# start the threads and mark each future with it's job name
# this is useful so that we can index the job name later to add the runs to the jobs dictionary
results = {
executor.submit(self._restore_runs_for_job, job_name, job_state): job_name
executor.submit(self._restore_runs_for_job, job_name, job_state, read_json): job_name
for job_name, job_state in jobs.items()
}
for result in concurrent.futures.as_completed(results):
Expand All @@ -133,13 +133,13 @@ def restore(self, job_names):
}
return state

def _restore_runs_for_job(self, job_name, job_state):
def _restore_runs_for_job(self, job_name, job_state, read_json: bool = False):
"""Restore the state for the runs of each job"""
run_nums = job_state["run_nums"]
keys = [jobrun.get_job_run_id(job_name, run_num) for run_num in run_nums]
job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys)
runs = copy.copy(job_runs_restored_states)
for run_id, state in runs.items():
job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys, read_json)
all_job_runs = copy.copy(job_runs_restored_states)
for run_id, state in all_job_runs.items():
if state == {}:
log.error(f"Failed to restore {run_id}, no state found for it!")
job_runs_restored_states.pop(run_id)
Expand All @@ -154,10 +154,10 @@ def _keys_for_items(self, item_type, names):
keys = (self._impl.build_key(item_type, name) for name in names)
return dict(zip(keys, names))

def _restore_dicts(self, item_type: str, items: List[str]) -> Dict[str, dict]:
def _restore_dicts(self, item_type: str, items: List[str], read_json: bool = False) -> Dict[str, dict]:
"""Return a dict mapping of the items name to its state data."""
key_to_item_map = self._keys_for_items(item_type, items)
key_to_state_map = self._impl.restore(key_to_item_map.keys())
key_to_state_map = self._impl.restore(key_to_item_map.keys(), read_json)
return {key_to_item_map[key]: state_data for key, state_data in key_to_state_map.items()}

def delete(self, type_enum, name):
Expand Down Expand Up @@ -288,5 +288,5 @@ def shutdown(self):
def disabled(self):
return self.state_manager.disabled()

def restore(self, jobs):
return self.state_manager.restore(jobs)
def restore(self, jobs, read_json: bool = False):
return self.state_manager.restore(jobs, read_json)

0 comments on commit 254f27e

Please sign in to comment.