Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yolox fix #5

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from loguru import logger

import yaml
import torch
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -41,6 +42,12 @@ def make_parser():
type=str,
help="plz input your experiment description file",
)
parser.add_argument(
"--config_filepath",
default=None,
type=str,
help="Filepath to config file",
)
parser.add_argument(
"--resume", default=False, action="store_true", help="resume training"
)
Expand Down Expand Up @@ -120,11 +127,21 @@ def main(exp, args):
if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)

#TODO: Add neptune logging with multidevice training. Logging now works only
# on 1 gpu device training, not working with multiprocessing.
exp.set_neptune_logging(True)

exp.merge(args.opts)

if not args.experiment_name:
args.experiment_name = exp.exp_name

if args.config_filepath is not None:
with open(args.config_filepath, "r") as f:
config = yaml.safe_load(f)
exp.add_params_from_config(config, use_neptune=True)
exp.neptune['config_file'].upload(args.config_filepath)
num_gpu = get_num_devices() if args.devices is None else args.devices
assert num_gpu <= get_num_devices()

Expand Down
51 changes: 48 additions & 3 deletions yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, exp, args):
# before_train methods.
self.exp = exp
self.args = args

self.neptune = self.exp.neptune
# training related attr
self.max_epoch = exp.max_epoch
self.amp_training = args.fp16
Expand All @@ -55,6 +55,9 @@ def __init__(self, exp, args):
self.input_size = exp.input_size
self.best_ap = 0

# validation loss
self.calc_validation_loss = exp.calc_val_loss

# metric record
self.meter = MeterBuffer(window_size=exp.print_interval)
self.file_name = os.path.join(exp.output_dir, args.experiment_name)
Expand Down Expand Up @@ -114,6 +117,7 @@ def train_one_iter(self):
self.ema_model.update(self.model)

lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
self.neptune['config/lr'].log(lr)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -151,10 +155,21 @@ def before_train(self):
no_aug=self.no_aug,
cache_img=self.args.cache,
)
if self.calc_validation_loss:
self.val_loader = self.exp.get_val_loader(
batch_size=self.args.batch_size,
is_distributed=self.is_distributed,
no_aug=False,
cache_img=self.args.cache,
)
logger.info("init prefetcher, this might take one minute or less...")
self.prefetcher = DataPrefetcher(self.train_loader)
if self.calc_validation_loss:
self.val_prefetcher = DataPrefetcher(self.val_loader)
# max_iter means iters per epoch
self.max_iter = len(self.train_loader)
if self.calc_validation_loss:
self.max_val_iter = len(self.val_loader)

self.lr_scheduler = self.exp.get_lr_scheduler(
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
Expand Down Expand Up @@ -243,7 +258,8 @@ def after_iter(self):
loss_str = ", ".join(
["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()]
)

for loss_name, loss_value in loss_meter.items():
self.neptune[f"loss/{loss_name}"].log(loss_value.latest)
time_meter = self.meter.get_filtered_meter("time")
time_str = ", ".join(
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
Expand Down Expand Up @@ -313,6 +329,10 @@ def resume_train(self, model):
return model

def evaluate_and_save_model(self):
# calculate loss
if self.calc_validation_loss:
self.calculate_eval_loss()

if self.use_model_ema:
evalmodel = self.ema_model.ema
else:
Expand All @@ -327,7 +347,7 @@ def evaluate_and_save_model(self):

update_best_ckpt = ap50_95 > self.best_ap
self.best_ap = max(self.best_ap, ap50_95)

self.neptune['metrics/best_ap'].log(self.best_ap)
if self.rank == 0:
if self.args.logger == "tensorboard":
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
Expand Down Expand Up @@ -360,7 +380,32 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
update_best_ckpt,
self.file_name,
ckpt_name,
self.neptune,
)

if self.args.logger == "wandb":
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)

def calculate_eval_loss(self):
for iter in range(self.max_val_iter):
inps, targets = self.val_prefetcher.next()
inps = inps.to(self.data_type)
targets = targets.to(self.data_type)
targets.requires_grad = False
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
with torch.cuda.amp.autocast(enabled=self.amp_training):
outputs = self.model(inps, targets)
loss = {
"total_loss": outputs["total_loss"],
"iou_loss": outputs["iou_loss"],
"l1_loss": outputs["l1_loss"],
"conf_loss": outputs["conf_loss"],
"cls_loss": outputs["cls_loss"]
}
progress_str = "epoch: {}/{}, iter: {}/{},".format(
self.epoch + 1, self.max_epoch, iter + 1, self.max_val_iter
)
for loss_name, loss_value in loss.items():
progress_str += " {}: {:.1f},".format(loss_name, loss_value)
self.neptune[f"loss/val/{loss_name}"].log(loss_value)
logger.info("Validation:{}".format(progress_str))
2 changes: 2 additions & 0 deletions yolox/data/datasets/mosaicdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cv2
import numpy as np
import torch

from yolox.utils import adjust_box_anns, get_local_rank

Expand Down Expand Up @@ -151,6 +152,7 @@ def __getitem__(self, idx):
# img_info and img_id are not used for training.
# They are also hard to be specified on a mosaic image.
# -----------------------------------------------------------------
img_id = torch.tensor(np.array(img_id), dtype=torch.long)
return mix_img, padded_labels, img_info, img_id

else:
Expand Down
26 changes: 25 additions & 1 deletion yolox/exp/base_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from typing import Dict
from tabulate import tabulate

import neptune.new as neptune
import torch
from torch.nn import Module

from yolox.utils import LRScheduler

from paths import DATASETS_PATH

class BaseExp(metaclass=ABCMeta):
"""Basic class for any experiment."""
Expand All @@ -22,6 +23,7 @@ def __init__(self):
self.output_dir = "./YOLOX_outputs"
self.print_interval = 100
self.eval_interval = 10
self.neptune = None

@abstractmethod
def get_model(self) -> Module:
Expand Down Expand Up @@ -60,6 +62,17 @@ def __repr__(self):
]
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")

def set_neptune_logging(self, state):
if state:
self.neptune = neptune.init(
project="jakub.pingielski/b-yond",
api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2NTlkYzZmZC1kZTY5LTQ2NjMtODFkZC04YmY4NTNmYTkwMTIifQ==",
)
else:
if self.neptune is not None:
self.neptune.stop()
self.neptune = None

def merge(self, cfg_list):
assert len(cfg_list) % 2 == 0
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
Expand All @@ -73,3 +86,14 @@ def merge(self, cfg_list):
except Exception:
v = ast.literal_eval(v)
setattr(self, k, v)

def add_params_from_config(self, config: dict, use_neptune: bool = True):
for key, value in config.items():
if key == "dataset_version":
setattr(self, "dataset_dir", DATASETS_PATH / value)
else:
setattr(self, key, value)
if use_neptune and self.neptune:
self.neptune[f"config/{key}"].log(value)


60 changes: 59 additions & 1 deletion yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self):
self.no_aug_epochs = 15
# apply EMA during training
self.ema = True
# calculate validation loss
self.calc_val_loss = False

# weight decay of optimizer
self.weight_decay = 5e-4
Expand All @@ -94,7 +96,7 @@ def __init__(self):
self.eval_interval = 10
# save history checkpoint or not.
# If set to False, yolox will only save latest and best ckpt.
self.save_history_ckpt = True
self.save_history_ckpt = False
# name of experiment
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

Expand Down Expand Up @@ -201,6 +203,62 @@ def get_data_loader(

return train_loader

def get_val_loader(
self, batch_size, is_distributed, no_aug=False, cache_img=False, testdev=False
):
from yolox.data import (
COCODataset,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
worker_init_reset_seed,
)
from yolox.utils import (
wait_for_the_master,
get_local_rank,
)

local_rank = get_local_rank()

with wait_for_the_master(local_rank):
dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.val_ann if not testdev else self.test_ann,
img_size=self.input_size,
preproc=TrainTransform(
max_labels=50,
flip_prob=0.0,
hsv_prob=0.0),
cache=cache_img,
)

self.dataset = dataset

if is_distributed:
batch_size = batch_size // dist.get_world_size()

sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)

batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
mosaic=not no_aug,
)

dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler

# Make sure each process has different random seed, especially for 'fork' method.
# Check https://github.com/pytorch/pytorch/issues/63311 for more details.
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed

val_loader = DataLoader(self.dataset, **dataloader_kwargs)

return val_loader

def random_resize(self, data_loader, epoch, rank, is_distributed):
tensor = torch.LongTensor(2).cuda()

Expand Down
5 changes: 4 additions & 1 deletion yolox/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
from loguru import logger
import neptune.new as neptune

import torch

Expand Down Expand Up @@ -33,11 +34,13 @@ def load_ckpt(model, ckpt):
return model


def save_checkpoint(state, is_best, save_dir, model_name=""):
def save_checkpoint(state, is_best, save_dir, model_name="", neptune=None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save_dir, "best_ckpt.pth")
shutil.copyfile(filename, best_filename)
if neptune:
neptune['best_checkpoint'].upload(best_filename)