Skip to content

Commit

Permalink
Remove TPUExecutor which relies on APIs that do not exist anymore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712854338
  • Loading branch information
mjanusz authored and copybara-github committed Jan 7, 2025
1 parent 12d680e commit 99ed8e3
Showing 1 changed file with 1 addition and 177 deletions.
178 changes: 1 addition & 177 deletions ffn/inference/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v1 as tf

from ..training import model as ffn_model
from . import inference_utils
from .inference_utils import timer_counter
Expand Down Expand Up @@ -338,183 +339,6 @@ def _schedule_batch(self, client_ids: Sequence[int], fetches: Sequence[str]):
pass


class TPUExecutor(ThreadingBatchExecutor):
"""ThreadingBatchExecutor for TF models on TPUs."""

CLIENT_ID_PAD = -1
CLIENT_ID_SHUTDOWN = -2

@property
def num_devices(self):
return self.num_tpus

def _tpu_loop(self, loop_op):
try:
self.session.run(loop_op)
except Exception as e:
logging.exception(e)
raise e
logging.info('TPU loop done.')

def _dispatcher(self, results_op):
"""Reads data from TPU and dispatches it to clients."""
try:
while True:
results = self.session.run(results_op)
ids, logits = results
if -2 in ids:
logging.info('Terminating dispatcher.')
return

logits = np.reshape(logits, self.model.logits.shape_as_list())
with self._interface.lock:
for i, client_id in enumerate(ids):
try:
self._interface.outputs[client_id].put({'logits': logits[i, ...]})
except KeyError:
# This could happen if a client unregistered itself
# while inference was running.
pass
except Exception as e:
logging.exception(e)
raise e

def _run_executor(self):
fs_loop, fs_dispatch = set(), set()
self._curr_infeed = 0

with futures.ThreadPoolExecutor(max_workers=self.num_tpus * 2) as e:
# Runs the TPU main loop in separate threads.
for loop_op in self.tpu_loop_ops:
fs_loop.add(e.submit(self._tpu_loop, loop_op))
# Gets results from the TPU and distributes them to the clients.
for results_op in self.tpu_outfeed_results:
fs_dispatch.add(e.submit(self._dispatcher, results_op))

super(TPUExecutor, self)._run_executor()
logging.info('TPU executor done. Shutting down.')

# TODO(mjanusz): Fix this awkward shut down procedure.
# Experiments show that the following alternatives do not work:
# - scheduling a single batch per TPU (no loops are actually terminated)
# - not scheduling more work on infeeds corresponding to loops that have
# terminated, assuming a 1:1 mapping between self.tpu_loop_ops and
# self.tpu_infeed_enquee_ops
while fs_loop:
for _ in range(self.num_tpus):
# Terminate the main loop on the TPU.
self._schedule_batch([TPUExecutor.CLIENT_ID_SHUTDOWN] *
self.batch_size, None)
logging.info('Scheduling termination request.')

fs_done, fs_loop = futures.wait(
fs_loop, timeout=1, return_when=futures.FIRST_COMPLETED)
for f in fs_done:
f.result()

# Check for exceptions.
for f in fs_dispatch:
f.result()
self.session.run(self.tpu_shutdown_system)
logging.info('TPU executor shutdown complete.')

def _schedule_batch_on_feed(self, client_ids, fetches, feed_id):
del fetches # TODO(mjanusz): Support this.
self.session.run(
self.tpu_infeed_enqueue_ops[feed_id],
{
# Pad client IDs to full batch size.
self.client_ids:
np.array(
client_ids + [TPUExecutor.CLIENT_ID_PAD] *
(self.batch_size - len(client_ids)),
dtype=np.int32),
self.plc_input_seed:
self.input_seed,
self.plc_input_patches:
self.input_image
})

def _schedule_batch(self, client_ids, fetches):
self._schedule_batch_on_feed(client_ids, fetches, self._curr_infeed)

# Distribute batches across available TPUs in a round-robin fashion.
self._curr_infeed += 1
self._curr_infeed %= len(self.tpu_infeed_enqueue_ops)

def _initialize_model(self):
self.tpu_initialize_system = tf.tpu.initialize_system()
self.tpu_shutdown_system = tf.tpu.shutdown_system()

# This returns the global_tpu_id, which we don't use yet.
self.session.run(self.tpu_initialize_system)
self.client_ids = tf.placeholder(dtype=tf.int32, shape=(self.batch_size,))

# Define infeeds.
infeed_placeholder_attrs = ('input_seed', 'input_patches')
infeed_placeholders = [self.client_ids] + [
getattr(self.model, attr) for attr in infeed_placeholder_attrs
]
tpu_infeed_queue = tf.contrib.tpu.InfeedQueue(
tuple_types=[t.dtype for t in infeed_placeholders],
tuple_shapes=[t.shape for t in infeed_placeholders],
)
num_tpus = sum(
dev.device_type == 'TPU' for dev in self.session.list_devices())
tpu_infeed_queue.set_number_of_shards(num_tpus)
logging.info('Found %d TPU cores.', num_tpus)

self.num_tpus = num_tpus
self.tpu_infeed_enqueue_ops = tpu_infeed_queue.generate_enqueue_ops(
[infeed_placeholders] * num_tpus, lambda x: x)

# Save the placeholders that we will feed on the host. These will
# be replaced by infeed dequeue ops below.
self.plc_input_seed = self.model.input_seed
self.plc_input_patches = self.model.input_patches

def loop_body(not_done):
"""Defines the graph that executes in a loop on the TPU."""
del not_done
inputs = tpu_infeed_queue.generate_dequeue_op()
client_ids = inputs[0]
inputs = inputs[1:]
for attr, iteration_input in zip(infeed_placeholder_attrs, inputs):
setattr(self.model, attr, iteration_input)

# Define the graph for the FFN model.
self.model.define_tf_graph()

# Flat shape (or at least removal of the channel dimension) is necessary
# for efficient outfeed as of Oct 2020).
self.flat_logits = tf.reshape(self.model.logits, [-1])
tpu_outfeed_enqueue_op = tf.contrib.tpu.outfeed_enqueue_tuple(
[client_ids, self.flat_logits])

with tf.control_dependencies([tpu_outfeed_enqueue_op]):
return tf.greater(client_ids[0:1], TPUExecutor.CLIENT_ID_SHUTDOWN)

def loop_condition(not_done):
return tf.identity(not_done[0], name='not_done_reduce')

# Note: this executes loop_body.
self.tpu_loop_ops = tf.tpu.replicate(
lambda: tf.while_loop( # pylint:disable=g-long-lambda
cond=loop_condition,
body=loop_body,
loop_vars=[tf.constant(True, dtype=tf.bool, shape=(1,))],
parallel_iterations=1),
inputs=[[]] * num_tpus)

self.tpu_outfeed_results = []
for i in range(num_tpus):
with tf.device('/device:TPU:%d' % i):
self.tpu_outfeed_results.append(
tf.contrib.tpu.outfeed_dequeue_tuple(
dtypes=[tf.int32, self.model.logits.dtype],
shapes=[self.client_ids.shape, self.flat_logits.shape]))


class JAXExecutor(ThreadingBatchExecutor):
"""ThreadingBatchExecutor for JAX models."""

Expand Down

0 comments on commit 99ed8e3

Please sign in to comment.