From 4985fc8f0fde08121c98230187882ba43044f206 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Fri, 17 Jan 2025 22:20:52 -0500 Subject: [PATCH] profiled predict --- conf/config.yaml | 5 +++-- src/classification.py | 46 ++++++++++++++++++++++++++++++++++--------- src/detection.py | 21 +++++++++----------- src/pipeline.py | 2 +- 4 files changed, 50 insertions(+), 24 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 3d0712a..356f044 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -39,10 +39,11 @@ detection_model: train_csv_folder: /blue/ewhite/b.weinstein/BOEM/annotations/train/ train_image_dir: /blue/ewhite/b.weinstein/BOEM/sample_flight/JPG_2024_Jan27/annotated crop_image_dir: /blue/ewhite/b.weinstein/BOEM/detection/crops/ - limit_empty_frac: 0.2 + limit_empty_frac: 0.01 labels: - "Bird" trainer: + batch_size: 4 train: fast_dev_run: False epochs: 10 @@ -59,7 +60,7 @@ classification_model: crop_image_dir: /blue/ewhite/b.weinstein/BOEM/classification/crops/ under_sample_ratio: 0 trainer: - fast_dev_run: False + fast_dev_run: True max_epochs: 1 lr: 0.00001 diff --git a/src/classification.py b/src/classification.py index 29ba044..cd62580 100644 --- a/src/classification.py +++ b/src/classification.py @@ -14,29 +14,47 @@ def create_train_test(annotations): return annotations.sample(frac=0.8, random_state=1), annotations.drop( annotations.sample(frac=0.8, random_state=1).index) -def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001): +def get_latest_checkpoint(checkpoint_dir, annotations, lr=0.0001, num_classes=None): #Get model with latest checkpoint dir, if none exist make a new model if os.path.exists(checkpoint_dir): checkpoints = glob.glob(os.path.join(checkpoint_dir,"*.ckpt")) if len(checkpoints) > 0: checkpoints.sort() checkpoint = checkpoints[-1] - m = CropModel.load_from_checkpoint(checkpoint) + try: + m = CropModel.load_from_checkpoint(checkpoint) + except Exception as e: + warnings.warn("Could not load model from checkpoint, {}".format(e)) + if num_classes: + m = CropModel(num_classes=num_classes, lr=lr) + else: + m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr) else: warnings.warn("No checkpoints found in {}".format(checkpoint_dir)) - m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr) + if num_classes: + m = CropModel(num_classes=num_classes, lr=lr) + else: + m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr) else: os.makedirs(checkpoint_dir) - m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr) + if num_classes: + m = CropModel(num_classes=num_classes, lr=lr) + else: + m = CropModel(num_classes=len(annotations["label"].unique()), lr=lr) return m -def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001): +def load(checkpoint=None, annotations=None, checkpoint_dir=None, lr=0.0001, num_classes=None): if checkpoint: - loaded_model = CropModel(checkpoint, num_classes=len(annotations["label"].unique()), lr=lr) + if num_classes: + loaded_model = CropModel(checkpoint, num_classes=num_classes, lr=lr) + else: + loaded_model = CropModel(checkpoint, num_classes=len(annotations["label"].unique()), lr=lr) elif checkpoint_dir: loaded_model = get_latest_checkpoint( - checkpoint_dir, annotations) + checkpoint_dir, + num_classes=num_classes, + annotations=annotations) else: raise ValueError("No checkpoint or checkpoint directory found.") @@ -80,7 +98,7 @@ def preprocess_images(model, annotations, root_dir, save_dir): labels = annotations["label"].values model.write_crops(boxes=boxes, root_dir=root_dir, images=images, labels=labels, savedir=save_dir) -def preprocess_and_train_classification(config, validation_df=None): +def preprocess_and_train_classification(config, validation_df=None, num_classes=None): """Preprocess data and train a crop model. Args: @@ -92,6 +110,10 @@ def preprocess_and_train_classification(config, validation_df=None): # Get and split annotations annotations = gather_data(config.classification_model.train_csv_folder) + # Remove the empty frames + annotations = annotations[~(annotations.label.astype(str)== "0")] + annotations = annotations[annotations.label != "FalsePositive"] + if validation_df is None: train_df, validation_df = create_train_test(annotations) else: @@ -99,7 +121,13 @@ def preprocess_and_train_classification(config, validation_df=None): isin(validation_df["image_path"])] # Load existing model - loaded_model = load(checkpoint=config.classification_model.checkpoint, checkpoint_dir=config.classification_model.checkpoint_dir, annotations=annotations, lr=config.classification_model.trainer.lr) + loaded_model = load( + checkpoint=config.classification_model.checkpoint, + checkpoint_dir=config.classification_model.checkpoint_dir, + annotations=annotations, + lr=config.classification_model.trainer.lr, + num_classes=num_classes + ) # Preprocess train and validation data preprocess_images( diff --git a/src/detection.py b/src/detection.py index 6a7f4eb..6e084ab 100644 --- a/src/detection.py +++ b/src/detection.py @@ -219,6 +219,9 @@ def preprocess_and_train(config, model_type="detection"): # Get and split annotations train_df = gather_data(config.detection_model.train_csv_folder) validation = gather_data(config.label_studio.csv_dir_validation) + + if config.detection_model.limit_empty_frac > 0: + validation = limit_empty_frames(validation, config.detection_model.limit_empty_frac) validation.loc[validation.label==0,"label"] = "Bird" @@ -246,15 +249,9 @@ def preprocess_and_train(config, model_type="detection"): allow_empty=True ) validation_df.loc[validation_df.label==0,"label"] = "Bird" - - # Limit empty frames - if config.detection_model.limit_empty_frac > 0: - train_df = limit_empty_frames(train_df, config.detection_model.limit_empty_frac) - if not validation_df.empty: - #validation_df = limit_empty_frames(validation_df, config.detection_model.limit_empty_frac) - # DeepForest evaluate doesn't work with empty frames yet, see https://github.com/weecology/DeepForest/pull/858 - validation_df = validation_df[validation_df.xmin!=0] - + non_empty = validation_df[(validation_df.xmin!=0)] + empty = validation_df[validation_df.xmin==0] + validation_df = pd.concat([empty.head(1), non_empty]) # Train model # Load existing model @@ -297,7 +294,7 @@ def get_latest_checkpoint(checkpoint_dir, annotations): return m -def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, crop_model=None, batch_size=64): +def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, crop_model=None, batch_size=16): if model_path: m = load(model_path) else: @@ -315,7 +312,7 @@ def _predict_list_(image_paths, patch_size, patch_overlap, model_path, m=None, c return predictions -def predict(image_paths, patch_size, patch_overlap, m=None, model_path=None, dask_client=None, crop_model=None, batch_size=8): +def predict(image_paths, patch_size, patch_overlap, m=None, model_path=None, dask_client=None, crop_model=None, batch_size=16): """Predict bounding boxes for images Args: m (main.deepforest): A trained deepforest model. @@ -352,6 +349,6 @@ def update_sys_path(): block_result = block_result.result() predictions.append(pd.concat(block_result)) else: - predictions = _predict_list_(image_paths=image_paths, patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path, m=m, crop_model=crop_model, batch_size=batch_size) + predictions = _predict_list_(image_paths=image_paths, patch_size=patch_size, patch_overlap=patch_overlap, model_path=model_path, m=m, crop_model=None, batch_size=batch_size) return predictions diff --git a/src/pipeline.py b/src/pipeline.py index ed4ca95..0424ba4 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -86,7 +86,7 @@ def run(self): self.config) trained_classification_model = classification.preprocess_and_train_classification( - self.config) + self.config, num_classes=len(trained_detection_model.label_dict)) detection_checkpoint_path = self.save_model(trained_detection_model, self.config.detection_model.checkpoint_dir)