Skip to content

Commit

Permalink
Add a demo notebook showing JAX FFN inference on LICONN data.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712831781
  • Loading branch information
mjanusz authored and copybara-github committed Jan 7, 2025
1 parent 03956b2 commit 45aee1d
Showing 1 changed file with 248 additions and 0 deletions.
248 changes: 248 additions & 0 deletions notebooks/jax_ffn_inference_liconn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# JAX FFN inference on LICONN data\n"
],
"metadata": {
"id": "Y-LMa0_zRyiE"
}
},
{
"cell_type": "code",
"source": [
"# Install the latest snapshot from the FFN repository.\n",
"!pip install git+https://github.com/google/ffn"
],
"metadata": {
"id": "UO9ixXAN7Hw-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
"\n",
"# Ensure tensorstore does not attempt to use GCE credentials\n",
"os.environ['GCE_METADATA_ROOT'] = 'metadata.google.internal.invalid'\n",
"\n",
"import tensorflow as tf\n",
"tf.config.set_visible_devices([], 'GPU')"
],
"metadata": {
"id": "P2BH-ACTDPgs"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from clu import checkpoint\n",
"from connectomics.jax.models import convstack\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import tensorstore as ts\n",
"\n",
"from ffn.inference import inference\n",
"from ffn.inference import inference_utils\n",
"from ffn.inference import inference_pb2\n",
"from ffn.inference import executor\n",
"from ffn.training import model as ffn_model"
],
"metadata": {
"id": "2j8v1nH_G9jh"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Check for GPU presence. If this fails, use \"Runtime > Change runtime type\".\n",
"assert jax.devices()[0].platform in ('gpu', 'tpu')"
],
"metadata": {
"id": "hLbhWzo1HNjW"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Use sample LICONN data (2x downsampled).\n",
"context = ts.Context({'cache_pool': {'total_bytes_limit': 1_000_000_000}})\n",
"img = ts.open({\n",
" 'driver': 'neuroglancer_precomputed',\n",
" 'kvstore': {'driver': 'gcs', 'bucket': 'liconn-public'},\n",
" 'path': 'ExPID82_1/image_230130b',\n",
" 'scale_index': 1},\n",
" read=True, context=context).result()[ts.d['channel'][0]]"
],
"metadata": {
"id": "Odkxn5nyMNMA"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load a 500^3 subvolume for local processing.\n",
"x0, y0, z0 = 1100, 1083, 111\n",
"raw = img[x0:x0+500, y0:y0+500, z0:z0+500].read().result()\n",
"raw = np.transpose(raw, [2, 1, 0])\n",
"raw = (raw.astype(np.float32) - 128.0) / 33. # normalize data for inference"
],
"metadata": {
"id": "dosI88JzMiR9"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"source": [
"plt.matshow(raw[250, :, :], cmap=plt.cm.Greys_r)"
],
"metadata": {
"id": "qa-JTB6UMjte"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load sample model checkpoint.\n",
"!gsutil cp gs://liconn-public/models/ffn/axde_59110972/ckpt-2116* .\n",
"\n",
"ckpt = checkpoint.Checkpoint('')\n",
"state = ckpt.load_state(state=None, checkpoint='ckpt-2116')"
],
"metadata": {
"id": "IlVzzzb0SYY2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Instantiate model for inference.\n",
"model = convstack.ResConvStack(convstack.ConvstackConfig(depth=20, padding='same', use_layernorm=True))\n",
"fov_size = 33, 33, 33\n",
"model_info = ffn_model.ModelInfo(deltas=(8, 8, 8), pred_mask_size=fov_size, input_seed_size=fov_size, input_image_size=fov_size)\n",
"\n",
"@jax.jit\n",
"def _apply_fn(data):\n",
" return model.apply({'params': state['params']}, data)\n",
"\n",
"iface = executor.ExecutorInterface()\n",
"counters = inference_utils.Counters()\n",
"exc = executor.JAXExecutor(iface, model_info, _apply_fn, counters, 1)\n",
"exc.start_server()"
],
"metadata": {
"id": "wugXTgtyS1aB"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"options = inference_pb2.InferenceOptions(init_activation=0.95, pad_value=0.5, move_threshold=0.6, segment_threshold=0.6)\n",
"cv = inference.Canvas(model_info, exc.get_client(counters), raw, options, voxel_size_zyx=(24, 18, 18))"
],
"metadata": {
"id": "dF3VZma3HSRq"
},
"execution_count": 39,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Trace a single neurite.\n",
"pos_xyz = (123, 171, 225)\n",
"cv.segment_at(pos_xyz[::-1], dynamic_image=inference.DynamicImage(), vis_update_every=10)"
],
"metadata": {
"id": "K9Mmd-k01L0V"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install neuroglancer"
],
"metadata": {
"id": "mGVmRSSY1PAO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Visualize results in neuroglancer.\n",
"import neuroglancer\n",
"from scipy.special import expit\n",
"from scipy.ndimage import label\n",
"seg = (label(cv.seed > 0)[0] == 1).astype(np.uint64)"
],
"metadata": {
"id": "VkJx8BAzZApX"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dimensions = neuroglancer.CoordinateSpace(\n",
" names=['x', 'y', 'z'],\n",
" units='nm',\n",
" scales=[18, 18, 24],\n",
")\n",
"viewer = neuroglancer.Viewer()\n",
"with viewer.txn() as s:\n",
" s.dimensions = dimensions\n",
" s.layers['raw'] = neuroglancer.ImageLayer(source=neuroglancer.LocalVolume(np.transpose((raw * 33 +128).astype(np.uint8), [2, 1, 0]), dimensions))\n",
" s.layers['seg'] = neuroglancer.SegmentationLayer(source=neuroglancer.LocalVolume(np.transpose(seg.astype(np.uint64), [2, 1, 0]), dimensions), segments=[1])\n",
"\n",
"viewer"
],
"metadata": {
"id": "tIWzzhnQHVGv"
},
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit 45aee1d

Please sign in to comment.