-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfind_ms_params.py
118 lines (88 loc) · 2.88 KB
/
find_ms_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author: Fumiya ENDOU <[email protected]>
# Created by PyCharm at 2020/02/16
# This is a part of EQDmgAnalyzr
from multiprocessing import current_process
from os.path import join
from pprint import pprint
import numpy as np
from pymeanshift import segment
from tqdm import tqdm
from imgproc.utils import imread_with_error
from utils.pool import CustomPool
ROOT_DIR_SRC = "./img/resource/aerial_image/CLAHE_with_WB_adjust"
ROOT_DIR_ANS = join(ROOT_DIR_SRC, "smoothed")
SP_RANGE = (1, 16, 1)
SR_RANGE = (1, 16, 1)
def func_worker(img, spatial_radius=None, range_radius=None, min_density=None):
_worker_id = current_process()._identity[0]
_desc = f"Worker #{_worker_id:3d} (sp={spatial_radius:2.1f}, sr={range_radius:2.1f})"
for _ in tqdm([0], desc=_desc, position=_worker_id, leave=False):
segmented = segment(
img,
spatial_radius=spatial_radius,
range_radius=range_radius,
min_density=min_density
)[0]
return segmented, spatial_radius, range_radius
def find_ms_params(n):
file_name = f"aerial_roi{n}.png"
src = imread_with_error(
join(ROOT_DIR_SRC, file_name)
)
ans = imread_with_error(
join(ROOT_DIR_ANS, file_name)
)
ms_params = sum([
[
{
"spatial_radius": sp,
"range_radius": sr,
"min_density": 0
}
for sr in np.arange(SR_RANGE[0], SR_RANGE[0]+SR_RANGE[1], SR_RANGE[2])
]
for sp in np.arange(SP_RANGE[0], SP_RANGE[0]+SP_RANGE[1], SP_RANGE[2])
], [])
progress_bar = tqdm(total=len(ms_params), position=0)
def _update_progressbar(arg):
progress_bar.update()
cp = CustomPool()
pool = cp.Pool(n_process=6, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),))
results = list()
for params in ms_params:
results.append(
pool.apply_async(
func_worker,
args=(src, ),
kwds=params,
callback=_update_progressbar
)
)
pool.close()
pool.join()
cp.update()
results = [result.get() for result in results]
results = sorted(
[
(
sp,
sr,
np.sum(
np.abs(segmented - ans)
)
)
for segmented, sp, sr in results
],
key=lambda e: e[0]
)
pprint(results)
with open(f"tmp/find_ms_params_{n}.csv", "wt") as f:
f.write("spatial_radius, range_radius, n_diffs\n")
for result in results:
f.write(", ".join([ str(x) for x in result]) + "\n")
return results
if __name__ == '__main__':
for n in [1, 2, 3, 5, 9]:
find_ms_params(n)