Skip to content

Commit

Permalink
add slurm config
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 25, 2024
1 parent c05ee2c commit e1ba9a9
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 12 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def update_sys_path():

def predict_and_divide(trained_detection_model, trained_classification_model, image_paths, patch_size, patch_overlap, confident_threshold):
predictions = detection.predict(
model=trained_detection_model,
m=trained_detection_model,
crop_model=trained_classification_model,
image_paths=image_paths,
patch_size=patch_size,
Expand Down
8 changes: 6 additions & 2 deletions src/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dask.array as da
import pandas as pd
from deepforest import main, visualize
from deepforest.utilities import read_file
from pytorch_lightning.loggers import CometLogger

# Local imports
Expand Down Expand Up @@ -161,6 +162,8 @@ def train(model, train_annotations, test_annotations, train_image_dir, comet_pro

with comet_logger.experiment.context_manager("train_images"):
non_empty_train_annotations = train_annotations[~(train_annotations.xmax==0)]
non_empty_train_annotations = read_file(non_empty_train_annotations, root_dir=train_image_dir)

if non_empty_train_annotations.empty:
pass
else:
Expand Down Expand Up @@ -245,10 +248,11 @@ def get_latest_checkpoint(checkpoint_dir, annotations):
else:
warn("No checkpoints found in {}".format(checkpoint_dir))
label_dict = {value: index for index, value in enumerate(annotations.label.unique())}
m = main.deepforest(config_file="Airplane/deepforest_config.yml", label_dict=label_dict)
m = main.deepforest(label_dict=label_dict)
else:
os.makedirs(checkpoint_dir)
m = main.deepforest(config_file="Airplane/deepforest_config.yml")
label_dict = {value: index for index, value in enumerate(annotations.label.unique())}
m = main.deepforest(label_dict=label_dict)

return m

Expand Down
1 change: 1 addition & 0 deletions src/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def import_image_tasks(label_studio_project, image_names, local_image_dir, predi

tasks = []
for index, image_name in enumerate(image_names):
print(f"Importing {image_name} into Label Studio")
data_dict = {'image': os.path.join("/data/local-files/?d=input/", os.path.basename(image_name))}
if predictions:
prediction = predictions[index]
Expand Down
24 changes: 16 additions & 8 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,11 @@ def run(self):
patch_overlap=self.config.active_testing.patch_overlap,
min_score=self.config.active_testing.min_score)


confident_predictions, uncertain_predictions = predict_and_divide(
trained_detection_model, trained_classification_model,
train_images_to_annotate, self.config.active_learning.patch_size,
self.config.active_learning.patch_overlap,
self.config.active_learning.confident_threshold)
self.config.pipeline.confidence_threshold)

reporter.confident_predictions = confident_predictions
reporter.uncertain_predictions = uncertain_predictions
Expand All @@ -116,11 +115,20 @@ def run(self):

# Align the predictions with the cropped images
# Run the annotation pipeline
label_studio.upload_to_label_studio(self.sftp_client,
uncertain_predictions,
**self.config)
label_studio.upload_to_label_studio(self.sftp_client,
test_images_to_annotate,
**self.config)
if len(image_paths) > 0:
label_studio.upload_to_label_studio(images=image_paths,
sftp_client=self.sftp_client,
label_studio_project=self.label_studio_project,
images_to_annotate_dir=self.config.active_learning.image_dir,
folder_name=self.config.label_studio.folder_name,
preannotations=uncertain_predictions
)

label_studio.upload_to_label_studio(images=test_images_to_annotate,
sftp_client=self.sftp_client,
label_studio_project=self.label_studio_project,
images_to_annotate_dir=self.config.active_testing.image_dir,
folder_name=self.config.label_studio.folder_name,
preannotations=None)
reporter.generate_report()

18 changes: 18 additions & 0 deletions submit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
#SBATCH --job-name=BOEM # Job name
#SBATCH --mail-type=END # Mail events
#SBATCH [email protected] # Where to send mail
#SBATCH --account=ewhite
#SBATCH --nodes=1 # Number of MPI ran
#SBATCH --cpus-per-task=1
#SBATCH --mem=150GB
#SBATCH --time=48:00:00 #Time limit hrs:min:sec
#SBATCH --output=/home/b.weinstein/logs/BOEM%j.out # Standard output and error log
#SBATCH --error=/home/b.weinstein/logs/BOEM%j.err
#SBATCH --partition=gpu
#SBATCH --gpus=1

source activate BOEM

cd ~/BOEM/
python main.py
2 changes: 1 addition & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def cleanup_label_studio(label_studio_client, request):
# Setup: yield to allow tests to run
yield


@pytest.mark.integration
def test_pipeline_run(config, label_studio_client):
"""Test complete pipeline run"""
pipeline = Pipeline(cfg=config)
Expand Down

0 comments on commit e1ba9a9

Please sign in to comment.