Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One instance dets #6

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c375b93
Added support for config
JakPing May 9, 2022
c03d5ba
Added neptune integration
JakPing May 9, 2022
c56359c
artifact logging
JakPing May 10, 2022
a0f589c
artifact logging
JakPing May 10, 2022
dbc43f8
artifact logging
JakPing May 10, 2022
3fe9cd2
refactor
JakPing May 10, 2022
586d3ed
refactor
JakPing May 10, 2022
da1f121
add more metrics
JakPing May 12, 2022
9415f61
bug fix
JakPing May 12, 2022
9e7c5b8
fixes in training and neptune logging
May 19, 2022
d26a6ca
validation loss logging
Aditya-Bobade May 19, 2022
e53d2bc
flag for validation loss logging
Aditya-Bobade May 20, 2022
da181cf
average validation loss logging
Aditya-Bobade May 20, 2022
47eccab
remove average validation loss logging
Aditya-Bobade May 20, 2022
1c31df4
mosaic_prob !=1 bug fix
May 24, 2022
c8fcf89
adding copy paste augmentations
May 31, 2022
3343d07
refactor
May 31, 2022
4ab2e63
improving speed
Jun 1, 2022
1a8b0dc
added postprocessing
Jun 3, 2022
d7a1ad1
changed integration from neptune to mlflow
Jun 15, 2022
b2f55a0
added exp connection
Jun 23, 2022
49b7d4b
parametrized mlflow connection
Jun 30, 2022
3d37e1c
Merge branch 'mlflow' of https://github.com/b-yond-infinite-network/Y…
Jun 30, 2022
dd91165
added configurable exp name
Jun 30, 2022
58e19c0
changes for classes by config refactor
Jul 4, 2022
01279e1
repair bug
Jul 4, 2022
b931501
fixes regarding training
Jul 5, 2022
be92108
fix errors in true divide
Jul 5, 2022
e7e9a23
install specific library version
Aditya-Bobade Jul 5, 2022
7c346ad
fixing returning only one instance of class
Jul 6, 2022
2d8bce3
reverts changeds regardin postprocess (moved to end of pipeline)
Jul 19, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# TODO: Update with exact module version
numpy
torch>=1.7
torch==1.11.0
opencv_python
loguru
scikit-image
tqdm
torchvision
torchvision==0.12.0
Pillow
thop
ninja
Expand Down
71 changes: 60 additions & 11 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import argparse
import random
import warnings
import mlflow
from loguru import logger

import yaml
import torch
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -41,6 +43,12 @@ def make_parser():
type=str,
help="plz input your experiment description file",
)
parser.add_argument(
"--config_filepath",
default=None,
type=str,
help="Filepath to config file",
)
parser.add_argument(
"--resume", default=False, action="store_true", help="resume training"
)
Expand Down Expand Up @@ -93,11 +101,25 @@ def make_parser():
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"-ml",
"--mlflow_url",
type=str,
help="MLFlow instance url for logging metrics and files.",
default=None
)
parser.add_argument(
"-mlex",
"--mlflow_experiment_name",
type=str,
help="Experiment name to log metrics and files",
default=None
)
return parser


@logger.catch
def main(exp, args):
def main(exp, run, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
Expand All @@ -113,28 +135,55 @@ def main(exp, args):
configure_omp()
cudnn.benchmark = True

trainer = Trainer(exp, args)
trainer = Trainer(exp, run, args)
trainer.train()


if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
run = None
if args.mlflow_url is not None and args.mlflow_experiment_name is not None:
mlflow.set_tracking_uri(args.mlflow_url)
experiment = mlflow.get_experiment_by_name(args.mlflow_experiment_name)
run = mlflow.start_run(experiment_id=experiment.experiment_id)

exp.merge(args.opts)

if not args.experiment_name:
args.experiment_name = exp.exp_name

if args.config_filepath is not None:
with open(args.config_filepath, "r") as f:
config = yaml.safe_load(f)
num_gpu = get_num_devices() if args.devices is None else args.devices
assert num_gpu <= get_num_devices()

dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=dist_url,
args=(exp, args),
)
if run is not None:
with run:
if args.config_filepath is not None:
mlflow.log_artifact(args.config_filepath, 'config_file')
exp.run = run
exp.add_params_from_config(config, use_mlflow=True)
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=dist_url,
args=(exp, run, args),
)
else:
if args.config_filepath is not None:
exp.add_params_from_config(config)
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=dist_url,
args=(exp, run, args),
)
59 changes: 54 additions & 5 deletions yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import os
import time
import mlflow
from loguru import logger

import torch
Expand Down Expand Up @@ -33,12 +34,12 @@


class Trainer:
def __init__(self, exp, args):
def __init__(self, exp, run, args):
# init function only defines some basic attr, other attrs like model, optimizer are built in
# before_train methods.
self.exp = exp
self.args = args

self.run = run
# training related attr
self.max_epoch = exp.max_epoch
self.amp_training = args.fp16
Expand All @@ -55,6 +56,9 @@ def __init__(self, exp, args):
self.input_size = exp.input_size
self.best_ap = 0

# validation loss
self.calc_validation_loss = exp.calc_val_loss

# metric record
self.meter = MeterBuffer(window_size=exp.print_interval)
self.file_name = os.path.join(exp.output_dir, args.experiment_name)
Expand Down Expand Up @@ -114,6 +118,8 @@ def train_one_iter(self):
self.ema_model.update(self.model)

lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
if self.run is not None:
mlflow.log_metric("lr", lr)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -151,10 +157,21 @@ def before_train(self):
no_aug=self.no_aug,
cache_img=self.args.cache,
)
if self.calc_validation_loss:
self.val_loader = self.exp.get_val_loader(
batch_size=self.args.batch_size,
is_distributed=self.is_distributed,
no_aug=False,
cache_img=self.args.cache,
)
logger.info("init prefetcher, this might take one minute or less...")
self.prefetcher = DataPrefetcher(self.train_loader)
if self.calc_validation_loss:
self.val_prefetcher = DataPrefetcher(self.val_loader)
# max_iter means iters per epoch
self.max_iter = len(self.train_loader)
if self.calc_validation_loss:
self.max_val_iter = len(self.val_loader)

self.lr_scheduler = self.exp.get_lr_scheduler(
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
Expand Down Expand Up @@ -243,7 +260,9 @@ def after_iter(self):
loss_str = ", ".join(
["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()]
)

for loss_name, loss_value in loss_meter.items():
if self.run is not None:
mlflow.log_metric(f"loss/{loss_name}", loss_value.latest)
time_meter = self.meter.get_filtered_meter("time")
time_str = ", ".join(
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
Expand Down Expand Up @@ -313,6 +332,10 @@ def resume_train(self, model):
return model

def evaluate_and_save_model(self):
# calculate loss
if self.calc_validation_loss:
self.calculate_eval_loss()

if self.use_model_ema:
evalmodel = self.ema_model.ema
else:
Expand All @@ -327,7 +350,8 @@ def evaluate_and_save_model(self):

update_best_ckpt = ap50_95 > self.best_ap
self.best_ap = max(self.best_ap, ap50_95)

if self.run is not None:
mlflow.log_metric(f"metrics/best_ap", self.best_ap)
if self.rank == 0:
if self.args.logger == "tensorboard":
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
Expand Down Expand Up @@ -360,7 +384,32 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
update_best_ckpt,
self.file_name,
ckpt_name,
self.run,
)

if self.args.logger == "wandb":
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)

def calculate_eval_loss(self):
for iter in range(self.max_val_iter):
inps, targets = self.val_prefetcher.next()
inps = inps.to(self.data_type)
targets = targets.to(self.data_type)
targets.requires_grad = False
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
with torch.cuda.amp.autocast(enabled=self.amp_training):
outputs = self.model(inps, targets)
loss = {
"total_loss": outputs["total_loss"],
"iou_loss": outputs["iou_loss"],
"l1_loss": outputs["l1_loss"],
"conf_loss": outputs["conf_loss"],
"cls_loss": outputs["cls_loss"]
}
progress_str = "epoch: {}/{}, iter: {}/{},".format(
self.epoch + 1, self.max_epoch, iter + 1, self.max_val_iter
)
for loss_name, loss_value in loss.items():
progress_str += " {}: {:.1f},".format(loss_name, loss_value)
if self.run is not None:
mlflow.log_metric(f"loss/val/{loss_name}", loss_value)
logger.info("Validation:{}".format(progress_str))
26 changes: 26 additions & 0 deletions yolox/data/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ def preproc(img, input_size, swap=(2, 0, 1)):
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r

def copy_paste(img, paste_img, labels, paste_labels, prob=0.5, obj_proc=0.5):
if random.random() > prob:
img_h, img_w = img.shape[:2]
paste_labels = paste_labels.astype(int)
objects_to_paste = paste_labels[random.sample(
range(0, len(paste_labels) - 1), int(len(paste_labels) * obj_proc)
)]
if len(objects_to_paste) == 0:
return img, labels
new_labels = []
for obj in objects_to_paste:
cropped_obj = paste_img[obj[1]:obj[3], obj[0]:obj[2]]
if random.random() > 0.5:
cropped_obj = cropped_obj[:,::-1]
new_x_min = random.randint(0, img_w - (obj[2] - obj[0]))
new_y_min = random.randint(0, img_h - (obj[3] - obj[1]))
new_x_max = new_x_min + (obj[2] - obj[0])
new_y_max = new_y_min + (obj[3] - obj[1])
new_labels.append(np.array([
new_x_min, new_y_min, new_x_max, new_y_max, obj[4]
]))
img[new_y_min:new_y_max, new_x_min:new_x_max] = cropped_obj
labels = np.append(labels, new_labels, 0)
return img, labels



class TrainTransform:
def __init__(self, max_labels=50, flip_prob=0.5, hsv_prob=1.0):
Expand Down
25 changes: 23 additions & 2 deletions yolox/data/datasets/mosaicdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import cv2
import numpy as np
import torch

from yolox.utils import adjust_box_anns, get_local_rank

from yolox.data.data_augment import copy_paste
from ..data_augment import random_affine
from .datasets_wrapper import Dataset

Expand Down Expand Up @@ -41,7 +42,8 @@ def __init__(
self, dataset, img_size, mosaic=True, preproc=None,
degrees=10.0, translate=0.1, mosaic_scale=(0.5, 1.5),
mixup_scale=(0.5, 1.5), shear=2.0, enable_mixup=True,
mosaic_prob=1.0, mixup_prob=1.0, *args
mosaic_prob=1.0, mixup_prob=1.0, copy_paste_prob=0.5,
copy_paste_obj_proc=0.5, *args
):
"""

Expand All @@ -64,6 +66,8 @@ def __init__(
self.degrees = degrees
self.translate = translate
self.scale = mosaic_scale
self.copy_paste_prob = copy_paste_prob
self.copy_paste_prob = copy_paste_obj_proc
self.shear = shear
self.mixup_scale = mixup_scale
self.enable_mosaic = mosaic
Expand Down Expand Up @@ -91,6 +95,14 @@ def __getitem__(self, idx):

for i_mosaic, index in enumerate(indices):
img, _labels, _, img_id = self._dataset.pull_item(index)

if self.copy_paste_prob is not None and self.copy_paste_prob!=0.0:
random_idx = index
while random_idx==index:
random_idx = random.randint(0, len(self._dataset.annotations)-1)
paste_img, paste_label, _, _ = self._dataset.pull_item(random_idx)
img, _labels = copy_paste(img, paste_img, _labels, paste_label, self.copy_paste_prob)

h0, w0 = img.shape[:2] # orig hw
scale = min(1. * input_h / h0, 1. * input_w / w0)
img = cv2.resize(
Expand Down Expand Up @@ -151,11 +163,20 @@ def __getitem__(self, idx):
# img_info and img_id are not used for training.
# They are also hard to be specified on a mosaic image.
# -----------------------------------------------------------------
img_id = torch.tensor(np.array(img_id), dtype=torch.long)
return mix_img, padded_labels, img_info, img_id

else:
self._dataset._input_dim = self.input_dim
img, label, img_info, img_id = self._dataset.pull_item(idx)

if self.copy_paste_prob is not None and self.copy_paste_prob!=0.0:
random_idx = idx
while random_idx==idx:
random_idx = random.randint(0, len(self._dataset.annotations)-1)
paste_img, paste_label, _, _ = self._dataset.pull_item(random_idx)
img, label = copy_paste(img, paste_img, label, paste_label, self.copy_paste_prob)

img, label = self.preproc(img, label, self.input_dim)
return img, label, img_info, img_id

Expand Down
1 change: 1 addition & 0 deletions yolox/evaluators/voc_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xml.etree.ElementTree as ET

import numpy as np
np.seterr(invalid='ignore')


def parse_rec(filename):
Expand Down
Loading