From d01b71a43cb3f965adc36b71ab54491c6c15c497 Mon Sep 17 00:00:00 2001 From: zhmiao Date: Wed, 21 Aug 2024 21:26:20 +0000 Subject: [PATCH] add megadetectorv5a to the model zoo --- PytorchWildlife/models/detection/yolov5/megadetector.py | 8 ++++++-- demo/image_demo.py | 2 +- demo/image_detection_demo.ipynb | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/PytorchWildlife/models/detection/yolov5/megadetector.py b/PytorchWildlife/models/detection/yolov5/megadetector.py index c355e8979..ec7077147 100644 --- a/PytorchWildlife/models/detection/yolov5/megadetector.py +++ b/PytorchWildlife/models/detection/yolov5/megadetector.py @@ -26,7 +26,7 @@ class MegaDetectorV5(YOLOV5Base): 2: "vehicle" } - def __init__(self, weights=None, device="cpu", pretrained=True): + def __init__(self, weights=None, device="cpu", pretrained=True, version="a"): """ Initializes the MegaDetectorV5 model with the option to load pretrained weights. @@ -34,10 +34,14 @@ def __init__(self, weights=None, device="cpu", pretrained=True): weights (str, optional): Path to the weights file. device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". pretrained (bool, optional): Whether to load the pretrained model. Default is True. + version (str, optional): Version of the MegaDetectorV5 model to load. Default is "a". """ if pretrained: - url = "https://zenodo.org/records/10023414/files/MegaDetector_v5b.0.0.pt?download=1" + if version == "a": + url = "https://zenodo.org/records/13357337/files/md_v5a.0.0.pt?download=1" + elif version == "b": + url = "https://zenodo.org/records/10023414/files/MegaDetector_v5b.0.0.pt?download=1" else: url = None diff --git a/demo/image_demo.py b/demo/image_demo.py index 58eb741ba..f703c0e54 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -26,7 +26,7 @@ #%% # Initializing the MegaDetectorV5 model for image detection -detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True) +detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, verison="a") #%% Single image detection # Specifying the path to the target image TODO: Allow argparsing diff --git a/demo/image_detection_demo.ipynb b/demo/image_detection_demo.ipynb index 4d7f5e2ff..ac3d77fa7 100644 --- a/demo/image_detection_demo.ipynb +++ b/demo/image_detection_demo.ipynb @@ -64,9 +64,15 @@ "source": [ "# Setting the device to use for computations ('cuda' indicates GPU)\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True)" + "detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, verison=\"a\")" ] }, + { + "cell_type": "markdown", + "id": "64c19af9", + "metadata": {}, + "source": [] + }, { "cell_type": "markdown", "id": "1e57dcca",