From e3596bd88c02f9f99f4dac2688741b6d886e2b57 Mon Sep 17 00:00:00 2001 From: HiroIshida Date: Sat, 12 Nov 2022 23:05:36 +0900 Subject: [PATCH] Refactor changes in refactor --- node_script/wrapper.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/node_script/wrapper.py b/node_script/wrapper.py index 67dc6d2..86b0561 100644 --- a/node_script/wrapper.py +++ b/node_script/wrapper.py @@ -17,7 +17,6 @@ from detic_ros.msg import SegmentationInfo - _cv_bridge = CvBridge() @@ -28,7 +27,7 @@ class InferenceRawResult: scores: List[float] visualization: Optional[VisImage] header: Header - class_names: List[str] + detected_class_names: List[str] def get_ros_segmentaion_image(self) -> Image: seg_img = _cv_bridge.cv2_to_imgmsg(self.segmentation_raw_image, encoding="32SC1") @@ -47,12 +46,13 @@ def get_ros_debug_segmentation_img(self) -> Image: human_friendly_scaling = 255 // self.segmentation_raw_image.max() new_data = (self.segmentation_raw_image * human_friendly_scaling).astype(np.uint8) debug_seg_img = _cv_bridge.cv2_to_imgmsg(new_data, encoding="mono8") - assert self.header is not None debug_seg_img.header = self.header return debug_seg_img def get_label_array(self) -> LabelArray: - labels = [Label(id=i + 1, name=self.class_names[i]) for i in self.class_indices] + labels = [Label(id=i + 1, name=name) + for i, name + in zip(self.class_indices, self.detected_class_names)] lab_arr = LabelArray(header=self.header, labels=labels) return lab_arr @@ -62,8 +62,7 @@ def get_score_array(self) -> VectorArray: def get_segmentation_info(self) -> SegmentationInfo: seg_img = self.get_ros_segmentaion_image() - detected_classes_names = [self.class_names[i] for i in self.class_indices] - seg_info = SegmentationInfo(detected_classes=detected_classes_names, + seg_info = SegmentationInfo(detected_classes=self.detected_class_names, scores=self.scores, segmentation=seg_img, header=self.header) @@ -131,13 +130,14 @@ def infer(self, msg: Image) -> InferenceRawResult: data[mask] = (i + 1) # Get class and score arrays - class_indexes = instances.pred_classes.tolist() + class_indices = instances.pred_classes.tolist() + detected_classes_names = [self.class_names[i] for i in class_indices] scores = instances.scores.tolist() result = InferenceRawResult( data, - class_indexes, + class_indices, scores, visualized_output, msg.header, - self.class_names) + detected_classes_names) return result