From 4bfc8e83346f1b039578db7e2a4a897d8f9d364a Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Tue, 17 Dec 2024 03:55:49 -0800 Subject: [PATCH] Make 3d random matrix generation compatible with tf.vector_map. PiperOrigin-RevId: 707032926 --- ffn/training/augmentation.py | 440 +++++++++++++++++++++++------------ 1 file changed, 285 insertions(+), 155 deletions(-) diff --git a/ffn/training/augmentation.py b/ffn/training/augmentation.py index 7c8087d..77250be 100644 --- a/ffn/training/augmentation.py +++ b/ffn/training/augmentation.py @@ -23,6 +23,7 @@ from skimage.transform import AffineTransform from skimage.transform import warp import tensorflow.compat.v1 as tf +import tensorflow.experimental.numpy as tnp import tensorflow.google.compat.v1 as tf from tf import transformations @@ -30,7 +31,8 @@ def standard_rotation_matrix( - mask: tf.Tensor, voxel_size: tuple[float, float, float]) -> tf.Operation: + mask: tf.Tensor, voxel_size: tuple[float, float, float] +) -> tf.Operation: """Computes a rotation matrix to put an object into a standard orientation. In the standard orientation, the axis of highest variance is 'z', and the @@ -51,40 +53,87 @@ def _compute_rot_mtx(mask, voxel_size=voxel_size): return mtx.astype(np.float32) ret = tf.py_func( - func=_compute_rot_mtx, inp=[mask], Tout=tf.float32, name='std_rot_mtx') + func=_compute_rot_mtx, inp=[mask], Tout=tf.float32, name='std_rot_mtx' + ) ret.set_shape([3, 3]) return ret +def random_quaternion(uniform_variate: tf.Tensor) -> tf.Operation: + """Returns a uniform random unit quaternion. + + TF version of transformations.random_quaternion. + + Args: + uniform_variate: float tensor with variates from U[0, 1]; shape [3] + + Returns: + 4-element tensor representing the random quaternion + """ + r1 = tf.math.sqrt(1.0 - uniform_variate[0]) + r2 = tf.math.sqrt(uniform_variate[0]) + pi2 = np.pi * 2.0 + t1 = pi2 * uniform_variate[1] + t2 = pi2 * uniform_variate[2] + return tf.stack(( + tf.math.sin(t1) * r1, + tf.math.cos(t1) * r1, + tf.math.sin(t2) * r2, + tf.math.cos(t2) * r2, + )) + + +@tf.function def random_3d_rotation_matrix( - uniform_variate: Optional[tf.Tensor] = None) -> tf.Operation: + uniform_variate: tf.Tensor | None = None, +) -> tf.Operation: """Computes a random 3d rotation matrix. + TF version of transformations.random_rotation_matrix, modified to compute + a 3x3 matrix. + Args: uniform_variate: optional float tensor with variates from U[0, 1]; shape [3] Returns: 3x3 tensor representing the random rotation matrix """ - - def _random_3d_rot_mtx(var): - return transformations.random_rotation_matrix(var).astype( - np.float32)[:3, :3] - if uniform_variate is None: uniform_variate = tf.random.uniform([3], 0, 1) - ret = tf.py_func( - func=_random_3d_rot_mtx, - inp=[uniform_variate], - Tout=tf.float32, - name='rand_3d_rot_mtx') - ret.set_shape([3, 3]) - return ret + q = random_quaternion(uniform_variate) + nq = tnp.dot(q, q) + if nq < transformations._EPS: # pylint:disable=protected-access + return tf.eye(3) + + q = q * tf.math.sqrt(2.0 / nq) + q = tnp.outer(q, q) + + return tf.stack(( + tf.cast( + tf.stack( + (1.0 - q[1, 1] - q[2, 2], q[0, 1] - q[2, 3], q[0, 2] + q[1, 3]) + ), + tf.float32, + ), + tf.cast( + tf.stack( + (q[0, 1] + q[2, 3], 1.0 - q[0, 0] - q[2, 2], q[1, 2] - q[0, 3]) + ), + tf.float32, + ), + tf.cast( + tf.stack( + (q[0, 2] - q[1, 3], q[1, 2] + q[0, 3], 1.0 - q[0, 0] - q[1, 1]) + ), + tf.float32, + ), + )) def random_2d_rotation_matrix( - uniform_variate: Optional[tf.Tensor] = None) -> tf.Operation: + uniform_variate: Optional[tf.Tensor] = None, +) -> tf.Operation: """Computes a matrix for a random rotation around the 'z' axis. Args: @@ -97,8 +146,9 @@ def random_2d_rotation_matrix( def _random_2d_rot_mtx(var): angle = var * 2 * np.pi - return transformations.rotation_matrix(angle, - [0, 0, 1]).astype(np.float32)[:3, :3] + return transformations.rotation_matrix(angle, [0, 0, 1]).astype(np.float32)[ + :3, :3 + ] if uniform_variate is None: uniform_variate = tf.random.uniform([1], 0, 1) @@ -107,7 +157,8 @@ def _random_2d_rot_mtx(var): func=_random_2d_rot_mtx, inp=[uniform_variate], Tout=tf.float32, - name='rand_2d_rot_mtx') + name='rand_2d_rot_mtx', + ) ret.set_shape([3, 3]) return ret @@ -115,7 +166,8 @@ def _random_2d_rot_mtx(var): def input_size_for_rotated_output( desired_size: tuple[int, int, int], in_voxel_size: tuple[float, float, float], - out_voxel_size: Optional[tuple[float, float, float]] = None) -> list[int]: + out_voxel_size: Optional[tuple[float, float, float]] = None, +) -> list[int]: """Computes the input size necessary for a given output size. The input size is computed to be large enough so that if an arbitrary @@ -137,11 +189,13 @@ def input_size_for_rotated_output( return np.ceil(2.0 * phys_r / in_voxel_size).astype(int).tolist() -def apply_rotation(data: tf.Tensor, - rotation_matrix: tf.Tensor, - in_voxel_size: tuple[float, float, float], - out_voxel_size: Optional[tuple[float, float, float]] = None, - interpolation='nearest') -> tf.Operation: +def apply_rotation( + data: tf.Tensor, + rotation_matrix: tf.Tensor, + in_voxel_size: tuple[float, float, float], + out_voxel_size: Optional[tuple[float, float, float]] = None, + interpolation='nearest', +) -> tf.Operation: """Applies a rotation to a tensor of volumetric data. Args: @@ -178,7 +232,8 @@ def apply_rotation(data: tf.Tensor, tf.range(0, out_diam_vx[2]), tf.range(0, out_diam_vx[1]), tf.range(0, out_diam_vx[0]), - indexing='ij') + indexing='ij', + ) # Convert back to physical coordinates. Shift by half a voxel since the grid # coordinates are assumed to correspond to the center of the voxel. @@ -202,10 +257,8 @@ def apply_rotation(data: tf.Tensor, # M^T.new = old // left multiply and use orthogonality of M # new^T.M = old^T // transpose points = tf.stack( - [tf.reshape(hx, [-1]), - tf.reshape(hy, [-1]), - tf.reshape(hz, [-1])], - axis=1) + [tf.reshape(hx, [-1]), tf.reshape(hy, [-1]), tf.reshape(hz, [-1])], axis=1 + ) phys_coords = tf.matmul(points, rotation_matrix) # -0.5 because the origin of the physical coordinate system is a half a voxel @@ -222,7 +275,8 @@ def apply_rotation(data: tf.Tensor, data[0, ...], orig_coords, padding_constant=[], - interpolation=interpolation) + interpolation=interpolation, + ) rotated = tf.cast(rotated, data.dtype) return tf.reshape(rotated, [1] + rotated.shape.as_list()) @@ -232,8 +286,8 @@ def reflection(data, decision): Args: data: input tensor, shape: [..], z, y, x, c - decision: boolean tensor, shape 3, indicating on which spatial dimensions - to apply the reflection (x, y, z) + decision: boolean tensor, shape 3, indicating on which spatial dimensions to + apply the reflection (x, y, z) Returns: TF op to conditionally apply reflection. @@ -259,9 +313,7 @@ def xy_transpose(data, decision): rank = data.get_shape().ndims perm = list(range(rank)) perm[rank - 3], perm[rank - 2] = perm[rank - 2], perm[rank - 3] - return tf.cond(decision, - lambda: tf.transpose(data, perm), - lambda: data) + return tf.cond(decision, lambda: tf.transpose(data, perm), lambda: data) def permute_axes(x, permutation, permutable_axes): @@ -323,9 +375,7 @@ def random_contrast_brightness_adjustment( contrast_factor = tf.random.uniform([], min_contrast, max_contrast) adjust_tensor = tf.image.adjust_contrast(adjust_tensor, contrast_factor) if brightness_factor_range: - min_delta_factor, max_delta_factor = ( - brightness_factor_range - ) + min_delta_factor, max_delta_factor = brightness_factor_range delta_factor = tf.random.uniform([], min_delta_factor, max_delta_factor) adjust_tensor = tf.image.adjust_brightness( adjust_tensor, delta=delta_factor @@ -364,8 +414,14 @@ class PermuteAndReflect: `permutable_axes` are identity mapped. """ - def __init__(self, rank, permutable_axes, reflectable_axes, - permutation_seed=None, reflection_seed=None): + def __init__( + self, + rank, + permutable_axes, + reflectable_axes, + permutation_seed=None, + reflection_seed=None, + ): """Initializes the transformation nodes. Args: @@ -376,6 +432,7 @@ def __init__(self, rank, permutable_axes, reflectable_axes, permutation. reflection_seed: Optional integer. Seed value to use for sampling reflection decisions. + Raises: ValueError: if arguments are invalid. """ @@ -393,14 +450,18 @@ def __init__(self, rank, permutable_axes, reflectable_axes, self.reflectable_axes = np.array(reflectable_axes, dtype=np.int32) if self.reflectable_axes.size > 0: - self.reflect_decisions = tf.random_uniform([len(self.reflectable_axes)], - seed=reflection_seed) > 0.5 - self.reflected_axes = tf.boolean_mask(self.reflectable_axes, - self.reflect_decisions) + self.reflect_decisions = ( + tf.random_uniform([len(self.reflectable_axes)], seed=reflection_seed) + > 0.5 + ) + self.reflected_axes = tf.boolean_mask( + self.reflectable_axes, self.reflect_decisions + ) if self.permutable_axes.size > 0: - self.permutation = tf.random_shuffle(self.permutable_axes, - seed=permutation_seed) + self.permutation = tf.random_shuffle( + self.permutable_axes, seed=permutation_seed + ) # full_permutation must be a list rather than an np.array of int32 because # some elements are set to be tensors below. full_permutation = [np.int32(x) for x in range(rank)] @@ -427,8 +488,9 @@ def __call__(self, x): return x -def warp_transform_size_factor(deformation_stdev_ratio, rotation_max, scale_max, - shear_max): +def warp_transform_size_factor( + deformation_stdev_ratio, rotation_max, scale_max, shear_max +): """Estimates max patch size factor for affine transform and warping. Uses linear estimations of extra data needed for rotation and @@ -455,13 +517,13 @@ def warp_transform_size_factor(deformation_stdev_ratio, rotation_max, scale_max, rotation_factor = min(np.pi / 4, rotation_max) / (np.pi / 4) shear_factor = min(np.pi, shear_max) / np.pi return 1 + ( - deformation_stdev_ratio + rotation_factor + scale_max + shear_factor) + deformation_stdev_ratio + rotation_factor + scale_max + shear_factor + ) -def _elastic_warp_2d(patch, - num_control_points_ratio, - deformation_stdev_ratio, - mode='reflect'): +def _elastic_warp_2d( + patch, num_control_points_ratio, deformation_stdev_ratio, mode='reflect' +): """Applies 2D elastic deformation to all y,x slices of patch. The same deformation is applied separately at each pair of @@ -488,24 +550,22 @@ def _elastic_warp_2d(patch, deformation_stdev = deformation_stdev_ratio * np.min(patch.shape) deformations = np.random.normal(0, deformation_stdev, coords.shape) deformed_coords = coords + deformations - grid_y, grid_x = np.mgrid[0:patch.shape[1], 0:patch.shape[2]] + grid_y, grid_x = np.mgrid[0 : patch.shape[1], 0 : patch.shape[2]] grid = griddata( - coords, deformed_coords, (grid_y, grid_x), method='cubic', fill_value=0) + coords, deformed_coords, (grid_y, grid_x), method='cubic', fill_value=0 + ) warped_patch = np.zeros(patch.shape, dtype=patch.dtype) for b in range(patch.shape[0]): for c in range(patch.shape[3]): warped_patch[b, :, :, c] = warp( - patch[b, :, :, c], - np.array((grid[:, :, 0], grid[:, :, 1])), - mode=mode) + patch[b, :, :, c], np.array((grid[:, :, 0], grid[:, :, 1])), mode=mode + ) return warped_patch -def _affine_transform_2d(patch, - rotation_max, - scale_max, - shear_max, - mode='reflect'): +def _affine_transform_2d( + patch, rotation_max, scale_max, shear_max, mode='reflect' +): """Applies 2D affine transformation to all y,x slices of patch. The same transform is applied separately at each pair of @@ -558,12 +618,14 @@ def _apply_at_random_z_indices(patch, fn, max_indices_ratio): return patch, z_indices -def elastic_warp(patch, - max_indices_ratio, - num_control_points_ratio, - deformation_stdev_ratio, - skip_ratio=0, - mode='reflect'): +def elastic_warp( + patch, + max_indices_ratio, + num_control_points_ratio, + deformation_stdev_ratio, + skip_ratio=0, + mode='reflect', +): """Applies elastic deformation to selected z indices of patch. Args: @@ -590,18 +652,21 @@ def elastic_warp(patch, def warp_function(p): return _elastic_warp_2d( - p, num_control_points_ratio, deformation_stdev_ratio, mode=mode) + p, num_control_points_ratio, deformation_stdev_ratio, mode=mode + ) return _apply_at_random_z_indices(patch, warp_function, max_indices_ratio) -def affine_transform(patch, - max_indices_ratio, - rotation_max, - scale_max, - shear_max, - skip_ratio=0, - mode='reflect'): +def affine_transform( + patch, + max_indices_ratio, + rotation_max, + scale_max, + shear_max, + skip_ratio=0, + mode='reflect', +): """Applies affine transform to selected z indices of patch. Args: @@ -628,10 +693,12 @@ def affine_transform(patch, def transform_function(p): return _affine_transform_2d( - p, rotation_max, scale_max, shear_max, mode=mode) + p, rotation_max, scale_max, shear_max, mode=mode + ) - return _apply_at_random_z_indices(patch, transform_function, - max_indices_ratio) + return _apply_at_random_z_indices( + patch, transform_function, max_indices_ratio + ) def _center_crop(patch, zyx_cropped_shape): @@ -648,7 +715,7 @@ def _center_crop(patch, zyx_cropped_shape): assert np.all(diff >= 0) start = diff // 2 end = patch.shape[1:-1] - np.ceil(diff / 2.0).astype(int) - return patch[:, start[0]:end[0], start[1]:end[1], start[2]:end[2], :] + return patch[:, start[0] : end[0], start[1] : end[1], start[2] : end[2], :] def _edge_pad(patch, zyx_padded_shape, mode='edge'): @@ -669,15 +736,17 @@ def _edge_pad(patch, zyx_padded_shape, mode='edge'): return np.pad(patch, pad, mode) -def misalignment(patch, - labels, - mask, - patch_final_zyx, - labels_final_zyx, - mask_final_zyx, - max_offset, - slip_ratio, - skip_ratio=0): +def misalignment( + patch, + labels, + mask, + patch_final_zyx, + labels_final_zyx, + mask_final_zyx, + max_offset, + slip_ratio, + skip_ratio=0, +): """Performs slip and translation misalignment augmentations. Patch, labels, and mask inputs are first edge padded to the same size @@ -706,16 +775,20 @@ def misalignment(patch, """ patch, labels, mask = patch.copy(), labels.copy(), mask.copy() if np.random.rand() < skip_ratio: - return (_center_crop(patch, patch_final_zyx), - _center_crop(labels, labels_final_zyx), - _center_crop(mask, mask_final_zyx), -1) + return ( + _center_crop(patch, patch_final_zyx), + _center_crop(labels, labels_final_zyx), + _center_crop(mask, mask_final_zyx), + -1, + ) - zyx_max_shape = np.array([patch.shape, labels.shape, - mask.shape]).max(axis=0)[1:-1] + zyx_max_shape = np.array([patch.shape, labels.shape, mask.shape]).max(axis=0)[ + 1:-1 + ] padded_data = [ _edge_pad(patch, zyx_max_shape), _edge_pad(labels, zyx_max_shape), - _edge_pad(mask, zyx_max_shape) + _edge_pad(mask, zyx_max_shape), ] offset_y, offset_x = np.random.randint(-max_offset, max_offset + 1, 2) @@ -744,8 +817,8 @@ def _quadrant_replace(patch, z, replacement, quadrant_prob): Args: patch: input 5D numpy array, [b, z, y, x, c] patch is modified in place z: z index on which to replace x,y quadrants - replacement: 4D numpy array containing replacement values, [b, y, x, c] - same shape as patch[:, z, :, :, :] + replacement: 4D numpy array containing replacement values, [b, y, x, c] same + shape as patch[:, z, :, :, :] quadrant_prob: probability that values in each quadrant are replaced """ apply_quadrants = np.random.rand(4) < quadrant_prob @@ -761,13 +834,15 @@ def _quadrant_replace(patch, z, replacement, quadrant_prob): patch[:, z, y:, x:, :] = replacement[:, y:, x:, :] -def missing_section(patch, - max_indices_ratio, - skip_ratio=0, - fill_value=None, - max_fill_val=256, - full_prob=0.5, - quadrant_prob=0.5): +def missing_section( + patch, + max_indices_ratio, + skip_ratio=0, + fill_value=None, + max_fill_val=256, + full_prob=0.5, + quadrant_prob=0.5, +): """Performs missing section augmentation. All values in randomly selected x,y quadrants of randomly @@ -798,7 +873,8 @@ def missing_section(patch, num_indices = np.random.randint(1, max_indices + 1) z_indices = np.random.choice(patch.shape[1], num_indices, replace=False) fill_val = ( - fill_value if fill_value is not None else np.random.rand() * max_fill_val) + fill_value if fill_value is not None else np.random.rand() * max_fill_val + ) fill_array = np.full(patch[:, 0, :, :, :].shape, fill_val, patch.dtype) for z in z_indices: if np.random.rand() < full_prob: @@ -808,12 +884,14 @@ def missing_section(patch, return patch, z_indices -def out_of_focus_section(patch, - max_indices_ratio, - max_filter_stdev, - skip_ratio=0, - full_prob=0.5, - quadrant_prob=0.5): +def out_of_focus_section( + patch, + max_indices_ratio, + max_filter_stdev, + skip_ratio=0, + full_prob=0.5, + quadrant_prob=0.5, +): """Applies out-of-focus-section augmentation. A Gaussian blur is applied to all values in randomly selected x,y @@ -851,12 +929,14 @@ def out_of_focus_section(patch, return patch, z_indices -def grayscale_perturb(patch, - max_contrast_factor, - max_brightness_factor, - skip_ratio=0, - max_val=255, - full_prob=0.5): +def grayscale_perturb( + patch, + max_contrast_factor, + max_brightness_factor, + skip_ratio=0, + max_val=255, + full_prob=0.5, +): """Applies brightness/contrast adjustment and gamma correction. Grayscale perturbation factors are chosen once for the @@ -892,10 +972,10 @@ def grayscale_perturb(patch, def perturb_fn(patch): contrast_factor = 1 + (np.random.rand() - 0.5) * max_contrast_factor brightness_factor = (np.random.rand() - 0.5) * max_brightness_factor - power = 2.0**(np.random.rand() * 2 - 1) + power = 2.0 ** (np.random.rand() * 2 - 1) normalized = patch.astype(np.float32) / max_val adjusted = normalized * contrast_factor + brightness_factor - gamma = np.clip(adjusted, 0, 1)**power + gamma = np.clip(adjusted, 0, 1) ** power rescaled = (gamma * max_val).astype(patch.dtype) return rescaled @@ -908,14 +988,33 @@ def perturb_fn(patch): def apply_section_augmentations( - patch, labels, mask, patch_final_zyx, labels_final_zyx, mask_final_zyx, - elastic_warp_skip_ratio, affine_transform_skip_ratio, - misalignment_skip_ratio, missing_section_skip_ratio, outoffocus_skip_ratio, - grayscale_skip_ratio, max_warp_indices_ratio, num_control_points_ratio, - deformation_stdev_ratio, max_affine_transform_indices_ratio, rotation_max, - scale_max, shear_max, max_xy_offset, slip_vs_translate_ratio, - max_missing_section_indices_ratio, max_outoffocus_indices_ratio, - max_filter_stdev, max_contrast_factor, max_brightness_factor): + patch, + labels, + mask, + patch_final_zyx, + labels_final_zyx, + mask_final_zyx, + elastic_warp_skip_ratio, + affine_transform_skip_ratio, + misalignment_skip_ratio, + missing_section_skip_ratio, + outoffocus_skip_ratio, + grayscale_skip_ratio, + max_warp_indices_ratio, + num_control_points_ratio, + deformation_stdev_ratio, + max_affine_transform_indices_ratio, + rotation_max, + scale_max, + shear_max, + max_xy_offset, + slip_vs_translate_ratio, + max_missing_section_indices_ratio, + max_outoffocus_indices_ratio, + max_filter_stdev, + max_contrast_factor, + max_brightness_factor, +): """Performs ssEM training set augmentations. Augmentations performed by this function were designed @@ -968,54 +1067,85 @@ def apply_section_augmentations( """ def elastic_warp_fn(patch): - return elastic_warp(patch, max_warp_indices_ratio, num_control_points_ratio, - deformation_stdev_ratio, elastic_warp_skip_ratio) + return elastic_warp( + patch, + max_warp_indices_ratio, + num_control_points_ratio, + deformation_stdev_ratio, + elastic_warp_skip_ratio, + ) def affine_transform_fn(patch): - return affine_transform(patch, max_affine_transform_indices_ratio, - rotation_max, scale_max, shear_max, - affine_transform_skip_ratio) + return affine_transform( + patch, + max_affine_transform_indices_ratio, + rotation_max, + scale_max, + shear_max, + affine_transform_skip_ratio, + ) def misalignment_fn(patch, labels, mask): - return misalignment(patch, labels, mask, patch_final_zyx, labels_final_zyx, - mask_final_zyx, max_xy_offset, slip_vs_translate_ratio, - misalignment_skip_ratio) + return misalignment( + patch, + labels, + mask, + patch_final_zyx, + labels_final_zyx, + mask_final_zyx, + max_xy_offset, + slip_vs_translate_ratio, + misalignment_skip_ratio, + ) def missing_section_fn(patch): - return missing_section(patch, max_missing_section_indices_ratio, - missing_section_skip_ratio) + return missing_section( + patch, max_missing_section_indices_ratio, missing_section_skip_ratio + ) def outoffocus_section_fn(patch): - return out_of_focus_section(patch, max_outoffocus_indices_ratio, - max_filter_stdev, outoffocus_skip_ratio) + return out_of_focus_section( + patch, + max_outoffocus_indices_ratio, + max_filter_stdev, + outoffocus_skip_ratio, + ) def grayscale_perturb_fn(patch): - return grayscale_perturb(patch, max_contrast_factor, max_brightness_factor, - grayscale_skip_ratio) + return grayscale_perturb( + patch, max_contrast_factor, max_brightness_factor, grayscale_skip_ratio + ) patch_shape = [patch.shape[0]] + list(patch_final_zyx) + [patch.shape[-1]] labels_shape = [labels.shape[0]] + list(labels_final_zyx) + [labels.shape[-1]] mask_shape = [mask.shape[0]] + list(mask_final_zyx) + [mask.shape[-1]] with tf.name_scope('section_augmentations'): - patch, elastic_warp_summary = tf.py_func(elastic_warp_fn, [patch], - [patch.dtype, tf.int64]) + patch, elastic_warp_summary = tf.py_func( + elastic_warp_fn, [patch], [patch.dtype, tf.int64] + ) tf.summary.histogram('elastic_warp_z_indices', elastic_warp_summary) - patch, affine_transform_summary = tf.py_func(affine_transform_fn, [patch], - [patch.dtype, tf.int64]) + patch, affine_transform_summary = tf.py_func( + affine_transform_fn, [patch], [patch.dtype, tf.int64] + ) tf.summary.histogram('affine_transform_z_indices', affine_transform_summary) patch, labels, mask, misalignment_summary = tf.py_func( - misalignment_fn, [patch, labels, mask], - [patch.dtype, labels.dtype, mask.dtype, tf.int64]) + misalignment_fn, + [patch, labels, mask], + [patch.dtype, labels.dtype, mask.dtype, tf.int64], + ) tf.summary.scalar('misalignment_summary', misalignment_summary) - patch, missing_section_summary = tf.py_func(missing_section_fn, [patch], - [patch.dtype, tf.int64]) + patch, missing_section_summary = tf.py_func( + missing_section_fn, [patch], [patch.dtype, tf.int64] + ) tf.summary.histogram('missing_section_z_indices', missing_section_summary) - patch, outoffocus_summary = tf.py_func(outoffocus_section_fn, [patch], - [patch.dtype, tf.int64]) + patch, outoffocus_summary = tf.py_func( + outoffocus_section_fn, [patch], [patch.dtype, tf.int64] + ) tf.summary.histogram('out-of-focus_z_indices', outoffocus_summary) - patch, grayscale_summary = tf.py_func(grayscale_perturb_fn, [patch], - [patch.dtype, tf.int64]) + patch, grayscale_summary = tf.py_func( + grayscale_perturb_fn, [patch], [patch.dtype, tf.int64] + ) tf.summary.scalar('grayscale_applied', grayscale_summary) patch.set_shape(patch_shape)