diff --git a/notebooks/jax_ffn_inference_liconn.ipynb b/notebooks/jax_ffn_inference_liconn.ipynb new file mode 100644 index 0000000..d69d5b2 --- /dev/null +++ b/notebooks/jax_ffn_inference_liconn.ipynb @@ -0,0 +1,248 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# JAX FFN inference on LICONN data\n" + ], + "metadata": { + "id": "Y-LMa0_zRyiE" + } + }, + { + "cell_type": "code", + "source": [ + "# Install the latest snapshot from the FFN repository.\n", + "!pip install git+https://github.com/google/ffn" + ], + "metadata": { + "id": "UO9ixXAN7Hw-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n", + "\n", + "# Ensure tensorstore does not attempt to use GCE credentials\n", + "os.environ['GCE_METADATA_ROOT'] = 'metadata.google.internal.invalid'\n", + "\n", + "import tensorflow as tf\n", + "tf.config.set_visible_devices([], 'GPU')" + ], + "metadata": { + "id": "P2BH-ACTDPgs" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from clu import checkpoint\n", + "from connectomics.jax.models import convstack\n", + "import jax\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorstore as ts\n", + "\n", + "from ffn.inference import inference\n", + "from ffn.inference import inference_utils\n", + "from ffn.inference import inference_pb2\n", + "from ffn.inference import executor\n", + "from ffn.training import model as ffn_model" + ], + "metadata": { + "id": "2j8v1nH_G9jh" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Check for GPU presence. If this fails, use \"Runtime > Change runtime type\".\n", + "assert jax.devices()[0].platform in ('gpu', 'tpu')" + ], + "metadata": { + "id": "hLbhWzo1HNjW" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Use sample LICONN data (2x downsampled).\n", + "context = ts.Context({'cache_pool': {'total_bytes_limit': 1_000_000_000}})\n", + "img = ts.open({\n", + " 'driver': 'neuroglancer_precomputed',\n", + " 'kvstore': {'driver': 'gcs', 'bucket': 'liconn-public'},\n", + " 'path': 'ExPID82_1/image_230130b',\n", + " 'scale_index': 1},\n", + " read=True, context=context).result()[ts.d['channel'][0]]" + ], + "metadata": { + "id": "Odkxn5nyMNMA" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load a 500^3 subvolume for local processing.\n", + "x0, y0, z0 = 1100, 1083, 111\n", + "raw = img[x0:x0+500, y0:y0+500, z0:z0+500].read().result()\n", + "raw = np.transpose(raw, [2, 1, 0])\n", + "raw = (raw.astype(np.float32) - 128.0) / 33. # normalize data for inference" + ], + "metadata": { + "id": "dosI88JzMiR9" + }, + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "plt.matshow(raw[250, :, :], cmap=plt.cm.Greys_r)" + ], + "metadata": { + "id": "qa-JTB6UMjte" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Load sample model checkpoint.\n", + "!gsutil cp gs://liconn-public/models/ffn/axde_59110972/ckpt-2116* .\n", + "\n", + "ckpt = checkpoint.Checkpoint('')\n", + "state = ckpt.load_state(state=None, checkpoint='ckpt-2116')" + ], + "metadata": { + "id": "IlVzzzb0SYY2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Instantiate model for inference.\n", + "model = convstack.ResConvStack(convstack.ConvstackConfig(depth=20, padding='same', use_layernorm=True))\n", + "fov_size = 33, 33, 33\n", + "model_info = ffn_model.ModelInfo(deltas=(8, 8, 8), pred_mask_size=fov_size, input_seed_size=fov_size, input_image_size=fov_size)\n", + "\n", + "@jax.jit\n", + "def _apply_fn(data):\n", + " return model.apply({'params': state['params']}, data)\n", + "\n", + "iface = executor.ExecutorInterface()\n", + "counters = inference_utils.Counters()\n", + "exc = executor.JAXExecutor(iface, model_info, _apply_fn, counters, 1)\n", + "exc.start_server()" + ], + "metadata": { + "id": "wugXTgtyS1aB" + }, + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "options = inference_pb2.InferenceOptions(init_activation=0.95, pad_value=0.5, move_threshold=0.6, segment_threshold=0.6)\n", + "cv = inference.Canvas(model_info, exc.get_client(counters), raw, options, voxel_size_zyx=(24, 18, 18))" + ], + "metadata": { + "id": "dF3VZma3HSRq" + }, + "execution_count": 39, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Trace a single neurite.\n", + "pos_xyz = (123, 171, 225)\n", + "cv.segment_at(pos_xyz[::-1], dynamic_image=inference.DynamicImage(), vis_update_every=10)" + ], + "metadata": { + "id": "K9Mmd-k01L0V" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install neuroglancer" + ], + "metadata": { + "id": "mGVmRSSY1PAO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Visualize results in neuroglancer.\n", + "import neuroglancer\n", + "from scipy.special import expit\n", + "from scipy.ndimage import label\n", + "seg = (label(cv.seed > 0)[0] == 1).astype(np.uint64)" + ], + "metadata": { + "id": "VkJx8BAzZApX" + }, + "execution_count": 48, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dimensions = neuroglancer.CoordinateSpace(\n", + " names=['x', 'y', 'z'],\n", + " units='nm',\n", + " scales=[18, 18, 24],\n", + ")\n", + "viewer = neuroglancer.Viewer()\n", + "with viewer.txn() as s:\n", + " s.dimensions = dimensions\n", + " s.layers['raw'] = neuroglancer.ImageLayer(source=neuroglancer.LocalVolume(np.transpose((raw * 33 +128).astype(np.uint8), [2, 1, 0]), dimensions))\n", + " s.layers['seg'] = neuroglancer.SegmentationLayer(source=neuroglancer.LocalVolume(np.transpose(seg.astype(np.uint64), [2, 1, 0]), dimensions), segments=[1])\n", + "\n", + "viewer" + ], + "metadata": { + "id": "tIWzzhnQHVGv" + }, + "execution_count": null, + "outputs": [] + } + ] +}