Skip to content

Commit

Permalink
Merge pull request #465 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
v1.0.2.8, adding indention and category exclusion in json output function
  • Loading branch information
zhmiao authored Apr 9, 2024
2 parents e219615 + 5de585c commit 13622b3
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
33 changes: 23 additions & 10 deletions PytorchWildlife/utils/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def save_crop_images(results, output_dir):
)


def save_detection_json(results, output_dir, categories=None):
def save_detection_json(results, output_dir, categories=None, exclude_category_ids=[]):
"""
Save detection results to a JSON file.
Expand All @@ -92,19 +92,32 @@ def save_detection_json(results, output_dir, categories=None):
Path to save the output JSON file.
categories (list, optional):
List of categories for detected objects. Defaults to None.
exclude_category_ids (list, optional):
List of category IDs to exclude from the output. Defaults to []. Category IDs can be found in the definition of each models.
"""
json_results = {"annotations": [], "categories": categories}
with open(output_dir, "w") as f:
for r in results:
json_results["annotations"].append(
{
"img_id": r["img_id"],
"bbox": r["detections"].xyxy.astype(int).tolist(),
"category": r["detections"].class_id.tolist(),
"confidence": r["detections"].confidence.tolist(),
}
)
json.dump(json_results, f)

# Category filtering
img_id = r["img_id"]
category = r["detections"].class_id

bbox = r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)]
confidence = r["detections"].confidence[~np.isin(category, exclude_category_ids)]
category = category[~np.isin(category, exclude_category_ids)]

if not all([x in exclude_category_ids for x in category]):
json_results["annotations"].append(
{
"img_id": img_id,
"bbox": bbox.tolist(),
"category": category.tolist(),
"confidence": confidence.tolist(),
}
)

json.dump(json_results, f, indent=4)


def save_detection_classification_json(
Expand Down
3 changes: 2 additions & 1 deletion demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,5 @@
#%% Output to JSON results
# Saving the detection results in JSON format
pw_utils.save_detection_json(results, os.path.join(".","batch_output.json"),
categories=detection_model.CLASS_NAMES)
categories=detection_model.CLASS_NAMES,
exclude_category_ids=[]) # Category IDs can be found in the definition of each model.
3 changes: 2 additions & 1 deletion demo/image_detection_colabdemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,8 @@
"outputs": [],
"source": [
"pw_utils.save_detection_json(results, \"./batch_output.json\",\n",
" categories=detection_model.CLASS_NAMES)"
" categories=detection_model.CLASS_NAMES,\n",
" exclude_category_ids=[]) # Category IDs can be found in the definition of each model."
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion demo/image_detection_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@
"outputs": [],
"source": [
"pw_utils.save_detection_json(results, os.path.join(\".\",\"batch_output.json\"),\n",
" categories=detection_model.CLASS_NAMES)"
" categories=detection_model.CLASS_NAMES,\n",
" exclude_category_ids=[]) # Category IDs can be found in the definition of each model."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = file.read()
setup(
name='PytorchWildlife',
version='1.0.2.5',
version='1.0.2.8',
packages=find_packages(),
url='https://github.com/microsoft/CameraTraps/',
license='MIT',
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.2.7
1.0.2.8

0 comments on commit 13622b3

Please sign in to comment.