Skip to content

Commit

Permalink
profiled predict
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Jan 18, 2025
1 parent c52bcc3 commit 4985fc8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 24 deletions.
5 changes: 3 additions & 2 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ detection_model:
train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/train/
train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/detection/crops/
limit_empty_frac: 0.2
limit_empty_frac: 0.01
labels:
- "Bird"
trainer:
batch_size: 4
train:
fast_dev_run: False
epochs: 10
Expand All @@ -59,7 +60,7 @@ classification_model:
crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/
under_sample_ratio: 0
trainer:
fast_dev_run: False
fast_dev_run: True
max_epochs: 1
lr: 0.00001

Expand Down
46 changes: 37 additions & 9 deletions src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,47 @@ def create_train_test(annotations):
return annotations.sample(frac=0.8, random_state=1), annotations.drop(
annotations.sample(frac=0.8, random_state=1).index)

def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001):
def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001, num_classes=None):
#Get model with latest checkpoint dir, if none exist make a new model
if os.path.exists(checkpoint_dir):
checkpoints = glob.glob(os.path.join(checkpoint_dir,"*.ckpt"))
if len(checkpoints) > 0:
checkpoints.sort()
checkpoint = checkpoints[-1]
m = CropModel.load_from_checkpoint(checkpoint)
try:
m = CropModel.load_from_checkpoint(checkpoint)
except Exception as e:
warnings.warn("Could not load model from checkpoint, {}".format(e))
if num_classes:
m = CropModel(num_classes=num_classes, lr=lr)
else:
m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr)
else:
warnings.warn("No checkpoints found in {}".format(checkpoint_dir))
m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr)
if num_classes:
m = CropModel(num_classes=num_classes, lr=lr)
else:
m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr)
else:
os.makedirs(checkpoint_dir)
m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr)
if num_classes:
m = CropModel(num_classes=num_classes, lr=lr)
else:
m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr)

return m

def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001):
def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001, num_classes=None):
if checkpoint:
loaded_model = CropModel(checkpoint, num_classes=len(annotations["label"].unique()), lr=lr)
if num_classes:
loaded_model = CropModel(checkpoint, num_classes=num_classes, lr=lr)
else:
loaded_model = CropModel(checkpoint, num_classes=len(annotations["label"].unique()), lr=lr)
elif checkpoint_dir:
loaded_model = get_latest_checkpoint(
checkpoint_dir, annotations)
checkpoint_dir,
num_classes=num_classes,
annotations=annotations)
else:
raise ValueError("No checkpoint or checkpoint directory found.")

Expand Down Expand Up @@ -80,7 +98,7 @@ def preprocess_images(model, annotations, root_dir, save_dir):
labels = annotations["label"].values
model.write_crops(boxes=boxes, root_dir=root_dir, images=images, labels=labels, savedir=save_dir)

def preprocess_and_train_classification(config, validation_df=None):
def preprocess_and_train_classification(config, validation_df=None, num_classes=None):
"""Preprocess data and train a crop model.
Args:
Expand All @@ -92,14 +110,24 @@ def preprocess_and_train_classification(config, validation_df=None):
# Get and split annotations
annotations = gather_data(config.classification_model.train_csv_folder)

# Remove the empty frames
annotations = annotations[~(annotations.label.astype(str)== "0")]
annotations = annotations[annotations.label != "FalsePositive"]

if validation_df is None:
train_df, validation_df = create_train_test(annotations)
else:
train_df = annotations[~annotations["image_path"].
isin(validation_df["image_path"])]

# Load existing model
loaded_model = load(checkpoint=config.classification_model.checkpoint, checkpoint_dir=config.classification_model.checkpoint_dir, annotations=annotations, lr=config.classification_model.trainer.lr)
loaded_model = load(
checkpoint=config.classification_model.checkpoint,
checkpoint_dir=config.classification_model.checkpoint_dir,
annotations=annotations,
lr=config.classification_model.trainer.lr,
num_classes=num_classes
)

# Preprocess train and validation data
preprocess_images(
Expand Down
21 changes: 9 additions & 12 deletions src/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def preprocess_and_train(config, model_type="detection"):
# Get and split annotations
train_df = gather_data(config.detection_model.train_csv_folder)
validation = gather_data(config.label_studio.csv_dir_validation)

if config.detection_model.limit_empty_frac > 0:
validation = limit_empty_frames(validation, config.detection_model.limit_empty_frac)

validation.loc[validation.label==0,"label"] = "Bird"

Expand Down Expand Up @@ -246,15 +249,9 @@ def preprocess_and_train(config, model_type="detection"):
allow_empty=True
)
validation_df.loc[validation_df.label==0,"label"] = "Bird"

# Limit empty frames
if config.detection_model.limit_empty_frac > 0:
train_df = limit_empty_frames(train_df, config.detection_model.limit_empty_frac)
if not validation_df.empty:
#validation_df = limit_empty_frames(validation_df, config.detection_model.limit_empty_frac)
# DeepForest evaluate doesn't work with empty frames yet, see https://github.com/weecology/DeepForest/pull/858
validation_df = validation_df[validation_df.xmin!=0]

non_empty = validation_df[(validation_df.xmin!=0)]
empty = validation_df[validation_df.xmin==0]
validation_df = pd.concat([empty.head(1), non_empty])

# Train model
# Load existing model
Expand Down Expand Up @@ -297,7 +294,7 @@ def get_latest_checkpoint(checkpoint_dir, annotations):

return m

def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, crop_model=None, batch_size=64):
def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, crop_model=None, batch_size=16):
if model_path:
m = load(model_path)
else:
Expand All @@ -315,7 +312,7 @@ def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, c

return predictions

def predict(image_paths, patch_size, patch_overlap, m=None, model_path=None, dask_client=None, crop_model=None, batch_size=8):
def predict(image_paths, patch_size, patch_overlap, m=None, model_path=None, dask_client=None, crop_model=None, batch_size=16):
"""Predict bounding boxes for images
Args:
m (main.deepforest): A trained deepforest model.
Expand Down Expand Up @@ -352,6 +349,6 @@ def update_sys_path():
block_result = block_result.result()
predictions.append(pd.concat(block_result))
else:
predictions = _predict_list_(image_paths=image_paths, patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path, m=m, crop_model=crop_model, batch_size=batch_size)
predictions = _predict_list_(image_paths=image_paths, patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path, m=m, crop_model=None, batch_size=batch_size)

return predictions
2 changes: 1 addition & 1 deletion src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(self):
self.config)

trained_classification_model = classification.preprocess_and_train_classification(
self.config)
self.config, num_classes=len(trained_detection_model.label_dict))

detection_checkpoint_path = self.save_model(trained_detection_model,
self.config.detection_model.checkpoint_dir)
Expand Down

0 comments on commit 4985fc8

Please sign in to comment.