diff --git a/mypy.ini b/mypy.ini index 531337257b..4889e768a9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -73,7 +73,8 @@ disallow_untyped_defs = True disallow_any_expr = True [mypy-parsl.executors.high_throughput.interchange.*] -check_untyped_defs = True +disallow_untyped_defs = True +warn_unreachable = True [mypy-parsl.executors.extreme_scale.*] ignore_errors = True diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 6c4ca961ec..bfe238c4fa 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import multiprocessing import zmq import os import sys @@ -13,7 +14,7 @@ import threading import json -from typing import cast, Any, Dict, Set, Optional +from typing import cast, Any, Dict, NoReturn, Sequence, Set, Optional, Tuple from parsl.utils import setproctitle from parsl.version import VERSION as PARSL_VERSION @@ -36,23 +37,23 @@ class ManagerLost(Exception): ''' Task lost due to manager loss. Manager is considered lost when multiple heartbeats have been missed. ''' - def __init__(self, manager_id, hostname): + def __init__(self, manager_id: bytes, hostname: str) -> None: self.manager_id = manager_id self.tstamp = time.time() self.hostname = hostname - def __str__(self): + def __str__(self) -> str: return "Task failure due to loss of manager {} on host {}".format(self.manager_id.decode(), self.hostname) class VersionMismatch(Exception): ''' Manager and Interchange versions do not match ''' - def __init__(self, interchange_version, manager_version): + def __init__(self, interchange_version: str, manager_version: str): self.interchange_version = interchange_version self.manager_version = manager_version - def __str__(self): + def __str__(self) -> str: return "Manager version info {} does not match interchange version info {}, causing a critical failure".format( self.manager_version, self.interchange_version) @@ -67,17 +68,17 @@ class Interchange: 4. Service single and batch requests from workers """ def __init__(self, - client_address="127.0.0.1", + client_address: str = "127.0.0.1", interchange_address: Optional[str] = None, - client_ports=(50055, 50056, 50057), - worker_ports=None, - worker_port_range=(54000, 55000), - hub_address=None, - hub_port=None, - heartbeat_threshold=60, - logdir=".", - logging_level=logging.INFO, - poll_period=10, + client_ports: Tuple[int, int, int] = (50055, 50056, 50057), + worker_ports: Optional[Tuple[int, int]] = None, + worker_port_range: Tuple[int, int] = (54000, 55000), + hub_address: Optional[str] = None, + hub_port: Optional[int] = None, + heartbeat_threshold: int = 60, + logdir: str = ".", + logging_level: int = logging.INFO, + poll_period: int = 10, ) -> None: """ Parameters @@ -191,7 +192,7 @@ def __init__(self, logger.info("Platform info: {}".format(self.current_platform)) - def get_tasks(self, count): + def get_tasks(self, count: int) -> Sequence[dict]: """ Obtains a batch of tasks from the internal pending_task_queue Parameters @@ -216,7 +217,7 @@ def get_tasks(self, count): return tasks @wrap_with_logs(target="interchange") - def task_puller(self): + def task_puller(self) -> NoReturn: """Pull tasks from the incoming tasks zmq pipe onto the internal pending task queue """ @@ -237,7 +238,7 @@ def task_puller(self): task_counter += 1 logger.debug(f"Fetched {task_counter} tasks so far") - def _create_monitoring_channel(self): + def _create_monitoring_channel(self) -> Optional[zmq.Socket]: if self.hub_address and self.hub_port: logger.info("Connecting to monitoring") hub_channel = self.context.socket(zmq.DEALER) @@ -248,7 +249,7 @@ def _create_monitoring_channel(self): else: return None - def _send_monitoring_info(self, hub_channel, manager: ManagerRecord): + def _send_monitoring_info(self, hub_channel: Optional[zmq.Socket], manager: ManagerRecord) -> None: if hub_channel: logger.info("Sending message {} to hub".format(manager)) @@ -259,7 +260,7 @@ def _send_monitoring_info(self, hub_channel, manager: ManagerRecord): hub_channel.send_pyobj((MessageType.NODE_INFO, d)) @wrap_with_logs(target="interchange") - def _command_server(self): + def _command_server(self) -> NoReturn: """ Command server to run async command to the interchange """ logger.debug("Command Server Starting") @@ -326,7 +327,7 @@ def _command_server(self): continue @wrap_with_logs - def start(self): + def start(self) -> None: """ Start the interchange """ @@ -382,7 +383,7 @@ def start(self): logger.info("Processed {} tasks in {} seconds".format(self.count, delta)) logger.warning("Exiting") - def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill_event): + def process_task_outgoing_incoming(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket], kill_event: threading.Event) -> None: # Listen for requests for work if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN: logger.debug("starting task_outgoing section") @@ -448,7 +449,7 @@ def process_task_outgoing_incoming(self, interesting_managers, hub_channel, kill logger.error("Unexpected non-heartbeat message received from manager {}") logger.debug("leaving task_outgoing section") - def process_tasks_to_send(self, interesting_managers): + def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None: # If we had received any requests, check if there are tasks that could be passed logger.debug("Managers count (interesting/total): {interesting}/{total}".format( @@ -474,14 +475,14 @@ def process_tasks_to_send(self, interesting_managers): tids = [t['task_id'] for t in tasks] m['tasks'].extend(tids) m['idle_since'] = None - logger.debug("Sent tasks: {} to manager {}".format(tids, manager_id)) + logger.debug("Sent tasks: {} to manager {!r}".format(tids, manager_id)) # recompute real_capacity after sending tasks real_capacity = m['max_capacity'] - tasks_inflight if real_capacity > 0: - logger.debug("Manager {} has free capacity {}".format(manager_id, real_capacity)) + logger.debug("Manager {!r} has free capacity {}".format(manager_id, real_capacity)) # ... so keep it in the interesting_managers list else: - logger.debug("Manager {} is now saturated".format(manager_id)) + logger.debug("Manager {!r} is now saturated".format(manager_id)) interesting_managers.remove(manager_id) else: interesting_managers.remove(manager_id) @@ -490,7 +491,7 @@ def process_tasks_to_send(self, interesting_managers): else: logger.debug("either no interesting managers or no tasks, so skipping manager pass") - def process_results_incoming(self, interesting_managers, hub_channel): + def process_results_incoming(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None: # Receive any results and forward to client if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: logger.debug("entering results_incoming section") @@ -508,6 +509,12 @@ def process_results_incoming(self, interesting_managers, hub_channel): # process this for task ID and forward to executor b_messages.append((p_message, r)) elif r['type'] == 'monitoring': + # the monitoring code makes the assumption that no + # monitoring messages will be received if monitoring + # is not configured, and that hub_channel will only + # be None when monitoring is not configurated. + assert hub_channel is not None + hub_channel.send_pyobj(r['payload']) elif r['type'] == 'heartbeat': logger.debug(f"Manager {manager_id!r} sent heartbeat via results connection") @@ -552,7 +559,7 @@ def process_results_incoming(self, interesting_managers, hub_channel): interesting_managers.add(manager_id) logger.debug("leaving results_incoming section") - def expire_bad_managers(self, interesting_managers, hub_channel): + def expire_bad_managers(self, interesting_managers: Set[bytes], hub_channel: Optional[zmq.Socket]) -> None: bad_managers = [(manager_id, m) for (manager_id, m) in self._ready_managers.items() if time.time() - m['last_heartbeat'] > self.heartbeat_threshold] for (manager_id, m) in bad_managers: @@ -576,7 +583,7 @@ def expire_bad_managers(self, interesting_managers, hub_channel): interesting_managers.remove(manager_id) -def start_file_logger(filename, level=logging.DEBUG, format_string=None): +def start_file_logger(filename: str, level: int = logging.DEBUG, format_string: Optional[str] = None) -> None: """Add a stream log handler. Parameters @@ -608,7 +615,7 @@ def start_file_logger(filename, level=logging.DEBUG, format_string=None): @wrap_with_logs(target="interchange") -def starter(comm_q, *args, **kwargs): +def starter(comm_q: multiprocessing.Queue, *args: Any, **kwargs: Any) -> None: """Start the interchange process The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__