Skip to content

Commit

Permalink
need to find a way to make the ssh key integration tests seperate
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 26, 2024
1 parent 063044a commit 99f23de
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 55 deletions.
50 changes: 27 additions & 23 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
defaults:
- server: serenity

comet:
project: BOEM
workspace: bw4sz

check_annotations: false
check_annotations: true
force_upload: false
label_studio:
project_name: "Bureau of Ocean Energy Management"
url: "https://labelstudio.naturecast.org/"
api_key: "${oc.env:LABEL_STUDIO_API_KEY}"
folder_name: "/pgsql/retrieverdash/everglades-label-studio/everglades-data"
images_to_annotate_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27
annotated_images_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
csv_dir: /blue/ewhite/b.weinstein/BOEM/annotations

predict:
patch_size: 450
Expand All @@ -27,28 +30,28 @@ propagate:

detection_model:
checkpoint: bird
checkpoint_dir:
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/detection/checkpoints
validation_csv_path:
train_csv_folder:
train_image_dir:
crop_image_dir:
train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/detection/crops/
limit_empty_frac: 0
fast_dev_run: false
labels:
- "Bird"

classification_model:
checkpoint:
checkpoint_dir:
validation_csv_path:
train_csv_folder:
train_image_dir:
crop_image_dir:
checkpoint:
checkpoint_dir: /blue/ewhite/b.weinstein/BOEM/classification/checkpoints
validation_csv_path:
train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/
under_sample_ratio: 0
fast_dev_run: false
fast_dev_run: True

pipeline_evaluation:
detect_ground_truth_dir:
detect_ground_truth_dir:
classify_confident_ground_truth_dir:
classify_uncertain_ground_truth_dir:
# This is an average mAP threshold for now, but we may want to add a per-iou threshold in the future
Expand All @@ -62,23 +65,23 @@ reporting:
report_dir:

active_learning:
image_dir:
strategy: 'random'
n_images: 100
m: 10
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
strategy: 'target-labels'
n_images: 10
patch_size: 256
patch_overlap: 0.5
min_score: 0.5
model_checkpoint:
target_labels:
- "Bird"

# Optional parameters:
evaluation: null
dask_client: null
pool_limit: null
evaluation:
dask_client:
pool_limit: 100

active_testing:
image_dir:
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
strategy: 'random'
n_images: 100
m: 10
Expand All @@ -89,4 +92,5 @@ active_testing:
deepforest:
train:
fast_dev_run: True
epochs: 1
workers: 0
2 changes: 1 addition & 1 deletion conf/server/serenity.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Remote server for image hosting
user: 'ben'
host: 'serenity.ifas.ufl.edu'
key_filename: '/Users/benweinstein/.ssh/id_rsa.pub'
key_filename: '/home/b.weinstein/.ssh/id_rsa.pub'
13 changes: 8 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import hydra
import os
from omegaconf import DictConfig
from src.pipeline import Pipeline
from src.label_studio import get_api_key

@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
"""Main entry point for the application"""
api_key = get_api_key()
os.environ["LABEL_STUDIO_API_KEY"] = api_key
if api_key is None:
print("Warning: No Label Studio API key found in .comet.config")
return None

# Initialize and run pipeline
pipeline = Pipeline(cfg=cfg)
results = pipeline.run(model_path=cfg.model.path)

# Log results
print(f"Images needing review: {len(results['needs_review'])}")
print(f"Images not needing review: {len(results['no_review_needed'])}")
results = pipeline.run()

if __name__ == "__main__":
main()
34 changes: 15 additions & 19 deletions src/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,24 @@ def upload_to_label_studio(images, sftp_client, label_studio_project, images_to_
upload_images(sftp_client=sftp_client, images=images, folder_name=folder_name)
import_image_tasks(label_studio_project=label_studio_project, image_names=images, local_image_dir=images_to_annotate_dir, predictions=preannotations)

def check_for_new_annotations(user, host, key_filename, label_studio_url, label_studio_project_name, train_csv_folder, images_to_annotate_dir, annotated_images_dir, folder_name):
def check_for_new_annotations(sftp_client, url, project_name, csv_dir, images_to_annotate_dir, annotated_images_dir, folder_name):
"""
Check for new annotations from Label Studio, move annotated images, and gather new images to annotate.
Args:
user (str): The username for the SFTP connection.
host (str): The host URL for the SFTP connection.
key_filename (str): The path to the SSH key file for the SFTP connection.
label_studio_url (str): The URL of the Label Studio server.
label_studio_project_name (str): The name of the Label Studio project.
train_csv_folder (str): The path to the folder containing training CSV files.
sftp_client (paramiko.SFTPClient): The SFTP client for downloading images.
url (str): The URL of the Label Studio server.
project_name (str): The name of the Label Studio project.
csv_dir (str): The path to the folder containing CSV files.
images_to_annotate_dir (str): The path to the directory of images to annotate.
annotated_images_dir (str): The path to the directory of annotated images.
folder_name (str): The name of the folder to upload images to.
filter_labels (list, optional): A list of labels to filter images by. Defaults to None.
Returns:
DataFrame: A DataFrame containing the gathered annotations.
"""
sftp_client = create_sftp_client(user=user, host=host, key_filename=key_filename)
label_studio_project = connect_to_label_studio(url=label_studio_url, project_name=label_studio_project_name)
new_annotations = download_completed_tasks(label_studio_project=label_studio_project, train_csv_folder=train_csv_folder)
label_studio_project = connect_to_label_studio(url=url, project_name=project_name)
new_annotations = download_completed_tasks(label_studio_project=label_studio_project, csv_dir=csv_dir)

# Move annotated images out of local pool
if new_annotations is not None:
Expand All @@ -63,7 +59,7 @@ def check_for_new_annotations(user, host, key_filename, label_studio_url, label_
return None

# Choose new images to annotate
label_studio_annotations = gather_data(train_csv_folder)
label_studio_annotations = gather_data(csv_dir)

return label_studio_annotations

Expand Down Expand Up @@ -266,7 +262,7 @@ def import_image_tasks(label_studio_project, image_names, local_image_dir, predi
tasks.append(upload_dict)
label_studio_project.import_tasks(tasks)

def download_completed_tasks(label_studio_project, train_csv_folder):
def download_completed_tasks(label_studio_project, csv_dir):
labeled_tasks = label_studio_project.get_labeled_tasks()
if not labeled_tasks:
print("No new annotations")
Expand All @@ -280,11 +276,11 @@ def download_completed_tasks(label_studio_project, train_csv_folder):
if len(label_json) == 0:
result = {
"image_path": image_path,
"xmin": None,
"ymin": None,
"xmax": None,
"ymax": None,
"label": None,
"xmin": 0,
"ymin": 0,
"xmax": 0,
"ymax": 0,
"label": 0,
"annotator":labeled_task["annotations"][0]["created_username"]
}
result = pd.DataFrame(result, index=[0])
Expand All @@ -302,7 +298,7 @@ def download_completed_tasks(label_studio_project, train_csv_folder):

# Save csv in dir with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
train_path = os.path.join(train_csv_folder, "train_{}.csv".format(timestamp))
train_path = os.path.join(csv_dir, "train_{}.csv".format(timestamp))
annotations.to_csv(train_path, index=False)

return annotations
Expand Down
13 changes: 10 additions & 3 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,23 @@ def run(self):
# Check for new annotations if the check_annotations flag is set
if self.config.check_annotations:
new_annotations = label_studio.check_for_new_annotations(
**self.config.label_studio)
sftp_client=self.sftp_client,
url=self.config.label_studio.url,
csv_dir=self.config.label_studio.csv_dir,
project_name=self.config.label_studio.project_name,
folder_name=self.config.label_studio.folder_name,
images_to_annotate_dir=self.config.label_studio.images_to_annotate_dir,
annotated_images_dir=self.config.label_studio.annotated_images_dir,
)
if new_annotations is None:
print("No new annotations, exiting")
if self.config.force_upload:
image_paths = glob.glob(os.path.join(self.config.active_learning.image_dir, "*.jpg"))
image_paths = glob.glob(os.path.join(self.config.label_studio.images_to_annotate_dir, "*.jpg"))
image_paths = random.sample(image_paths, 10)
label_studio.upload_to_label_studio(images=image_paths,
sftp_client=self.sftp_client,
label_studio_project=self.label_studio_project,
images_to_annotate_dir=self.config.active_learning.image_dir,
images_to_annotate_dir=self.config.label_studio.images_to_annotate_dir,
folder_name=self.config.label_studio.folder_name,
preannotations=None
)
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def config(tmpdir_factory):
# Create sample bounding box annotations
train_data = {
'image_path': ['empty.jpg', 'birds.jpg', "birds.jpg"],
'xmin': [0, 200, 150],
'ymin': [0, 300, 250],
'xmax': [20, 300, 250],
'ymax': [20, 400, 350],
'xmin': [None, 200, 150],
'ymin': [None, 300, 250],
'xmax': [None, 300, 250],
'ymax': [None, 400, 350],
'label': ['Bird', 'Bird1', 'Bird2'],
'annotator': ['test_user', 'test_user', 'test_user']
}
Expand Down

0 comments on commit 99f23de

Please sign in to comment.