Skip to content

Commit

Permalink
Merge pull request #476 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
Adding Sink overwrite argument for image saving
  • Loading branch information
zhmiao authored Apr 15, 2024
2 parents 006be83 + 302a4ed commit 997dfb6
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 17 deletions.
6 changes: 4 additions & 2 deletions PW_FT_classification/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
import numpy as np

def save_crop_images(results, output_dir, original_csv_path):
def save_crop_images(results, output_dir, original_csv_path, overwrite=False):
"""
Save cropped images based on the detection bounding boxes.
Expand All @@ -16,6 +16,8 @@ def save_crop_images(results, output_dir, original_csv_path):
Directory to save the cropped images.
original_csv_path (str):
Path to the original CSV file.
overwrite (bool):
Whether overwriting existing image folders. Default to False.
Return:
new_csv_path (str):
Path to the new CSV file.
Expand All @@ -29,7 +31,7 @@ def save_crop_images(results, output_dir, original_csv_path):
new_records = []

os.makedirs(output_dir, exist_ok=True)
with sv.ImageSink(target_dir_path=output_dir, overwrite=False) as sink:
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
for entry in results:
# Process the data if the name of the file is in the dataframe
if os.path.basename(entry["img_id"]) in original_df['path'].values:
Expand Down
12 changes: 8 additions & 4 deletions PytorchWildlife/utils/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# !!! Output paths need to be optimized !!!
def save_detection_images(results, output_dir):
def save_detection_images(results, output_dir, overwrite=False):
"""
Save detected images with bounding boxes and labels annotated.
Expand All @@ -30,11 +30,13 @@ def save_detection_images(results, output_dir):
Detection results containing image ID, detections, and labels.
output_dir (str):
Directory to save the annotated images.
overwrite (bool):
Whether overwriting existing image folders. Default to False.
"""
box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2)
os.makedirs(output_dir, exist_ok=True)

with sv.ImageSink(target_dir_path=output_dir, overwrite=True) as sink:
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
if isinstance(results, list):
for entry in results:
annotated_img = box_annotator.annotate(
Expand All @@ -57,7 +59,7 @@ def save_detection_images(results, output_dir):


# !!! Output paths need to be optimized !!!
def save_crop_images(results, output_dir):
def save_crop_images(results, output_dir, overwrite=False):
"""
Save cropped images based on the detection bounding boxes.
Expand All @@ -66,10 +68,12 @@ def save_crop_images(results, output_dir):
Detection results containing image ID and detections.
output_dir (str):
Directory to save the cropped images.
overwrite (bool):
Whether overwriting existing image folders. Default to False.
"""
assert isinstance(results, list)
os.makedirs(output_dir, exist_ok=True)
with sv.ImageSink(target_dir_path=output_dir, overwrite=True) as sink:
with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink:
for entry in results:
for i, (xyxy, _, _, cat, _) in enumerate(entry["detections"]):
cropped_img = sv.crop_image(
Expand Down
6 changes: 3 additions & 3 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
results = detection_model.single_image_detection(transform(img), img.shape, tgt_img_path)

# Saving the detection results
pw_utils.save_detection_images(results, os.path.join(".","demo_output"))
pw_utils.save_detection_images(results, os.path.join(".","demo_output"), overwrite=False)

#%% Batch detection
""" Batch-detection demo """
Expand All @@ -68,11 +68,11 @@

#%% Output to annotated images
# Saving the batch detection results as annotated images
pw_utils.save_detection_images(results, "batch_output")
pw_utils.save_detection_images(results, "batch_output", overwrite=False)

#%% Output to cropped images
# Saving the detected objects as cropped images
pw_utils.save_crop_images(results, "crop_output")
pw_utils.save_crop_images(results, "crop_output", overwrite=False)

#%% Output to JSON results
# Saving the detection results in JSON format
Expand Down
6 changes: 3 additions & 3 deletions demo/image_detection_colabdemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@
"transform = pw_trans.MegaDetector_v5_Transform(target_size=detection_model.IMAGE_SIZE,\n",
" stride=detection_model.STRIDE)\n",
"results = detection_model.single_image_detection(transform(img), img.shape, temp_file_path)\n",
"pw_utils.save_detection_images(results, \"./demo_output\")"
"pw_utils.save_detection_images(results, \"./demo_output\", overwrite=False)"
]
},
{
Expand Down Expand Up @@ -997,7 +997,7 @@
},
"outputs": [],
"source": [
"pw_utils.save_detection_images(results, \"batch_output\")"
"pw_utils.save_detection_images(results, \"batch_output\", overwrite=False)"
]
},
{
Expand All @@ -1020,7 +1020,7 @@
},
"outputs": [],
"source": [
"pw_utils.save_crop_images(results, \"crop_output\")"
"pw_utils.save_crop_images(results, \"crop_output\", overwrite=False)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions demo/image_detection_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"transform = pw_trans.MegaDetector_v5_Transform(target_size=detection_model.IMAGE_SIZE,\n",
" stride=detection_model.STRIDE)\n",
"results = detection_model.single_image_detection(transform(img), img.shape, tgt_img_path)\n",
"pw_utils.save_detection_images(results, os.path.join(\".\",\"demo_output\"))"
"pw_utils.save_detection_images(results, os.path.join(\".\",\"demo_output\"), overwrite=False)"
]
},
{
Expand Down Expand Up @@ -138,7 +138,7 @@
"metadata": {},
"outputs": [],
"source": [
"pw_utils.save_detection_images(results, \"batch_output\")"
"pw_utils.save_detection_images(results, \"batch_output\", overwrite=False)"
]
},
{
Expand All @@ -157,7 +157,7 @@
"metadata": {},
"outputs": [],
"source": [
"pw_utils.save_crop_images(results, \"crop_output\")"
"pw_utils.save_crop_images(results, \"crop_output\", overwrite=False)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = file.read()
setup(
name='PytorchWildlife',
version='1.0.2.11',
version='1.0.2.12',
packages=find_packages(),
url='https://github.com/microsoft/CameraTraps/',
license='MIT',
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.2.11
1.0.2.12

0 comments on commit 997dfb6

Please sign in to comment.