Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bird classification #15966

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frigate/config/__init__.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,13 @@
from .auth import * # noqa: F403
from .camera import * # noqa: F403
from .camera_group import * # noqa: F403
from .classification import * # noqa: F403
from .config import * # noqa: F403
from .database import * # noqa: F403
from .logger import * # noqa: F403
from .mqtt import * # noqa: F403
from .notification import * # noqa: F403
from .proxy import * # noqa: F403
from .semantic_search import * # noqa: F403
from .telemetry import * # noqa: F403
from .tls import * # noqa: F403
from .ui import * # noqa: F403
Original file line number Diff line number Diff line change
@@ -11,6 +11,22 @@
]


class BirdClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable bird classification.")
threshold: float = Field(
default=0.9,
title="Minimum classification score required to be considered a match.",
gt=0.0,
le=1.0,
)


class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)


class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable semantic search.")
reindex: Optional[bool] = Field(
14 changes: 9 additions & 5 deletions frigate/config/config.py
Original file line number Diff line number Diff line change
@@ -51,17 +51,18 @@
from .camera.snapshots import SnapshotsConfig
from .camera.timestamp import TimestampStyleConfig
from .camera_group import CameraGroupConfig
from .classification import (
ClassificationConfig,
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .database import DatabaseConfig
from .env import EnvVars
from .logger import LoggerConfig
from .mqtt import MqttConfig
from .notification import NotificationConfig
from .proxy import ProxyConfig
from .semantic_search import (
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .telemetry import TelemetryConfig
from .tls import TlsConfig
from .ui import UIConfig
@@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel):
default_factory=TelemetryConfig, title="Telemetry configuration."
)
tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.")
classification: ClassificationConfig = Field(
default_factory=ClassificationConfig, title="Object classification config."
)
semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration."
)
154 changes: 154 additions & 0 deletions frigate/data_processing/real_time/bird_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Handle processing images to classify birds."""

import logging
import os

import cv2
import numpy as np
import requests

from frigate.config import FrigateConfig
from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR
from frigate.util.object import calculate_region

from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi

try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter

logger = logging.getLogger(__name__)


class BirdProcessor(RealTimeProcessorApi):
def __init__(self, config: FrigateConfig, metrics: DataProcessorMetrics):
super().__init__(config, metrics)
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, any] = None
self.tensor_output_details: dict[str, any] = None
self.detected_birds: dict[str, float] = {}
self.labelmap: dict[int, str] = {}

download_path = os.path.join(MODEL_CACHE_DIR, "bird")
self.model_files = {
"bird.tflite": "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
"birdmap.txt": "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt",
}

if not all(
os.path.exists(os.path.join(download_path, n))
for n in self.model_files.keys()
):
# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader

self.downloader = ModelDownloader(
model_name="bird",
download_path=download_path,
file_names=self.model_files.keys(),
download_func=self.__download_models,
complete_func=self.__build_detector,
)
self.downloader.ensure_model_files()
else:
self.__build_detector()

def __download_models(self, path: str) -> None:
try:
file_name = os.path.basename(path)

# conditionally import ModelDownloader
from frigate.util.downloader import ModelDownloader

ModelDownloader.download_from_url(self.model_files[file_name], path)
except Exception as e:
logger.error(f"Failed to download {path}: {e}")

def __build_detector(self) -> None:
self.interpreter = Interpreter(
model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"),
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()

i = 0

with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
line = f.readline()
while line:
start = line.find("(")
end = line.find(")")
self.labelmap[i] = line[start + 1 : end]
i += 1
line = f.readline()

def process_frame(self, obj_data, frame):
if obj_data["label"] != "bird":
return

x, y, x2, y2 = calculate_region(
frame.shape,
obj_data["box"][0],
obj_data["box"][1],
obj_data["box"][2],
obj_data["box"][3],
224,
1.0,
)

rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[
y:y2,
x:x2,
]

cv2.imwrite("/media/frigate/test_class.png", input)

input = np.expand_dims(input, axis=0)

self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
probs = res / res.sum(axis=0)
best_id = np.argmax(probs)

if best_id == 964:
logger.debug("No bird classification was detected.")
return

score = round(probs[best_id], 2)

if score < self.config.classification.bird.threshold:
logger.debug(f"Score {score} is not above required threshold")
return

previous_score = self.detected_birds.get(obj_data["id"], 0.0)

if score <= previous_score:
logger.debug(f"Score {score} is worse than previous score {previous_score}")
return

resp = requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
json={
"camera": obj_data.get("camera"),
"subLabel": self.labelmap[best_id],
"subLabelScore": score,
},
)

if resp.status_code == 200:
self.detected_birds[obj_data["id"]] = score

def handle_request(self, request_data):
return None

def expire_object(self, object_id):
if object_id in self.detected_birds:
self.detected_birds.pop(object_id)
2 changes: 1 addition & 1 deletion frigate/embeddings/lpr/lpr.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from shapely.geometry import Polygon

from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import LicensePlateRecognitionConfig
from frigate.config.classification import LicensePlateRecognitionConfig
from frigate.embeddings.embeddings import Embeddings

logger = logging.getLogger(__name__)
4 changes: 4 additions & 0 deletions frigate/embeddings/maintainer.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
UPDATE_EVENT_DESCRIPTION,
)
from frigate.data_processing.real_time.api import RealTimeProcessorApi
from frigate.data_processing.real_time.bird_processor import BirdProcessor
from frigate.data_processing.real_time.face_processor import FaceProcessor
from frigate.data_processing.types import DataProcessorMetrics
from frigate.embeddings.lpr.lpr import LicensePlateRecognition
@@ -78,6 +79,9 @@ def __init__(
if self.config.face_recognition.enabled:
self.processors.append(FaceProcessor(self.config, metrics))

if self.config.classification.bird.enabled:
self.processors.append(BirdProcessor(self.config, metrics))

# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
self.stop_event = stop_event