Skip to content

Commit

Permalink
Merge pull request #527 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
add megadetectorv5a to the model zoo
  • Loading branch information
zhmiao authored Aug 21, 2024
2 parents 2b16ad1 + d01b71a commit 722808b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
8 changes: 6 additions & 2 deletions PytorchWildlife/models/detection/yolov5/megadetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,22 @@ class MegaDetectorV5(YOLOV5Base):
2: "vehicle"
}

def __init__(self, weights=None, device="cpu", pretrained=True):
def __init__(self, weights=None, device="cpu", pretrained=True, version="a"):
"""
Initializes the MegaDetectorV5 model with the option to load pretrained weights.
Args:
weights (str, optional): Path to the weights file.
device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu".
pretrained (bool, optional): Whether to load the pretrained model. Default is True.
version (str, optional): Version of the MegaDetectorV5 model to load. Default is "a".
"""

if pretrained:
url = "https://zenodo.org/records/10023414/files/MegaDetector_v5b.0.0.pt?download=1"
if version == "a":
url = "https://zenodo.org/records/13357337/files/md_v5a.0.0.pt?download=1"
elif version == "b":
url = "https://zenodo.org/records/10023414/files/MegaDetector_v5b.0.0.pt?download=1"
else:
url = None

Expand Down
2 changes: 1 addition & 1 deletion demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#%%
# Initializing the MegaDetectorV5 model for image detection
detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True)
detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, verison="a")

#%% Single image detection
# Specifying the path to the target image TODO: Allow argparsing
Expand Down
8 changes: 7 additions & 1 deletion demo/image_detection_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@
"source": [
"# Setting the device to use for computations ('cuda' indicates GPU)\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True)"
"detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, verison=\"a\")"
]
},
{
"cell_type": "markdown",
"id": "64c19af9",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "1e57dcca",
Expand Down

0 comments on commit 722808b

Please sign in to comment.