From b733575412eee0b129df074d908e47ac8c7e0054 Mon Sep 17 00:00:00 2001 From: lart Date: Tue, 23 Aug 2022 00:19:12 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9E=20fix(Smeasure=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E8=B4=A8=E5=BF=83=E8=AE=A1=E7=AE=97):=20=E5=8E=9F=E5=A7=8B?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=9C=A8=E8=BE=93=E5=85=A5=E8=BF=87=E5=A4=A7?= =?UTF-8?q?=E6=97=B6=E4=BC=9A=E6=BA=A2=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原本直接基于np.sum()的实现在输入尺寸过大的时候会出现数值溢出。现在基于np.argwhere()的实现方式则避免了这一问题。 关于质心计算的更多细节可见文档:https://www.yuque.com/lart/blog/gpbigm --- examples/test_metrics.py | 54 ++++++++++++++++++++--------------- py_sod_metrics/sod_metrics.py | 12 ++++---- version.txt | 2 +- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/examples/test_metrics.py b/examples/test_metrics.py index 57e2b8e..c14719e 100644 --- a/examples/test_metrics.py +++ b/examples/test_metrics.py @@ -4,11 +4,12 @@ # @GitHub : https://github.com/lartpang import os +import sys +from pprint import pprint import cv2 -from tqdm import tqdm -# pip install pysodmetrics +sys.path.append("..") from py_sod_metrics import MAE, Emeasure, Fmeasure, Smeasure, WeightedFmeasure FM = Fmeasure() @@ -21,7 +22,8 @@ mask_root = os.path.join(data_root, "masks") pred_root = os.path.join(data_root, "preds") mask_name_list = sorted(os.listdir(mask_root)) -for mask_name in tqdm(mask_name_list, total=len(mask_name_list)): +for i, mask_name in enumerate(mask_name_list): + print(f"[{i}] Processing {mask_name}...") mask_path = os.path.join(mask_root, mask_name) pred_path = os.path.join(pred_root, mask_name) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) @@ -50,24 +52,30 @@ "maxFm": fm["curve"].max(), } -print(results) -# 'Smeasure': 0.9029763868504661, -# 'wFmeasure': 0.5579812753638986, -# 'MAE': 0.03705558476661653, -# 'adpEm': 0.9408760066970631, -# 'meanEm': 0.9566258293508715, -# 'maxEm': 0.966954482892271, -# 'adpFm': 0.5816750824038355, -# 'meanFm': 0.577051059518767, -# 'maxFm': 0.5886784581120638 +default_results = { + "v1_2_3": { + "Smeasure": 0.9029763868504661, + "wFmeasure": 0.5579812753638986, + "MAE": 0.03705558476661653, + "adpEm": 0.9408760066970631, + "meanEm": 0.9566258293508715, + "maxEm": 0.966954482892271, + "adpFm": 0.5816750824038355, + "meanFm": 0.577051059518767, + "maxFm": 0.5886784581120638, + }, + "v1_3_0": { + "Smeasure": 0.9029761578759272, + "wFmeasure": 0.5579812753638986, + "MAE": 0.03705558476661653, + "adpEm": 0.9408760066970617, + "meanEm": 0.9566258293508704, + "maxEm": 0.9669544828922699, + "adpFm": 0.5816750824038355, + "meanFm": 0.577051059518767, + "maxFm": 0.5886784581120638, + }, +} -# version 1.2.3 -# 'Smeasure': 0.9029763868504661, -# 'wFmeasure': 0.5579812753638986, -# 'MAE': 0.03705558476661653, -# 'adpEm': 0.9408760066970631, -# 'meanEm': 0.9566258293508715, -# 'maxEm': 0.966954482892271, -# 'adpFm': 0.5816750824038355, -# 'meanFm': 0.577051059518767, -# 'maxFm': 0.5886784581120638 +pprint(results) +pprint({k: default_value - results[k] for k, default_value in default_results["v1_3_0"].items()}) diff --git a/py_sod_metrics/sod_metrics.py b/py_sod_metrics/sod_metrics.py index 63a49cc..e9917d0 100644 --- a/py_sod_metrics/sod_metrics.py +++ b/py_sod_metrics/sod_metrics.py @@ -266,19 +266,17 @@ def centroid(self, matrix: np.ndarray) -> tuple: so there is no need to use the redundant addition operation when dividing the region later, because the sequence generated by ``1:X`` in matlab will contain ``X``. - :param matrix: a data array + :param matrix: a bool data array :return: the centroid coordinate """ h, w = matrix.shape - if matrix.sum() == 0: + area_object = np.count_nonzero(matrix) + if area_object == 0: x = np.round(w / 2) y = np.round(h / 2) else: - area_object = np.sum(matrix) - row_ids = np.arange(h) - col_ids = np.arange(w) - x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object) - y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object) + # More details can be found at: https://www.yuque.com/lart/blog/gpbigm + y, x = np.argwhere(matrix).mean(axis=0).round() return int(x) + 1, int(y) + 1 def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x: int, y: int) -> dict: diff --git a/version.txt b/version.txt index f0bb29e..3a3cd8c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -1.3.0 +1.3.1