diff --git a/ffn/inference/executor.py b/ffn/inference/executor.py index 997ebc6..c0461c1 100644 --- a/ffn/inference/executor.py +++ b/ffn/inference/executor.py @@ -36,6 +36,7 @@ import jax.numpy as jnp import numpy as np import tensorflow.compat.v1 as tf + from ..training import model as ffn_model from . import inference_utils from .inference_utils import timer_counter @@ -338,183 +339,6 @@ def _schedule_batch(self, client_ids: Sequence[int], fetches: Sequence[str]): pass -class TPUExecutor(ThreadingBatchExecutor): - """ThreadingBatchExecutor for TF models on TPUs.""" - - CLIENT_ID_PAD = -1 - CLIENT_ID_SHUTDOWN = -2 - - @property - def num_devices(self): - return self.num_tpus - - def _tpu_loop(self, loop_op): - try: - self.session.run(loop_op) - except Exception as e: - logging.exception(e) - raise e - logging.info('TPU loop done.') - - def _dispatcher(self, results_op): - """Reads data from TPU and dispatches it to clients.""" - try: - while True: - results = self.session.run(results_op) - ids, logits = results - if -2 in ids: - logging.info('Terminating dispatcher.') - return - - logits = np.reshape(logits, self.model.logits.shape_as_list()) - with self._interface.lock: - for i, client_id in enumerate(ids): - try: - self._interface.outputs[client_id].put({'logits': logits[i, ...]}) - except KeyError: - # This could happen if a client unregistered itself - # while inference was running. - pass - except Exception as e: - logging.exception(e) - raise e - - def _run_executor(self): - fs_loop, fs_dispatch = set(), set() - self._curr_infeed = 0 - - with futures.ThreadPoolExecutor(max_workers=self.num_tpus * 2) as e: - # Runs the TPU main loop in separate threads. - for loop_op in self.tpu_loop_ops: - fs_loop.add(e.submit(self._tpu_loop, loop_op)) - # Gets results from the TPU and distributes them to the clients. - for results_op in self.tpu_outfeed_results: - fs_dispatch.add(e.submit(self._dispatcher, results_op)) - - super(TPUExecutor, self)._run_executor() - logging.info('TPU executor done. Shutting down.') - - # TODO(mjanusz): Fix this awkward shut down procedure. - # Experiments show that the following alternatives do not work: - # - scheduling a single batch per TPU (no loops are actually terminated) - # - not scheduling more work on infeeds corresponding to loops that have - # terminated, assuming a 1:1 mapping between self.tpu_loop_ops and - # self.tpu_infeed_enquee_ops - while fs_loop: - for _ in range(self.num_tpus): - # Terminate the main loop on the TPU. - self._schedule_batch([TPUExecutor.CLIENT_ID_SHUTDOWN] * - self.batch_size, None) - logging.info('Scheduling termination request.') - - fs_done, fs_loop = futures.wait( - fs_loop, timeout=1, return_when=futures.FIRST_COMPLETED) - for f in fs_done: - f.result() - - # Check for exceptions. - for f in fs_dispatch: - f.result() - self.session.run(self.tpu_shutdown_system) - logging.info('TPU executor shutdown complete.') - - def _schedule_batch_on_feed(self, client_ids, fetches, feed_id): - del fetches # TODO(mjanusz): Support this. - self.session.run( - self.tpu_infeed_enqueue_ops[feed_id], - { - # Pad client IDs to full batch size. - self.client_ids: - np.array( - client_ids + [TPUExecutor.CLIENT_ID_PAD] * - (self.batch_size - len(client_ids)), - dtype=np.int32), - self.plc_input_seed: - self.input_seed, - self.plc_input_patches: - self.input_image - }) - - def _schedule_batch(self, client_ids, fetches): - self._schedule_batch_on_feed(client_ids, fetches, self._curr_infeed) - - # Distribute batches across available TPUs in a round-robin fashion. - self._curr_infeed += 1 - self._curr_infeed %= len(self.tpu_infeed_enqueue_ops) - - def _initialize_model(self): - self.tpu_initialize_system = tf.tpu.initialize_system() - self.tpu_shutdown_system = tf.tpu.shutdown_system() - - # This returns the global_tpu_id, which we don't use yet. - self.session.run(self.tpu_initialize_system) - self.client_ids = tf.placeholder(dtype=tf.int32, shape=(self.batch_size,)) - - # Define infeeds. - infeed_placeholder_attrs = ('input_seed', 'input_patches') - infeed_placeholders = [self.client_ids] + [ - getattr(self.model, attr) for attr in infeed_placeholder_attrs - ] - tpu_infeed_queue = tf.contrib.tpu.InfeedQueue( - tuple_types=[t.dtype for t in infeed_placeholders], - tuple_shapes=[t.shape for t in infeed_placeholders], - ) - num_tpus = sum( - dev.device_type == 'TPU' for dev in self.session.list_devices()) - tpu_infeed_queue.set_number_of_shards(num_tpus) - logging.info('Found %d TPU cores.', num_tpus) - - self.num_tpus = num_tpus - self.tpu_infeed_enqueue_ops = tpu_infeed_queue.generate_enqueue_ops( - [infeed_placeholders] * num_tpus, lambda x: x) - - # Save the placeholders that we will feed on the host. These will - # be replaced by infeed dequeue ops below. - self.plc_input_seed = self.model.input_seed - self.plc_input_patches = self.model.input_patches - - def loop_body(not_done): - """Defines the graph that executes in a loop on the TPU.""" - del not_done - inputs = tpu_infeed_queue.generate_dequeue_op() - client_ids = inputs[0] - inputs = inputs[1:] - for attr, iteration_input in zip(infeed_placeholder_attrs, inputs): - setattr(self.model, attr, iteration_input) - - # Define the graph for the FFN model. - self.model.define_tf_graph() - - # Flat shape (or at least removal of the channel dimension) is necessary - # for efficient outfeed as of Oct 2020). - self.flat_logits = tf.reshape(self.model.logits, [-1]) - tpu_outfeed_enqueue_op = tf.contrib.tpu.outfeed_enqueue_tuple( - [client_ids, self.flat_logits]) - - with tf.control_dependencies([tpu_outfeed_enqueue_op]): - return tf.greater(client_ids[0:1], TPUExecutor.CLIENT_ID_SHUTDOWN) - - def loop_condition(not_done): - return tf.identity(not_done[0], name='not_done_reduce') - - # Note: this executes loop_body. - self.tpu_loop_ops = tf.tpu.replicate( - lambda: tf.while_loop( # pylint:disable=g-long-lambda - cond=loop_condition, - body=loop_body, - loop_vars=[tf.constant(True, dtype=tf.bool, shape=(1,))], - parallel_iterations=1), - inputs=[[]] * num_tpus) - - self.tpu_outfeed_results = [] - for i in range(num_tpus): - with tf.device('/device:TPU:%d' % i): - self.tpu_outfeed_results.append( - tf.contrib.tpu.outfeed_dequeue_tuple( - dtypes=[tf.int32, self.model.logits.dtype], - shapes=[self.client_ids.shape, self.flat_logits.shape])) - - class JAXExecutor(ThreadingBatchExecutor): """ThreadingBatchExecutor for JAX models."""