Skip to content

Commit

Permalink
target models
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 26, 2024
1 parent 99f23de commit 68c1fcd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
23 changes: 12 additions & 11 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ comet:
workspace: bw4sz

check_annotations: true
force_upload: false
# Force upload bypasses the pipeline, useful for debugging and starting a new project
force_upload: true
label_studio:
project_name: "Bureau of Ocean Energy Management"
url: "https://labelstudio.naturecast.org/"
Expand All @@ -16,7 +17,7 @@ label_studio:
csv_dir: /blue/ewhite/b.weinstein/BOEM/annotations

predict:
patch_size: 450
patch_size: 2000
patch_overlap: 0
min_score: 0.5

Expand Down Expand Up @@ -68,26 +69,26 @@ active_learning:
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
strategy: 'target-labels'
n_images: 10
patch_size: 256
patch_overlap: 0.5
min_score: 0.5
patch_size: 2000
patch_overlap: 0
min_score: 0.2
model_checkpoint:
target_labels:
- "Bird"

# Optional parameters:
evaluation:
dask_client:
pool_limit: 100
pool_limit: 500

active_testing:
image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
strategy: 'random'
n_images: 100
m: 10
patch_size: 256
patch_overlap: 0.5
min_score: 0.5
n_images: 10
m:
patch_size: 2000
patch_overlap: 0
min_score: 0.2

deepforest:
train:
Expand Down
8 changes: 4 additions & 4 deletions src/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def choose_train_images(evaluation, image_dir, strategy, n=10, patch_size=512, p
return chosen_images
elif strategy in ["most-detections","target-labels"]:
# Predict all images
if model_path is None:
raise ValueError("A model is required for the 'most-detections' or 'target-labels' strategy.")
if model_path is None and model is None:
raise ValueError("A model is required for the 'most-detections' or 'target-labels' strategy. Either pass a model or a model_path.")
if dask_client:
# load model on each client
def update_sys_path():
Expand All @@ -59,7 +59,7 @@ def update_sys_path():
blocks = dask_pool.to_delayed().ravel()
block_futures = []
for block in blocks:
block_future = dask_client.submit(detection.predict,image_paths=block.compute(), patch_size=patch_size, patch_overlap=patch_overlap, min_score=min_score, model_path=model_path)
block_future = dask_client.submit(detection.predict,image_paths=block.compute(), patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path)
block_futures.append(block_future)
# Get results
dask_results = []
Expand All @@ -68,7 +68,7 @@ def update_sys_path():
dask_results.append(pd.concat(block_result))
preannotations = pd.concat(dask_results)
else:
preannotations = detection.predict(model=model, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap, min_score=min_score)
preannotations = detection.predict(m=model, image_paths=pool, patch_size=patch_size, patch_overlap=patch_overlap)
preannotations = pd.concat(preannotations)

if strategy == "most-detections":
Expand Down
28 changes: 16 additions & 12 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src import label_studio
from src.classification import preprocess_and_train_classification
from src.data_processing import density_cropping
from src.detection import preprocess_and_train
from src.detection import preprocess_and_train, load
from src.pipeline_evaluation import PipelineEvaluation
from src.reporting import Reporting

Expand Down Expand Up @@ -47,16 +47,19 @@ def run(self):
if new_annotations is None:
print("No new annotations, exiting")
if self.config.force_upload:
image_paths = glob.glob(os.path.join(self.config.label_studio.images_to_annotate_dir, "*.jpg"))
image_paths = random.sample(image_paths, 10)
label_studio.upload_to_label_studio(images=image_paths,
sftp_client=self.sftp_client,
label_studio_project=self.label_studio_project,
images_to_annotate_dir=self.config.label_studio.images_to_annotate_dir,
folder_name=self.config.label_studio.folder_name,
preannotations=None
)
return None
detection_model = load(self.config.detection_model.checkpoint)
train_images_to_annotate = choose_train_images(
evaluation=None,
image_dir=self.config.active_learning.image_dir,
model=detection_model,
strategy=self.config.active_learning.strategy,
n=self.config.active_learning.n_images,
patch_size=self.config.active_learning.patch_size,
patch_overlap=self.config.active_learning.patch_overlap,
min_score=self.config.active_learning.min_score,
target_labels=self.config.active_learning.target_labels
)
return None
# Select images to upload

return None
Expand Down Expand Up @@ -106,7 +109,8 @@ def run(self):
n=self.config.active_learning.n_images,
patch_size=self.config.active_learning.patch_size,
patch_overlap=self.config.active_learning.patch_overlap,
min_score=self.config.active_learning.min_score
min_score=self.config.active_learning.min_score,
target_labels=self.config.active_learning.target_labels
)

test_images_to_annotate = choose_test_images(
Expand Down

0 comments on commit 68c1fcd

Please sign in to comment.