diff --git a/frigate/config/__init__.py b/frigate/config/__init__.py index 1af2f08fe0..e90c336e51 100644 --- a/frigate/config/__init__.py +++ b/frigate/config/__init__.py @@ -3,13 +3,13 @@ from .auth import * # noqa: F403 from .camera import * # noqa: F403 from .camera_group import * # noqa: F403 +from .classification import * # noqa: F403 from .config import * # noqa: F403 from .database import * # noqa: F403 from .logger import * # noqa: F403 from .mqtt import * # noqa: F403 from .notification import * # noqa: F403 from .proxy import * # noqa: F403 -from .semantic_search import * # noqa: F403 from .telemetry import * # noqa: F403 from .tls import * # noqa: F403 from .ui import * # noqa: F403 diff --git a/frigate/config/semantic_search.py b/frigate/config/classification.py similarity index 78% rename from frigate/config/semantic_search.py rename to frigate/config/classification.py index 66b8c71701..4e806f9d93 100644 --- a/frigate/config/semantic_search.py +++ b/frigate/config/classification.py @@ -11,6 +11,22 @@ ] +class BirdClassificationConfig(FrigateBaseModel): + enabled: bool = Field(default=False, title="Enable bird classification.") + threshold: float = Field( + default=0.9, + title="Minimum classification score required to be considered a match.", + gt=0.0, + le=1.0, + ) + + +class ClassificationConfig(FrigateBaseModel): + bird: BirdClassificationConfig = Field( + default_factory=BirdClassificationConfig, title="Bird classification config." + ) + + class SemanticSearchConfig(FrigateBaseModel): enabled: bool = Field(default=False, title="Enable semantic search.") reindex: Optional[bool] = Field( diff --git a/frigate/config/config.py b/frigate/config/config.py index c4247e6f2f..f3b17c5fa9 100644 --- a/frigate/config/config.py +++ b/frigate/config/config.py @@ -51,17 +51,18 @@ from .camera.snapshots import SnapshotsConfig from .camera.timestamp import TimestampStyleConfig from .camera_group import CameraGroupConfig +from .classification import ( + ClassificationConfig, + FaceRecognitionConfig, + LicensePlateRecognitionConfig, + SemanticSearchConfig, +) from .database import DatabaseConfig from .env import EnvVars from .logger import LoggerConfig from .mqtt import MqttConfig from .notification import NotificationConfig from .proxy import ProxyConfig -from .semantic_search import ( - FaceRecognitionConfig, - LicensePlateRecognitionConfig, - SemanticSearchConfig, -) from .telemetry import TelemetryConfig from .tls import TlsConfig from .ui import UIConfig @@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel): default_factory=TelemetryConfig, title="Telemetry configuration." ) tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.") + classification: ClassificationConfig = Field( + default_factory=ClassificationConfig, title="Object classification config." + ) semantic_search: SemanticSearchConfig = Field( default_factory=SemanticSearchConfig, title="Semantic search configuration." ) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py new file mode 100644 index 0000000000..e432a186b9 --- /dev/null +++ b/frigate/data_processing/real_time/bird_processor.py @@ -0,0 +1,154 @@ +"""Handle processing images to classify birds.""" + +import logging +import os + +import cv2 +import numpy as np +import requests + +from frigate.config import FrigateConfig +from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR +from frigate.util.object import calculate_region + +from ..types import DataProcessorMetrics +from .api import RealTimeProcessorApi + +try: + from tflite_runtime.interpreter import Interpreter +except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + +logger = logging.getLogger(__name__) + + +class BirdProcessor(RealTimeProcessorApi): + def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics): + super().__init__(config, metrics) + self.interpreter: Interpreter = None + self.tensor_input_details: dict[str, any] = None + self.tensor_output_details: dict[str, any] = None + self.detected_birds: dict[str, float] = {} + self.labelmap: dict[int, str] = {} + + download_path = os.path.join(MODEL_CACHE_DIR, "bird") + self.model_files = { + "bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite", + "birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt", + } + + if not all( + os.path.exists(os.path.join(download_path, n)) + for n in self.model_files.keys() + ): + # conditionally import ModelDownloader + from frigate.util.downloader import ModelDownloader + + self.downloader = ModelDownloader( + model_name="bird", + download_path=download_path, + file_names=self.model_files.keys(), + download_func=self.__download_models, + complete_func=self.__build_detector, + ) + self.downloader.ensure_model_files() + else: + self.__build_detector() + + def __download_models(self, path: str) -> None: + try: + file_name = os.path.basename(path) + + # conditionally import ModelDownloader + from frigate.util.downloader import ModelDownloader + + ModelDownloader.download_from_url(self.model_files[file_name], path) + except Exception as e: + logger.error(f"Failed to download {path}: {e}") + + def __build_detector(self) -> None: + self.interpreter = Interpreter( + model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"), + num_threads=2, + ) + self.interpreter.allocate_tensors() + self.tensor_input_details = self.interpreter.get_input_details() + self.tensor_output_details = self.interpreter.get_output_details() + + i = 0 + + with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f: + line = f.readline() + while line: + start = line.find("(") + end = line.find(")") + self.labelmap[i] = line[start + 1 : end] + i += 1 + line = f.readline() + + def process_frame(self, obj_data, frame): + if obj_data["label"] != "bird": + return + + x, y, x2, y2 = calculate_region( + frame.shape, + obj_data["box"][0], + obj_data["box"][1], + obj_data["box"][2], + obj_data["box"][3], + 224, + 1.0, + ) + + rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) + input = rgb[ + y:y2, + x:x2, + ] + + cv2.imwrite("/media/frigate/test_class.png", input) + + input = np.expand_dims(input, axis=0) + + self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) + self.interpreter.invoke() + res: np.ndarray = self.interpreter.get_tensor( + self.tensor_output_details[0]["index"] + )[0] + probs = res / res.sum(axis=0) + best_id = np.argmax(probs) + + if best_id == 964: + logger.debug("No bird classification was detected.") + return + + score = round(probs[best_id], 2) + + if score < self.config.classification.bird.threshold: + logger.debug(f"Score {score} is not above required threshold") + return + + previous_score = self.detected_birds.get(obj_data["id"], 0.0) + + if score <= previous_score: + logger.debug(f"Score {score} is worse than previous score {previous_score}") + return + + resp = requests.post( + f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label", + json={ + "camera": obj_data.get("camera"), + "subLabel": self.labelmap[best_id], + "subLabelScore": score, + }, + ) + + if resp.status_code == 200: + self.detected_birds[obj_data["id"]] = score + + def handle_request(self, request_data): + return None + + def expire_object(self, object_id): + if object_id in self.detected_birds: + self.detected_birds.pop(object_id) diff --git a/frigate/embeddings/lpr/lpr.py b/frigate/embeddings/lpr/lpr.py index 16eba99898..d7e513c737 100644 --- a/frigate/embeddings/lpr/lpr.py +++ b/frigate/embeddings/lpr/lpr.py @@ -8,7 +8,7 @@ from shapely.geometry import Polygon from frigate.comms.inter_process import InterProcessRequestor -from frigate.config.semantic_search import LicensePlateRecognitionConfig +from frigate.config.classification import LicensePlateRecognitionConfig from frigate.embeddings.embeddings import Embeddings logger = logging.getLogger(__name__) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index a7e25469bb..aa0322fd7f 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -30,6 +30,7 @@ UPDATE_EVENT_DESCRIPTION, ) from frigate.data_processing.real_time.api import RealTimeProcessorApi +from frigate.data_processing.real_time.bird_processor import BirdProcessor from frigate.data_processing.real_time.face_processor import FaceProcessor from frigate.data_processing.types import DataProcessorMetrics from frigate.embeddings.lpr.lpr import LicensePlateRecognition @@ -78,6 +79,9 @@ def __init__( if self.config.face_recognition.enabled: self.processors.append(FaceProcessor(self.config, metrics)) + if self.config.classification.bird.enabled: + self.processors.append(BirdProcessor(self.config, metrics)) + # create communication for updating event descriptions self.requestor = InterProcessRequestor() self.stop_event = stop_event