diff --git a/example.ipynb b/example.ipynb index 2872dcc..ec764f3 100644 --- a/example.ipynb +++ b/example.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30a3c190-5045-4840-b4c2-6d263b7a3178", + "id": "0", "metadata": {}, "outputs": [], "source": [ @@ -22,7 +22,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f85ebf1-19a9-4bd7-afca-e077d99f9018", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -107,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3fd6807-a5d6-408a-a8a1-fae9dcfa1d5b", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -123,7 +123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98835a5d-081e-4bee-b2bc-595e01638033", + "id": "3", "metadata": {}, "outputs": [], "source": [ diff --git a/ncolor/label.py b/ncolor/label.py index 469acff..98f55b9 100644 --- a/ncolor/label.py +++ b/ncolor/label.py @@ -4,12 +4,12 @@ from numba import njit import scipy from .format_labels import format_labels, is_sequential -from skimage.segmentation import expand_labels as skimage_expand_labels +# from skimage.segmentation import expand_labels as skimage_expand_labels import edt from scipy.ndimage import distance_transform_edt -def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False): +def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False, greedy=False): # needs to be in standard label form # but also needs to be in int32 data type to work properly; the formatting automatically # puts it into the smallest datatype to save space @@ -26,7 +26,7 @@ def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False): lab = expand_labels(lab) # lab = np.pad(format_labels(lab),pad) lab = format_labels(np.pad(lab,pad),background=0) - lut = get_lut(lab,n,conn,max_depth,offset,return_n) + lut = get_lut(lab,n,conn,max_depth,offset,greedy) nc = lut[lab][unpad]*mask @@ -35,15 +35,22 @@ def label(lab, n=4, conn=2, max_depth=5, offset=0, expand=None, return_n=False): else: return nc -def get_lut(lab, n=4, conn=2, max_depth=5, offset=0, return_n=False): - lab = format_labels(lab).astype(np.int32) +def get_lut(lab, n=4, conn=2, max_depth=5, offset=0, greedy=False): + # lab = format_labels(lab).astype(np.int32) + lab = format_labels(lab).astype(np.int64) + idx = connect(lab, conn) idx = mapidx(idx) - colors = render_net(idx, n=n, rand=10, max_depth=max_depth, offset=offset) + if greedy: + colors = greedy_coloring(idx) + else: + colors = render_net(idx, n=n, rand=10, max_depth=max_depth, offset=offset) + lut = np.ones(lab.max()+1, dtype=np.uint8) for i in colors: lut[i] = colors[i] lut[0] = 0 return lut + def neighbors(shape, conn=1): dim = len(shape) @@ -72,6 +79,7 @@ def search(img, nbs): def connect(img, conn=1): buf = np.pad(img, 1, 'constant') nbs = neighbors(buf.shape, conn) + # rst = search(buf, nbs) rst = search(buf, nbs) if len(rst)<2: return rst @@ -86,17 +94,43 @@ def connect(img, conn=1): return rst[order][idx] # maybe replace this with fastremap +import fastremap def mapidx(idx): dic = {} - for i in np.unique(idx): dic[i] = [] + # for i in np.unique(idx): dic[i] = [] + for i in fastremap.unique(idx): dic[i] = [] # marginally faster for i,j in idx: dic[i].append(j) dic[j].append(i) return dic + +def mapidx(idx): + # Stack idx and its reversed version to account for both directions + idx_rev = idx[:, [1, 0]] + idx_all = np.vstack((idx, idx_rev)) + + # Sort idx_all by the first column (i) + order = np.argsort(idx_all[:, 0]) + idx_all_sorted = idx_all[order] + + i = idx_all_sorted[:, 0] + j = idx_all_sorted[:, 1] + + # Find unique 'i's and the indices where they occur + unique_i, indices = fastremap.unique(i, return_index=True) + + # Split 'j' into lists according to the indices + splits = np.split(j, indices[1:]) + + # Build the dictionary mapping each 'i' to its list of neighbors + dic = dict(zip(unique_i, splits)) + return dic # create a connection mapping def render_net(conmap, n=4, rand=12, depth=0, max_depth=5, offset=0): - thresh = 1e4 + # LARGE_INT = len(conmap)+1 # minimal to work, doesn't look as good? + LARGE_INT = len(conmap)*2 # get back to previous behavior + thresh = LARGE_INT if depth= 0) & (j < len_line) +# i_valid = i_repeat[valid_mask] +# j_valid = j[valid_mask] + +# # Get the labels at the valid indices +# line_i = line[i_valid] +# line_j = line[j_valid] + +# # Apply the conditions: +# # - Neighbor is non-zero +# # - Labels are different +# mask = (line_j != 0) & (line_i != line_j) + +# # Collect the valid pairs +# pairs = np.column_stack((line_i[mask], line_j[mask])) + +# return pairs + + +# def search2(img, conn=1): +# coords = np.array(np.nonzero(img)) # Convert to a NumPy array +# npix = coords.shape[1] # Number of non-zero pixels +# dim = img.ndim +# shape = img.shape + +# # Define neighbor offsets +# from scipy.ndimage import generate_binary_structure +# structure = generate_binary_structure(dim, conn) +# structure[tuple([1]*dim)] = 0 # Remove the center +# neighbor_offsets = np.array(np.nonzero(structure)) - 1 # Offsets relative to center +# n_neighbors = neighbor_offsets.shape[1] + +# # Compute neighbor coordinates +# # Expand coords to shape (dim, npix, 1) +# coords_expanded = coords[:, :, np.newaxis] # Shape: (dim, npix, 1) +# # Broadcast neighbor_offsets to (dim, 1, n_neighbors) and add +# neighbor_coords = coords_expanded + neighbor_offsets[:, np.newaxis, :] # Shape: (dim, npix, n_neighbors) + +# # Reshape to 2D arrays for easier indexing +# neighbor_coords = neighbor_coords.reshape(dim, -1) # Shape: (dim, npix * n_neighbors) +# center_coords = np.repeat(coords_expanded, n_neighbors, axis=2).reshape(dim, -1) # Shape: (dim, npix * n_neighbors) + +# # Handle out-of-bounds coordinates +# valid_mask = np.all((neighbor_coords >= 0) & (neighbor_coords < np.array(shape)[:, np.newaxis]), axis=0) + +# # Filter valid neighbor coordinates +# valid_neighbor_coords = neighbor_coords[:, valid_mask] +# valid_center_coords = center_coords[:, valid_mask] + +# # Map coordinates to flat indices +# neighbor_indices = np.ravel_multi_index(valid_neighbor_coords, shape) +# center_indices = np.ravel_multi_index(valid_center_coords, shape) + +# # Get labels at indices +# line = img.ravel() +# labels_center = line[center_indices] +# labels_neighbor = line[neighbor_indices] + +# # Filter valid pairs +# valid_pairs_mask = (labels_neighbor != 0) & (labels_neighbor != labels_center) + +# # Collect valid label pairs +# pairs = np.column_stack((labels_center[valid_pairs_mask], labels_neighbor[valid_pairs_mask])) + +# return pairs + + + # import fastremap + +# def connect(img, conn=1): +# buf = np.pad(img, 1, 'constant') +# rst = search2(buf, conn) +# if len(rst) < 2: +# return rst +# # Remove duplicates and sort the pairs +# rst = fastremap.unique(np.sort(rst, axis=1), axis=0) +# return rst + +# using fastremap is a lot slower? +# def connect(img, conn=1): +# buf = np.pad(img, 1, 'constant') +# nbs = neighbors(buf.shape, conn) +# rst = search(buf, nbs) +# if len(rst) < 2: +# return rst +# rst.sort(axis=1) +# print(rst.shape) +# # Use np.unique to find unique rows (label pairs) +# rst_unique = fastremap.unique(rst, axis=0) +# return rst_unique \ No newline at end of file