From 7d478be798df83c3c700c347220c95646ecec49e Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 11 Jan 2025 09:05:13 -0700 Subject: [PATCH 1/7] Start working on bird processor --- .../real_time/bird_processor.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 frigate/data_processing/real_time/bird_processor.py diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py new file mode 100644 index 0000000000..f561192586 --- /dev/null +++ b/frigate/data_processing/real_time/bird_processor.py @@ -0,0 +1,90 @@ +"""Handle processing images to classify birds.""" + +import logging +import os + +import numpy as np + +from frigate.config import FrigateConfig +from frigate.const import MODEL_CACHE_DIR + +from .processor_api import ProcessorApi +from .types import PostProcessingMetrics + +try: + from tflite_runtime.interpreter import Interpreter +except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + +logger = logging.getLogger(__name__) + + +class BirdProcessor(ProcessorApi): + def __init__(self, config: FrigateConfig, metrics: PostProcessingMetrics): + super().__init__(config, metrics) + self.interpreter: Interpreter = None + self.tensor_input_details: dict[str, any] = None + self.tensor_output_details: dict[str, any] = None + self.detected_birds: dict[str, float] = {} + + download_path = os.path.join(MODEL_CACHE_DIR, "bird") + self.model_files = { + "bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite", + "birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt", + } + + if not all( + os.path.exists(os.path.join(download_path, n)) + for n in self.model_files.keys() + ): + # conditionally import ModelDownloader + from frigate.util.downloader import ModelDownloader + + self.downloader = ModelDownloader( + model_name="bird", + download_path=download_path, + file_names=self.model_files.keys(), + download_func=self.__download_models, + complete_func=self.__build_detector, + ) + self.downloader.ensure_model_files() + else: + self.__build_detector() + + def __download_models(self, path: str) -> None: + try: + file_name = os.path.basename(path) + + # conditionally import ModelDownloader + from frigate.util.downloader import ModelDownloader + + ModelDownloader.download_from_url(self.model_files[file_name], path) + except Exception as e: + logger.error(f"Failed to download {path}: {e}") + + def __build_detector(self) -> None: + self.interpreter = Interpreter( + model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"), + num_threads=2, + ) + self.interpreter.allocate_tensors() + self.tensor_input_details = self.interpreter.get_input_details() + self.tensor_output_details = self.interpreter.get_output_details() + + def process_frame(self, obj_data, frame): + if obj_data["label"] != "bird": + return + + self.interpreter.set_tensor(self.tensor_input_details[0]["index"], frame) + self.interpreter.invoke() + res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + non_zero_indices = res > 0 + class_ids = np.argpartition(-res, 20)[:20] + class_ids = class_ids[np.argsort(-res[class_ids])] + class_ids = class_ids[non_zero_indices[class_ids]] + scores = res[class_ids] + boxes = np.full((scores.shape[0], 4), -1, np.float32) + count = len(scores) + + def handle_request(self, request_data): + return None From 75c9a723bfb982875f79a535d293f93fbff33930 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 11 Jan 2025 09:11:55 -0700 Subject: [PATCH 2/7] Initial setup for bird processing --- frigate/data_processing/real_time/bird_processor.py | 10 +++++----- frigate/embeddings/maintainer.py | 3 +++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index f561192586..a6dfb8510a 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -8,8 +8,8 @@ from frigate.config import FrigateConfig from frigate.const import MODEL_CACHE_DIR -from .processor_api import ProcessorApi -from .types import PostProcessingMetrics +from ..types import DataProcessorMetrics +from .api import RealTimeProcessorApi try: from tflite_runtime.interpreter import Interpreter @@ -19,8 +19,8 @@ logger = logging.getLogger(__name__) -class BirdProcessor(ProcessorApi): - def __init__(self, config: FrigateConfig, metrics: PostProcessingMetrics): +class BirdProcessor(RealTimeProcessorApi): + def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics): super().__init__(config, metrics) self.interpreter: Interpreter = None self.tensor_input_details: dict[str, any] = None @@ -87,4 +87,4 @@ def process_frame(self, obj_data, frame): count = len(scores) def handle_request(self, request_data): - return None + return None \ No newline at end of file diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index a7e25469bb..671df4917c 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -30,6 +30,7 @@ UPDATE_EVENT_DESCRIPTION, ) from frigate.data_processing.real_time.api import RealTimeProcessorApi +from frigate.data_processing.real_time.bird_processor import BirdProcessor from frigate.data_processing.real_time.face_processor import FaceProcessor from frigate.data_processing.types import DataProcessorMetrics from frigate.embeddings.lpr.lpr import LicensePlateRecognition @@ -78,6 +79,8 @@ def __init__( if self.config.face_recognition.enabled: self.processors.append(FaceProcessor(self.config, metrics)) + self.processors.append(BirdProcessor(self.config, metrics)) + # create communication for updating event descriptions self.requestor = InterProcessRequestor() self.stop_event = stop_event From f86b2232df6e87ede91a58df700046fcbee1b130 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 11 Jan 2025 16:11:04 -0700 Subject: [PATCH 3/7] Improvements to handling --- .../real_time/bird_processor.py | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index a6dfb8510a..c96fb1868b 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -3,10 +3,12 @@ import logging import os +import cv2 import numpy as np from frigate.config import FrigateConfig from frigate.const import MODEL_CACHE_DIR +from frigate.util.object import calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi @@ -75,16 +77,36 @@ def process_frame(self, obj_data, frame): if obj_data["label"] != "bird": return - self.interpreter.set_tensor(self.tensor_input_details[0]["index"], frame) + x, y, x2, y2 = calculate_region( + frame.shape, + obj_data["box"][0], + obj_data["box"][1], + obj_data["box"][2], + obj_data["box"][3], + 224, + 1.4, + ) + + rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) + input = rgb[ + y:y2, + x:x2, + ] + + logger.info(f"input shape is {input.shape}") + cv2.imwrite("/media/frigate/test_class.png", input) + + input = np.expand_dims(input, axis=0) + + self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() - res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] - non_zero_indices = res > 0 - class_ids = np.argpartition(-res, 20)[:20] - class_ids = class_ids[np.argsort(-res[class_ids])] - class_ids = class_ids[non_zero_indices[class_ids]] - scores = res[class_ids] - boxes = np.full((scores.shape[0], 4), -1, np.float32) - count = len(scores) + res: np.ndarray = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + probs = res / res.sum(axis=0) + best_id = np.argmax(probs) + score = probs[best_id] def handle_request(self, request_data): - return None \ No newline at end of file + return None + + def expire_object(self, object_id): + pass From f87e82481d188399dba7a9b3ec84cb8d69144a95 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 11 Jan 2025 16:15:41 -0700 Subject: [PATCH 4/7] Get classification working --- frigate/data_processing/real_time/bird_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index c96fb1868b..cf2c5bfea6 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -93,7 +93,6 @@ def process_frame(self, obj_data, frame): x:x2, ] - logger.info(f"input shape is {input.shape}") cv2.imwrite("/media/frigate/test_class.png", input) input = np.expand_dims(input, axis=0) @@ -103,7 +102,8 @@ def process_frame(self, obj_data, frame): res: np.ndarray = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] probs = res / res.sum(axis=0) best_id = np.argmax(probs) - score = probs[best_id] + score = round(probs[best_id], 2) + logger.info(f"the best scoring index is {best_id} {score}%") def handle_request(self, request_data): return None From 61870184df821d43c1299e9b3b42e74e623cd938 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 13 Jan 2025 07:24:59 -0700 Subject: [PATCH 5/7] Cleanup classification --- .../real_time/bird_processor.py | 44 +++++++++++++++++-- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index cf2c5bfea6..aa9b119840 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -5,9 +5,10 @@ import cv2 import numpy as np +import requests from frigate.config import FrigateConfig -from frigate.const import MODEL_CACHE_DIR +from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR from frigate.util.object import calculate_region from ..types import DataProcessorMetrics @@ -28,6 +29,7 @@ def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics): self.tensor_input_details: dict[str, any] = None self.tensor_output_details: dict[str, any] = None self.detected_birds: dict[str, float] = {} + self.labelmap: dict[int, str] = {} download_path = os.path.join(MODEL_CACHE_DIR, "bird") self.model_files = { @@ -73,6 +75,17 @@ def __build_detector(self) -> None: self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() + i = 0 + + with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f: + line = f.readline() + while line: + start = line.find("(") + end = line.find(")") + self.labelmap[i] = line[start + 1 : end] + i += 1 + line = f.readline() + def process_frame(self, obj_data, frame): if obj_data["label"] != "bird": return @@ -84,7 +97,7 @@ def process_frame(self, obj_data, frame): obj_data["box"][2], obj_data["box"][3], 224, - 1.4, + 1.0, ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) @@ -99,11 +112,34 @@ def process_frame(self, obj_data, frame): self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() - res: np.ndarray = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + res: np.ndarray = self.interpreter.get_tensor( + self.tensor_output_details[0]["index"] + )[0] probs = res / res.sum(axis=0) best_id = np.argmax(probs) + + if best_id == 964: + logger.debug("No bird classification was detected.") + return + score = round(probs[best_id], 2) - logger.info(f"the best scoring index is {best_id} {score}%") + previous_score = self.detected_birds.get(obj_data["id"], 0.0) + + if score <= previous_score: + logger.debug(f"Score {score} is worse than previous score {previous_score}") + return + + resp = requests.post( + f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label", + json={ + "camera": obj_data.get("camera"), + "subLabel": self.labelmap[best_id], + "subLabelScore": score, + }, + ) + + if resp.status_code == 200: + self.detected_birds[obj_data["id"]] = score def handle_request(self, request_data): return None From bbdb712b33a940977d7dea9a9f84ef8dcd26c24c Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 13 Jan 2025 07:29:43 -0700 Subject: [PATCH 6/7] Add classification config --- frigate/config/__init__.py | 2 +- .../{semantic_search.py => classification.py} | 16 ++++++++++++++++ frigate/config/config.py | 14 +++++++++----- .../data_processing/real_time/bird_processor.py | 8 +++++++- frigate/embeddings/lpr/lpr.py | 2 +- frigate/embeddings/maintainer.py | 3 ++- 6 files changed, 36 insertions(+), 9 deletions(-) rename frigate/config/{semantic_search.py => classification.py} (78%) diff --git a/frigate/config/__init__.py b/frigate/config/__init__.py index 1af2f08fe0..2f9ec0c566 100644 --- a/frigate/config/__init__.py +++ b/frigate/config/__init__.py @@ -9,7 +9,7 @@ from .mqtt import * # noqa: F403 from .notification import * # noqa: F403 from .proxy import * # noqa: F403 -from .semantic_search import * # noqa: F403 +from .classification import * # noqa: F403 from .telemetry import * # noqa: F403 from .tls import * # noqa: F403 from .ui import * # noqa: F403 diff --git a/frigate/config/semantic_search.py b/frigate/config/classification.py similarity index 78% rename from frigate/config/semantic_search.py rename to frigate/config/classification.py index 66b8c71701..4e806f9d93 100644 --- a/frigate/config/semantic_search.py +++ b/frigate/config/classification.py @@ -11,6 +11,22 @@ ] +class BirdClassificationConfig(FrigateBaseModel): + enabled: bool = Field(default=False, title="Enable bird classification.") + threshold: float = Field( + default=0.9, + title="Minimum classification score required to be considered a match.", + gt=0.0, + le=1.0, + ) + + +class ClassificationConfig(FrigateBaseModel): + bird: BirdClassificationConfig = Field( + default_factory=BirdClassificationConfig, title="Bird classification config." + ) + + class SemanticSearchConfig(FrigateBaseModel): enabled: bool = Field(default=False, title="Enable semantic search.") reindex: Optional[bool] = Field( diff --git a/frigate/config/config.py b/frigate/config/config.py index c4247e6f2f..f3b17c5fa9 100644 --- a/frigate/config/config.py +++ b/frigate/config/config.py @@ -51,17 +51,18 @@ from .camera.snapshots import SnapshotsConfig from .camera.timestamp import TimestampStyleConfig from .camera_group import CameraGroupConfig +from .classification import ( + ClassificationConfig, + FaceRecognitionConfig, + LicensePlateRecognitionConfig, + SemanticSearchConfig, +) from .database import DatabaseConfig from .env import EnvVars from .logger import LoggerConfig from .mqtt import MqttConfig from .notification import NotificationConfig from .proxy import ProxyConfig -from .semantic_search import ( - FaceRecognitionConfig, - LicensePlateRecognitionConfig, - SemanticSearchConfig, -) from .telemetry import TelemetryConfig from .tls import TlsConfig from .ui import UIConfig @@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel): default_factory=TelemetryConfig, title="Telemetry configuration." ) tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.") + classification: ClassificationConfig = Field( + default_factory=ClassificationConfig, title="Object classification config." + ) semantic_search: SemanticSearchConfig = Field( default_factory=SemanticSearchConfig, title="Semantic search configuration." ) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index aa9b119840..e432a186b9 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -123,6 +123,11 @@ def process_frame(self, obj_data, frame): return score = round(probs[best_id], 2) + + if score < self.config.classification.bird.threshold: + logger.debug(f"Score {score} is not above required threshold") + return + previous_score = self.detected_birds.get(obj_data["id"], 0.0) if score <= previous_score: @@ -145,4 +150,5 @@ def handle_request(self, request_data): return None def expire_object(self, object_id): - pass + if object_id in self.detected_birds: + self.detected_birds.pop(object_id) diff --git a/frigate/embeddings/lpr/lpr.py b/frigate/embeddings/lpr/lpr.py index 16eba99898..d7e513c737 100644 --- a/frigate/embeddings/lpr/lpr.py +++ b/frigate/embeddings/lpr/lpr.py @@ -8,7 +8,7 @@ from shapely.geometry import Polygon from frigate.comms.inter_process import InterProcessRequestor -from frigate.config.semantic_search import LicensePlateRecognitionConfig +from frigate.config.classification import LicensePlateRecognitionConfig from frigate.embeddings.embeddings import Embeddings logger = logging.getLogger(__name__) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 671df4917c..aa0322fd7f 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -79,7 +79,8 @@ def __init__( if self.config.face_recognition.enabled: self.processors.append(FaceProcessor(self.config, metrics)) - self.processors.append(BirdProcessor(self.config, metrics)) + if self.config.classification.bird.enabled: + self.processors.append(BirdProcessor(self.config, metrics)) # create communication for updating event descriptions self.requestor = InterProcessRequestor() From d357c6c1fd6d3aebc77cd443bdfaca35f98ac677 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 13 Jan 2025 07:45:00 -0700 Subject: [PATCH 7/7] Update sort --- frigate/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frigate/config/__init__.py b/frigate/config/__init__.py index 2f9ec0c566..e90c336e51 100644 --- a/frigate/config/__init__.py +++ b/frigate/config/__init__.py @@ -3,13 +3,13 @@ from .auth import * # noqa: F403 from .camera import * # noqa: F403 from .camera_group import * # noqa: F403 +from .classification import * # noqa: F403 from .config import * # noqa: F403 from .database import * # noqa: F403 from .logger import * # noqa: F403 from .mqtt import * # noqa: F403 from .notification import * # noqa: F403 from .proxy import * # noqa: F403 -from .classification import * # noqa: F403 from .telemetry import * # noqa: F403 from .tls import * # noqa: F403 from .ui import * # noqa: F403