Skip to content

Commit

Permalink
forgot to push ThreadedDumper (and resolved); added dump_percent to g…
Browse files Browse the repository at this point in the history
…enerate commandline
  • Loading branch information
dhodcz2 committed Dec 25, 2023
1 parent 701ffb7 commit 8f7e5a5
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 15 deletions.
8 changes: 7 additions & 1 deletion src/tile2net/raster/generate/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@
),
arg(
'--source', '-s', default=None, type=str,
)
),
arg(
'--dump_percent',
type=int,
default=0,
help='The percentage of segmentation results to save. 100 means all, 0 means none.',
),
)

class Namespace(argh.ArghNamespace):
Expand Down
246 changes: 232 additions & 14 deletions src/tile2net/tileseg/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
Miscellanous Functions
"""

from __future__ import annotations

import tempfile
from typing import Optional

import cv2
import sys
import os
Expand All @@ -44,9 +50,15 @@

from tile2net.tileseg.config import cfg
from tile2net.namespace import Namespace
from concurrent.futures import Future, ThreadPoolExecutor

from runx.logx import logx

from geopandas import GeoDataFrame

if False:
from tile2net.raster.tile import Tile


def fast_hist(pred, gtruth, num_classes):
# mask indicates pixels we care about
Expand Down Expand Up @@ -230,7 +242,7 @@ def eval_metrics(iou_acc, args, net, optim, val_loss, epoch, mf_score=None):
return was_best


class ImageDumper():
class ImageDumper:
"""
Image dumping class
Expand All @@ -239,10 +251,6 @@ class ImageDumper():
writes the images out to disk.
"""

# def __init__(self, val_len, args: Namespace, tensorboard=True, write_webpage=True,
# webpage_fn='index.html', dump_all_images=False, dump_assets=False,
# dump_err_prob=False, dump_num=10, dump_for_auto_labelling=False,
# dump_for_submission=False):
def __init__(
self,
val_len,
Expand All @@ -257,13 +265,19 @@ def __init__(
dump_num=10,
):
"""
:val_len: num validation images
:tensorboard: push summary to tensorboard
:webpage: generate a summary html page
:webpage_fn: name of webpage file
:dump_all_images: dump all (validation) images, e.g. for video
:dump_num: number of images to dump if not dumping all
:dump_assets: dump attention maps
Parameters
----------
val_len: num validation images
args: command line arguments
tensorboard: push summary to tensorboard
write_webpage: generate a summary html page
webpage_fn: name of webpage file
dump_all_images: dump all (validation) images, e.g. for video
dump_num: number of images to dump if not dumping all
dump_assets: dump attention maps
dump_err_prob: dump error probability
dump_for_auto_labelling: dump images for auto-labelling
dump_for_submission: dump images for submission
"""
self.val_len = val_len
self.tensorboard = tensorboard
Expand Down Expand Up @@ -325,6 +339,7 @@ def save_prob_and_err_mask(self, dump_dict, img_name, idx, prediction):
return False, err_pil

dump_percent = 100 # first one always dumps if args.dump_percent != 0

def create_composite_image(self, input_image, prediction_pil, img_name):
if not self.args.dump_percent:
return
Expand Down Expand Up @@ -363,7 +378,6 @@ def get_dump_assets(self, dump_dict, img_name, idx, colorize_mask_fn, to_tensorb
mask_pil.save(mask_fn)
to_tensorboard.append(self.visualize(mask_pil))


def dump(self, dump_dict, val_idx, testing=None, grid=None):

# if not self.args.dump_percent:
Expand Down Expand Up @@ -463,7 +477,7 @@ def write_summaries(self, was_best):


def print_evaluate_results(hist, iu, epoch=0, iou_per_scale=None,
log_multiscale_tb = False, eps = 1e-8):
log_multiscale_tb=False, eps=1e-8):
"""
If single scale:
just print results for default scale
Expand Down Expand Up @@ -559,3 +573,207 @@ def fmt_scale(prefix, scale):
scale_str = str(float(scale))
scale_str.replace('.', '')
return f'{prefix}_{scale_str}x'


class ThreadedDumper(ImageDumper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.futures: list[Future] = []
self.threads = ThreadPoolExecutor()
os.makedirs(self.save_dir, exist_ok=True)

def dump(self, dump_dict, val_idx, testing=None, grid=None):

colorize_mask_fn = cfg.DATASET_INST.colorize_mask

for idx in range(len(dump_dict['input_images'])):

input_image = dump_dict['input_images'][idx]
gt_image = dump_dict['gt_images'][idx]
prediction = dump_dict['assets']['predictions'][idx]
img_name = dump_dict['img_names'][idx]

er_prob, err_pil = self.save_prob_and_err_mask(dump_dict, img_name, idx, prediction)

input_image = self.inv_normalize(input_image)
input_image = input_image.cpu()
input_image = standard_transforms.ToPILImage()(input_image)
input_image = input_image.convert("RGB")

gt_pil = colorize_mask_fn(gt_image.cpu().numpy())

if testing:
alpha = False
all_pix = (np.array(input_image).shape[0] * np.array(input_image).shape[1])
if np.array(input_image).shape[-1] == 3:
black = np.count_nonzero(np.all(np.array(input_image) == [0, 0, 0], axis=2))
elif np.array(input_image).shape[-1] == 4:
black = np.count_nonzero(np.all(np.array(input_image) == [0, 0, 0, 0], axis=-1))
alpha = True
ratio = black / all_pix

if ratio > 0.25 or alpha:
continue

else:
prediction_pil = colorize_mask_fn(prediction)
prediction_pil = prediction_pil.convert('RGB')
self.create_composite_image(input_image, prediction_pil, img_name)

if grid:
idd_ = img_name.split('_')[-1]
save_dir = os.path.join(cfg.RESULT_DIR, 'seg_results')
self.save_dir = save_dir
tile = grid.tiles[grid.pose_dict[int(idd_)]]
polygons = self.map_features(tile, np.array(prediction_pil), img_array=True)
if polygons is not None:
yield polygons
else:
# gt_fn = '{}_gt.png'.format(img_name)
gt_pil = colorize_mask_fn(gt_image.cpu().numpy())
# prediction_fn = '{}_prediction.png'.format(img_name)
prediction_pil = colorize_mask_fn(prediction)
prediction_pil = prediction_pil.convert('RGB')
self.create_composite_image(input_image, prediction_pil, img_name)

to_tensorboard = [
self.visualize(input_image.convert('RGB')),
self.visualize(gt_pil.convert('RGB')),
self.visualize(prediction_pil.convert('RGB')),
]
if er_prob and err_pil is not None:
to_tensorboard.append(self.visualize(err_pil.convert('RGB')))

self.get_dump_assets(dump_dict, img_name, idx, colorize_mask_fn, to_tensorboard)

self.imgs_to_tensorboard.append(to_tensorboard)

for future in self.futures:
future.result()

def create_composite_image(self, input_image, prediction_pil, img_name):
threads = self.threads
futures = self.futures
if not self.args.dump_percent:
return
self.dump_percent += self.args.dump_percent
if self.dump_percent < 100:
return
self.dump_percent -= 100
parent = os.path.dirname(self.save_dir)
os.makedirs(parent, exist_ok=True)
composited = Image.new('RGB', (input_image.width + input_image.width, input_image.height))
composited.paste(input_image, (0, 0))
composited.paste(prediction_pil, (prediction_pil.width, 0))
composited_fn = 'sidebside_{}.png'.format(img_name)
composited_fn = os.path.join(self.save_dir, composited_fn)
# print(f'saving {composited_fn}')
# composited.save(composited_fn)
future = threads.submit(composited.save, composited_fn)
futures.append(future)

def get_dump_assets(self, dump_dict, img_name, idx, colorize_mask_fn, to_tensorboard):
threads = self.threads
futures = self.futures
if self.dump_assets:
assets = dump_dict['assets']
for asset in assets:
mask = assets[asset][idx]
mask_fn = os.path.join(self.save_dir, f'{img_name}_{asset}.png')
if 'pred_' in asset:
pred_pil = colorize_mask_fn(mask)
future = threads.submit(pred_pil.save, mask_fn)
futures.append(future)
continue
if type(mask) == torch.Tensor:
mask = mask.squeeze().cpu().numpy()
else:
mask = mask.squeeze()
mask = (mask * 255)
mask = mask.astype(np.uint8)
mask_pil = Image.fromarray(mask)
mask_pil = mask_pil.convert('RGB')
future = threads.submit(mask_pil.save, mask_fn)
futures.append(future)
to_tensorboard.append(self.visualize(mask_pil))

def save_prob_and_err_mask(
self,
dump_dict,
img_name,
idx,
prediction,
):
threads = self.threads
futures = self.futures
err_pil = None
if 'err_mask' in dump_dict and 'prob_mask' in dump_dict['assets']:
prob_image = dump_dict['assets']['prob_mask'][idx]
err_mask = dump_dict['err_mask'][idx]
image = (prob_image.cpu().numpy() * 255).astype(np.uint8)
future = threads.submit(self.save_image, image, f'{img_name}_prob.png')
futures.append(future)
err_pil = Image.fromarray(prediction.astype(np.uint8)).convert('RGB')
# err_pil.save(os.path.join(self.save_dir, f'{img_name}_err_mask.png'))
path = os.path.join(self.save_dir, f'{img_name}_err_mask.png')
future = threads.submit(err_pil.save, path)
futures.append(future)
return True, err_pil
# return True, err_pil
return False, err_pil

def map_features(
self,
tile: Tile,
src_img: np.ndarray,
img_array=True,
) -> Optional[GeoDataFrame]:
"""Converts a raster mask to a GeoDataFrame of polygons
Parameters
----------
src_img : str
path to the image
img_array : array, optional
if the image is already read, pass the array to avoid reading it again
Returns
-------
geoms : GeoDataFrame
GeoDataFrame of polygons
"""
swcw = []
threads = self.threads
futures = self.futures

sidewalks = tile.mask2poly(
src_img,
class_name='sidewalk',
class_id=2,
img_array=img_array
)
if sidewalks is not False:
swcw.append(sidewalks)

crosswalks = tile.mask2poly(
src_img,
class_name='crosswalk',
class_id=0,
class_hole_size=15,
img_array=img_array
)
if crosswalks is not False:
swcw.append(crosswalks)

roads = tile.mask2poly(
src_img,
class_name='road',
class_id=1,
class_hole_size=30,
img_array=img_array
)
if roads is not False:
swcw.append(roads)
if len(swcw) > 0:
# noinspection PyTypeChecker
rswcw: GeoDataFrame = pd.concat(swcw)
rswcw.reset_index(drop=True, inplace=True)
return rswcw

0 comments on commit 8f7e5a5

Please sign in to comment.