-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a demo notebook showing JAX FFN inference on LICONN data.
PiperOrigin-RevId: 712831781
- Loading branch information
1 parent
03956b2
commit 45aee1d
Showing
1 changed file
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"gpuType": "T4" | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# JAX FFN inference on LICONN data\n" | ||
], | ||
"metadata": { | ||
"id": "Y-LMa0_zRyiE" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Install the latest snapshot from the FFN repository.\n", | ||
"!pip install git+https://github.com/google/ffn" | ||
], | ||
"metadata": { | ||
"id": "UO9ixXAN7Hw-" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import os\n", | ||
"os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n", | ||
"\n", | ||
"# Ensure tensorstore does not attempt to use GCE credentials\n", | ||
"os.environ['GCE_METADATA_ROOT'] = 'metadata.google.internal.invalid'\n", | ||
"\n", | ||
"import tensorflow as tf\n", | ||
"tf.config.set_visible_devices([], 'GPU')" | ||
], | ||
"metadata": { | ||
"id": "P2BH-ACTDPgs" | ||
}, | ||
"execution_count": 2, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from clu import checkpoint\n", | ||
"from connectomics.jax.models import convstack\n", | ||
"import jax\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"import tensorstore as ts\n", | ||
"\n", | ||
"from ffn.inference import inference\n", | ||
"from ffn.inference import inference_utils\n", | ||
"from ffn.inference import inference_pb2\n", | ||
"from ffn.inference import executor\n", | ||
"from ffn.training import model as ffn_model" | ||
], | ||
"metadata": { | ||
"id": "2j8v1nH_G9jh" | ||
}, | ||
"execution_count": 3, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Check for GPU presence. If this fails, use \"Runtime > Change runtime type\".\n", | ||
"assert jax.devices()[0].platform in ('gpu', 'tpu')" | ||
], | ||
"metadata": { | ||
"id": "hLbhWzo1HNjW" | ||
}, | ||
"execution_count": 16, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Use sample LICONN data (2x downsampled).\n", | ||
"context = ts.Context({'cache_pool': {'total_bytes_limit': 1_000_000_000}})\n", | ||
"img = ts.open({\n", | ||
" 'driver': 'neuroglancer_precomputed',\n", | ||
" 'kvstore': {'driver': 'gcs', 'bucket': 'liconn-public'},\n", | ||
" 'path': 'ExPID82_1/image_230130b',\n", | ||
" 'scale_index': 1},\n", | ||
" read=True, context=context).result()[ts.d['channel'][0]]" | ||
], | ||
"metadata": { | ||
"id": "Odkxn5nyMNMA" | ||
}, | ||
"execution_count": 17, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Load a 500^3 subvolume for local processing.\n", | ||
"x0, y0, z0 = 1100, 1083, 111\n", | ||
"raw = img[x0:x0+500, y0:y0+500, z0:z0+500].read().result()\n", | ||
"raw = np.transpose(raw, [2, 1, 0])\n", | ||
"raw = (raw.astype(np.float32) - 128.0) / 33. # normalize data for inference" | ||
], | ||
"metadata": { | ||
"id": "dosI88JzMiR9" | ||
}, | ||
"execution_count": 34, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"plt.matshow(raw[250, :, :], cmap=plt.cm.Greys_r)" | ||
], | ||
"metadata": { | ||
"id": "qa-JTB6UMjte" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Load sample model checkpoint.\n", | ||
"!gsutil cp gs://liconn-public/models/ffn/axde_59110972/ckpt-2116* .\n", | ||
"\n", | ||
"ckpt = checkpoint.Checkpoint('')\n", | ||
"state = ckpt.load_state(state=None, checkpoint='ckpt-2116')" | ||
], | ||
"metadata": { | ||
"id": "IlVzzzb0SYY2" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Instantiate model for inference.\n", | ||
"model = convstack.ResConvStack(convstack.ConvstackConfig(depth=20, padding='same', use_layernorm=True))\n", | ||
"fov_size = 33, 33, 33\n", | ||
"model_info = ffn_model.ModelInfo(deltas=(8, 8, 8), pred_mask_size=fov_size, input_seed_size=fov_size, input_image_size=fov_size)\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def _apply_fn(data):\n", | ||
" return model.apply({'params': state['params']}, data)\n", | ||
"\n", | ||
"iface = executor.ExecutorInterface()\n", | ||
"counters = inference_utils.Counters()\n", | ||
"exc = executor.JAXExecutor(iface, model_info, _apply_fn, counters, 1)\n", | ||
"exc.start_server()" | ||
], | ||
"metadata": { | ||
"id": "wugXTgtyS1aB" | ||
}, | ||
"execution_count": 21, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"options = inference_pb2.InferenceOptions(init_activation=0.95, pad_value=0.5, move_threshold=0.6, segment_threshold=0.6)\n", | ||
"cv = inference.Canvas(model_info, exc.get_client(counters), raw, options, voxel_size_zyx=(24, 18, 18))" | ||
], | ||
"metadata": { | ||
"id": "dF3VZma3HSRq" | ||
}, | ||
"execution_count": 39, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Trace a single neurite.\n", | ||
"pos_xyz = (123, 171, 225)\n", | ||
"cv.segment_at(pos_xyz[::-1], dynamic_image=inference.DynamicImage(), vis_update_every=10)" | ||
], | ||
"metadata": { | ||
"id": "K9Mmd-k01L0V" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install neuroglancer" | ||
], | ||
"metadata": { | ||
"id": "mGVmRSSY1PAO" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Visualize results in neuroglancer.\n", | ||
"import neuroglancer\n", | ||
"from scipy.special import expit\n", | ||
"from scipy.ndimage import label\n", | ||
"seg = (label(cv.seed > 0)[0] == 1).astype(np.uint64)" | ||
], | ||
"metadata": { | ||
"id": "VkJx8BAzZApX" | ||
}, | ||
"execution_count": 48, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"dimensions = neuroglancer.CoordinateSpace(\n", | ||
" names=['x', 'y', 'z'],\n", | ||
" units='nm',\n", | ||
" scales=[18, 18, 24],\n", | ||
")\n", | ||
"viewer = neuroglancer.Viewer()\n", | ||
"with viewer.txn() as s:\n", | ||
" s.dimensions = dimensions\n", | ||
" s.layers['raw'] = neuroglancer.ImageLayer(source=neuroglancer.LocalVolume(np.transpose((raw * 33 +128).astype(np.uint8), [2, 1, 0]), dimensions))\n", | ||
" s.layers['seg'] = neuroglancer.SegmentationLayer(source=neuroglancer.LocalVolume(np.transpose(seg.astype(np.uint64), [2, 1, 0]), dimensions), segments=[1])\n", | ||
"\n", | ||
"viewer" | ||
], | ||
"metadata": { | ||
"id": "tIWzzhnQHVGv" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |