diff --git a/README.md b/README.md new file mode 100644 index 00000000..c8a30942 --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +# Age Estimation PyTorch +PyTorch-based CNN implementation for estimating age from face images. +Currently only the APPA-REAL dataset is supported. +Similar Keras-based project can be found [here](https://github.com/yu4u/age-gender-estimation). + + + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Demo +Webcam is required. +See `python demo.py -h` for detailed options. + +```bash +python demo.py +``` + +Using `--img_dir` argument, images in that directory will be used as input: + +```bash +python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] +``` + +Further using `--output_dir` argument, +resulting images will be saved in that directory (no resulting image window is displayed in this case): + +```bash +python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] --output_dir [PATH/TO/OUTPUT_DIRECTORY] +``` + +## Train + +#### Download Dataset + +Download and extract the [APPA-REAL dataset](http://chalearnlap.cvc.uab.es/dataset/26/description/). + +> The APPA-REAL database contains 7,591 images with associated real and apparent age labels. The total number of apparent votes is around 250,000. On average we have around 38 votes per each image and this makes the average apparent age very stable (0.3 standard error of the mean). + +```bash +wget http://158.109.8.102/AppaRealAge/appa-real-release.zip +unzip appa-real-release.zip +``` + +#### Train Model +Train a model using the APPA-REAL dataset. +See `python train.py -h` for detailed options. + +```bash +python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log +``` + +Check training progress: + +```bash +tensorboard --logdir=tf_log +``` + + + +#### Training Options +You can change training parameters including model architecture using additional arguments like this: + +```bash +python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log MODEL.ARCH se_resnet50 TRAIN.OPT sgd TRAIN.LR 0.1 +``` + +All default parameters defined in [defaults.py](defaults.py) can be changed using this style. + + +#### Test Trained Model +Evaluate the trained model using the APPA-REAL test dataset. + +```bash +python test.py --data_dir [PATH/TO/appa-real-release] --resume [PATH/TO/BEST_MODEL.pth] +``` + +After evaluation, you can see something like this: + +```bash +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:08<00:00, 1.28it/s] +test mae: 4.800 +``` diff --git a/dataset.py b/dataset.py new file mode 100644 index 00000000..61dc8339 --- /dev/null +++ b/dataset.py @@ -0,0 +1,97 @@ +import argparse +import better_exceptions +from pathlib import Path +import numpy as np +import pandas as pd +import torch +import cv2 +from torch.utils.data import Dataset +from imgaug import augmenters as iaa + + +class ImgAugTransform: + def __init__(self): + self.aug = iaa.Sequential([ + iaa.OneOf([ + iaa.Sometimes(0.25, iaa.AdditiveGaussianNoise(scale=0.1 * 255)), + iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))) + ]), + iaa.Affine( + rotate=(-20, 20), mode="edge", + scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, + translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)} + ), + iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True), + iaa.GammaContrast((0.3, 2)), + iaa.Fliplr(0.5), + ]) + + def __call__(self, img): + img = np.array(img) + img = self.aug.augment_image(img) + return img + + +class FaceDataset(Dataset): + def __init__(self, data_dir, data_type, img_size=224, augment=False, age_stddev=1.0): + assert(data_type in ("train", "valid", "test")) + csv_path = Path(data_dir).joinpath(f"gt_avg_{data_type}.csv") + img_dir = Path(data_dir).joinpath(data_type) + self.img_size = img_size + self.augment = augment + self.age_stddev = age_stddev + + if augment: + self.transform = ImgAugTransform() + else: + self.transform = lambda i: i + + self.x = [] + self.y = [] + self.std = [] + df = pd.read_csv(str(csv_path)) + ignore_path = Path(__file__).resolve().parent.joinpath("ignore_list.csv") + ignore_img_names = list(pd.read_csv(str(ignore_path))["img_name"].values) + + for _, row in df.iterrows(): + img_name = row["file_name"] + + if img_name in ignore_img_names: + continue + + img_path = img_dir.joinpath(img_name + "_face.jpg") + assert(img_path.is_file()) + self.x.append(str(img_path)) + self.y.append(row["apparent_age_avg"]) + self.std.append(row["apparent_age_std"]) + + def __len__(self): + return len(self.y) + + def __getitem__(self, idx): + img_path = self.x[idx] + age = self.y[idx] + + if self.augment: + age += np.random.randn() * self.std[idx] * self.age_stddev + + img = cv2.imread(str(img_path), 1) + img = cv2.resize(img, (self.img_size, self.img_size)) + img = self.transform(img).astype(np.float32) + return torch.from_numpy(np.transpose(img, (2, 0, 1))), np.clip(round(age), 0, 100) + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--data_dir", type=str, required=True) + args = parser.parse_args() + dataset = FaceDataset(args.data_dir, "train") + print("train dataset len: {}".format(len(dataset))) + dataset = FaceDataset(args.data_dir, "valid") + print("valid dataset len: {}".format(len(dataset))) + dataset = FaceDataset(args.data_dir, "test") + print("test dataset len: {}".format(len(dataset))) + + +if __name__ == '__main__': + main() diff --git a/defaults.py b/defaults.py new file mode 100644 index 00000000..538ba8e8 --- /dev/null +++ b/defaults.py @@ -0,0 +1,26 @@ +from yacs.config import CfgNode as CN + +_C = CN() + +# Model +_C.MODEL = CN() +_C.MODEL.ARCH = "se_resnext50_32x4d" # check python train.py -h for available models +_C.MODEL.IMG_SIZE = 224 + +# Train +_C.TRAIN = CN() +_C.TRAIN.OPT = "adam" # adam or sgd +_C.TRAIN.WORKERS = 8 +_C.TRAIN.LR = 0.001 +_C.TRAIN.LR_DECAY_STEP = 20 +_C.TRAIN.LR_DECAY_RATE = 0.2 +_C.TRAIN.MOMENTUM = 0.9 +_C.TRAIN.WEIGHT_DECAY = 0.0 +_C.TRAIN.BATCH_SIZE = 128 +_C.TRAIN.EPOCHS = 80 +_C.TRAIN.AGE_STDDEV = 1.0 + +# Test +_C.TEST = CN() +_C.TEST.WORKERS = 8 +_C.TEST.BATCH_SIZE = 128 diff --git a/demo.py b/demo.py new file mode 100644 index 00000000..e6b16ff0 --- /dev/null +++ b/demo.py @@ -0,0 +1,163 @@ +import argparse +import better_exceptions +from pathlib import Path +from contextlib import contextmanager +import numpy as np +import cv2 +import dlib +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.nn.functional as F +from model import get_model +from defaults import _C as cfg + + +def get_args(): + parser = argparse.ArgumentParser(description="Age estimation demo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--resume", type=str, required=True, + help="Model weight to be tested") + parser.add_argument("--margin", type=float, default=0.4, + help="Margin around detected face for age-gender estimation") + parser.add_argument("--img_dir", type=str, default=None, + help="Target image directory; if set, images in image_dir are used instead of webcam") + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory to which resulting images will be stored if set") + parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, + help="Modify config options using the command-line") + args = parser.parse_args() + return args + + +def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX, + font_scale=0.8, thickness=1): + size = cv2.getTextSize(label, font, font_scale, thickness)[0] + x, y = point + cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED) + cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA) + + +@contextmanager +def video_capture(*args, **kwargs): + cap = cv2.VideoCapture(*args, **kwargs) + try: + yield cap + finally: + cap.release() + + +def yield_images(): + with video_capture(0) as cap: + cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) + + while True: + ret, img = cap.read() + + if not ret: + raise RuntimeError("Failed to capture image") + + yield img, None + + +def yield_images_from_dir(img_dir): + img_dir = Path(img_dir) + + for img_path in img_dir.glob("*.*"): + img = cv2.imread(str(img_path), 1) + + if img is not None: + h, w, _ = img.shape + r = 640 / max(w, h) + yield cv2.resize(img, (int(w * r), int(h * r))), img_path.name + + +def main(): + args = get_args() + + if args.opts: + cfg.merge_from_list(args.opts) + + cfg.freeze() + + if args.output_dir is not None: + if args.img_dir is None: + raise ValueError("=> --img_dir argument is required if --output_dir is used") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # create model + print("=> creating model '{}'".format(cfg.MODEL.ARCH)) + model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + # load checkpoint + resume_path = args.resume + + if Path(resume_path).is_file(): + print("=> loading checkpoint '{}'".format(resume_path)) + checkpoint = torch.load(resume_path, map_location="cpu") + model.load_state_dict(checkpoint['state_dict']) + print("=> loaded checkpoint '{}'".format(resume_path)) + else: + raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) + + if device == "cuda": + cudnn.benchmark = True + + model.eval() + margin = args.margin + img_dir = args.img_dir + detector = dlib.get_frontal_face_detector() + img_size = cfg.MODEL.IMG_SIZE + image_generator = yield_images_from_dir(img_dir) if img_dir else yield_images() + + with torch.no_grad(): + for img, name in image_generator: + input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_h, img_w, _ = np.shape(input_img) + + # detect faces using dlib detector + detected = detector(input_img, 1) + faces = np.empty((len(detected), img_size, img_size, 3)) + + if len(detected) > 0: + for i, d in enumerate(detected): + x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height() + xw1 = max(int(x1 - margin * w), 0) + yw1 = max(int(y1 - margin * h), 0) + xw2 = min(int(x2 + margin * w), img_w - 1) + yw2 = min(int(y2 + margin * h), img_h - 1) + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2) + cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2) + faces[i] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1], (img_size, img_size)) + + # predict ages + inputs = torch.from_numpy(np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device) + outputs = F.softmax(model(inputs), dim=-1).cpu().numpy() + ages = np.arange(0, 101) + predicted_ages = (outputs * ages).sum(axis=-1) + + # draw results + for i, d in enumerate(detected): + label = "{}".format(int(predicted_ages[i])) + draw_label(img, (d.left(), d.top()), label) + + if args.output_dir is not None: + output_path = output_dir.joinpath(name) + cv2.imwrite(str(output_path), img) + else: + cv2.imshow("result", img) + key = cv2.waitKey(-1) if img_dir else cv2.waitKey(30) + + if key == 27: # ESC + break + + +if __name__ == '__main__': + main() diff --git a/ignore_list.csv b/ignore_list.csv new file mode 100644 index 00000000..6958f0f5 --- /dev/null +++ b/ignore_list.csv @@ -0,0 +1,119 @@ +img_name +000025.jpg +000049.jpg +000067.jpg +000085.jpg +000095.jpg +000100.jpg +000127.jpg +000145.jpg +000191.jpg +000215.jpg +000320.jpg +000373.jpg +000392.jpg +000407.jpg +000488.jpg +000503.jpg +000506.jpg +000510.jpg +000536.jpg +000605.jpg +000625.jpg +000639.jpg +000707.jpg +000708.jpg +000712.jpg +000813.jpg +000837.jpg +000848.jpg +000856.jpg +000891.jpg +000892.jpg +001022.jpg +001044.jpg +001095.jpg +001098.jpg +001122.jpg +001125.jpg +001137.jpg +001156.jpg +001227.jpg +001251.jpg +001267.jpg +001282.jpg +001328.jpg +001349.jpg +001380.jpg +001427.jpg +001460.jpg +001475.jpg +001697.jpg +001744.jpg +001864.jpg +001957.jpg +001968.jpg +001973.jpg +002029.jpg +002063.jpg +002109.jpg +002112.jpg +002115.jpg +002123.jpg +002162.jpg +002175.jpg +002179.jpg +002221.jpg +002250.jpg +002303.jpg +002359.jpg +002360.jpg +002412.jpg +002417.jpg +002435.jpg +002460.jpg +002466.jpg +002472.jpg +002488.jpg +002535.jpg +002543.jpg +002565.jpg +002615.jpg +002630.jpg +002633.jpg +002661.jpg +002733.jpg +002756.jpg +002860.jpg +002883.jpg +002887.jpg +002890.jpg +002948.jpg +002995.jpg +003018.jpg +003130.jpg +003164.jpg +003233.jpg +003258.jpg +003271.jpg +003329.jpg +003351.jpg +003357.jpg +003371.jpg +003415.jpg +003427.jpg +003441.jpg +003447.jpg +003458.jpg +003570.jpg +003625.jpg +003669.jpg +003711.jpg +003747.jpg +003749.jpg +003758.jpg +003763.jpg +003772.jpg +003805.jpg +003814.jpg +003903.jpg diff --git a/misc/example.png b/misc/example.png new file mode 100644 index 00000000..e0bf0f87 Binary files /dev/null and b/misc/example.png differ diff --git a/misc/tfboard.png b/misc/tfboard.png new file mode 100644 index 00000000..3836beb2 Binary files /dev/null and b/misc/tfboard.png differ diff --git a/model.py b/model.py new file mode 100644 index 00000000..95a581a9 --- /dev/null +++ b/model.py @@ -0,0 +1,19 @@ +import torch.nn as nn +import pretrainedmodels +import pretrainedmodels.utils + + +def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"): + model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) + dim_feats = model.last_linear.in_features + model.last_linear = nn.Linear(dim_feats, num_classes) + return model + + +def main(): + model = get_model() + print(model) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..797b25fe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +better-exceptions==0.2.2 +dlib==19.17.0 +future==0.17.1 +imgaug==0.2.9 +numpy==1.16.4 +opencv-python==4.1.0.25 +pandas==0.24.2 +pretrainedmodels==0.7.4 +tensorboard==1.14.0 +torch==1.1.0 +torchvision==0.3.0 +tqdm==4.32.2 +yacs==0.1.6 diff --git a/test.py b/test.py new file mode 100644 index 00000000..553af829 --- /dev/null +++ b/test.py @@ -0,0 +1,71 @@ +import argparse +import better_exceptions +from pathlib import Path +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +from torch.utils.data import DataLoader +import pretrainedmodels +import pretrainedmodels.utils +from model import get_model +from dataset import FaceDataset +from defaults import _C as cfg +from train import validate + + +def get_args(): + model_names = sorted(name for name in pretrainedmodels.__dict__ + if not name.startswith("__") + and name.islower() + and callable(pretrainedmodels.__dict__[name])) + parser = argparse.ArgumentParser(description=f"available models: {model_names}", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--data_dir", type=str, required=True, help="Data root directory") + parser.add_argument("--resume", type=str, required=True, help="Model weight to be tested") + parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, + help="Modify config options using the command-line") + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + if args.opts: + cfg.merge_from_list(args.opts) + + cfg.freeze() + + # create model + print("=> creating model '{}'".format(cfg.MODEL.ARCH)) + model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + # load checkpoint + resume_path = args.resume + + if Path(resume_path).is_file(): + print("=> loading checkpoint '{}'".format(resume_path)) + checkpoint = torch.load(resume_path, map_location="cpu") + model.load_state_dict(checkpoint['state_dict']) + print("=> loaded checkpoint '{}'".format(resume_path)) + else: + raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) + + if device == "cuda": + cudnn.benchmark = True + + test_dataset = FaceDataset(args.data_dir, "test", img_size=cfg.MODEL.IMG_SIZE, augment=False) + test_loader = DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, + num_workers=cfg.TRAIN.WORKERS, drop_last=False) + + print("=> start testing") + _, _, test_mae = validate(test_loader, model, None, 0, device) + print(f"test mae: {test_mae:.3f}") + + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 00000000..6f390e84 --- /dev/null +++ b/train.py @@ -0,0 +1,244 @@ +import argparse +import better_exceptions +from pathlib import Path +from collections import OrderedDict +from tqdm import tqdm +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +from torch.optim.lr_scheduler import StepLR +import torch.utils.data +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +import pretrainedmodels +import pretrainedmodels.utils +from model import get_model +from dataset import FaceDataset +from defaults import _C as cfg + + +def get_args(): + model_names = sorted(name for name in pretrainedmodels.__dict__ + if not name.startswith("__") + and name.islower() + and callable(pretrainedmodels.__dict__[name])) + parser = argparse.ArgumentParser(description=f"available models: {model_names}", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--data_dir", type=str, required=True, help="Data root directory") + parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint if any") + parser.add_argument("--checkpoint", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("--tensorboard", type=str, default=None, help="Tensorboard log directory") + parser.add_argument('--multi_gpu', action="store_true", help="Use multi GPUs (data parallel)") + parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, + help="Modify config options using the command-line") + args = parser.parse_args() + return args + + +class AverageMeter(object): + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val + self.count += n + self.avg = self.sum / self.count + + +def train(train_loader, model, criterion, optimizer, epoch, device): + model.train() + loss_monitor = AverageMeter() + accuracy_monitor = AverageMeter() + + with tqdm(train_loader) as _tqdm: + for x, y in _tqdm: + x = x.to(device) + y = y.to(device) + + # compute output + outputs = model(x) + + # calc loss + loss = criterion(outputs, y) + cur_loss = loss.item() + + # calc accuracy + _, predicted = outputs.max(1) + correct_num = predicted.eq(y).sum().item() + + # measure accuracy and record loss + sample_num = x.size(0) + loss_monitor.update(cur_loss, sample_num) + accuracy_monitor.update(correct_num, sample_num) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + _tqdm.set_postfix(OrderedDict(stage="train", epoch=epoch, loss=loss_monitor.avg), + acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num) + + return loss_monitor.avg, accuracy_monitor.avg + + +def validate(validate_loader, model, criterion, epoch, device): + model.eval() + loss_monitor = AverageMeter() + accuracy_monitor = AverageMeter() + preds = [] + gt = [] + + with torch.no_grad(): + with tqdm(validate_loader) as _tqdm: + for i, (x, y) in enumerate(_tqdm): + x = x.to(device) + y = y.to(device) + + # compute output + outputs = model(x) + preds.append(F.softmax(outputs, dim=-1).cpu().numpy()) + gt.append(y.cpu().numpy()) + + # valid for validation, not used for test + if criterion is not None: + # calc loss + loss = criterion(outputs, y) + cur_loss = loss.item() + + # calc accuracy + _, predicted = outputs.max(1) + correct_num = predicted.eq(y).sum().item() + + # measure accuracy and record loss + sample_num = x.size(0) + loss_monitor.update(cur_loss, sample_num) + accuracy_monitor.update(correct_num, sample_num) + _tqdm.set_postfix(OrderedDict(stage="val", epoch=epoch, loss=loss_monitor.avg), + acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num) + + preds = np.concatenate(preds, axis=0) + gt = np.concatenate(gt, axis=0) + ages = np.arange(0, 101) + ave_preds = (preds * ages).sum(axis=-1) + diff = ave_preds - gt + mae = np.abs(diff).mean() + + return loss_monitor.avg, accuracy_monitor.avg, mae + + +def main(): + args = get_args() + + if args.opts: + cfg.merge_from_list(args.opts) + + cfg.freeze() + start_epoch = 0 + checkpoint_dir = Path(args.checkpoint) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # create model + print("=> creating model '{}'".format(cfg.MODEL.ARCH)) + model = get_model(model_name=cfg.MODEL.ARCH) + + if cfg.TRAIN.OPT == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, + momentum=cfg.TRAIN.MOMENTUM, + weight_decay=cfg.TRAIN.WEIGHT_DECAY) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + # optionally resume from a checkpoint + resume_path = args.resume + + if resume_path: + if Path(resume_path).is_file(): + print("=> loading checkpoint '{}'".format(resume_path)) + checkpoint = torch.load(resume_path, map_location="cpu") + start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(resume_path, checkpoint['epoch'])) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + print("=> no checkpoint found at '{}'".format(resume_path)) + + if args.multi_gpu: + model = nn.DataParallel(model) + + if device == "cuda": + cudnn.benchmark = True + + criterion = nn.CrossEntropyLoss().to(device) + train_dataset = FaceDataset(args.data_dir, "train", img_size=cfg.MODEL.IMG_SIZE, augment=True, + age_stddev=cfg.TRAIN.AGE_STDDEV) + train_loader = DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, + num_workers=cfg.TRAIN.WORKERS, drop_last=True) + + val_dataset = FaceDataset(args.data_dir, "valid", img_size=cfg.MODEL.IMG_SIZE, augment=False) + val_loader = DataLoader(val_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, + num_workers=cfg.TRAIN.WORKERS, drop_last=False) + + scheduler = StepLR(optimizer, step_size=cfg.TRAIN.LR_DECAY_STEP, gamma=cfg.TRAIN.LR_DECAY_RATE, + last_epoch=start_epoch - 1) + best_val_mae = 10000.0 + train_writer = None + + if args.tensorboard is not None: + opts_prefix = "_".join(args.opts) + train_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_train") + val_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_val") + + for epoch in range(start_epoch, cfg.TRAIN.EPOCHS): + # adjust learning rate + scheduler.step() + + # train + train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, device) + + # validate + val_loss, val_acc, val_mae = validate(val_loader, model, criterion, epoch, device) + + if args.tensorboard is not None: + train_writer.add_scalar("loss", train_loss, epoch) + train_writer.add_scalar("acc", train_acc, epoch) + val_writer.add_scalar("loss", val_loss, epoch) + val_writer.add_scalar("acc", val_acc, epoch) + val_writer.add_scalar("mae", val_mae, epoch) + + # checkpoint + if val_mae < best_val_mae: + print(f"=> [epoch {epoch:03d}] best val mae was improved from {best_val_mae:.3f} to {val_mae:.3f}") + model_state_dict = model.module.state_dict() if args.multi_gpu else model.state_dict() + torch.save( + { + 'epoch': epoch + 1, + 'arch': cfg.MODEL.ARCH, + 'state_dict': model_state_dict, + 'optimizer_state_dict': optimizer.state_dict() + }, + str(checkpoint_dir.joinpath("epoch{:03d}_{:.5f}_{:.4f}.pth".format(epoch, val_loss, val_mae))) + ) + best_val_mae = val_mae + else: + print(f"=> [epoch {epoch:03d}] best val mae was not improved from {best_val_mae:.3f} ({val_mae:.3f})") + + print("=> training finished") + print(f"additional opts: {args.opts}") + print(f"best val mae: {best_val_mae:.3f}") + + +if __name__ == '__main__': + main()