-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yusuke Uchida
committed
Jun 30, 2019
0 parents
commit b73d91c
Showing
11 changed files
with
838 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Age Estimation PyTorch | ||
PyTorch-based CNN implementation for estimating age from face images. | ||
Currently only the APPA-REAL dataset is supported. | ||
Similar Keras-based project can be found [here](https://github.com/yu4u/age-gender-estimation). | ||
|
||
<img src="misc/example.png" width="800px"> | ||
|
||
## Requirements | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Demo | ||
Webcam is required. | ||
See `python demo.py -h` for detailed options. | ||
|
||
```bash | ||
python demo.py | ||
``` | ||
|
||
Using `--img_dir` argument, images in that directory will be used as input: | ||
|
||
```bash | ||
python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] | ||
``` | ||
|
||
Further using `--output_dir` argument, | ||
resulting images will be saved in that directory (no resulting image window is displayed in this case): | ||
|
||
```bash | ||
python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] --output_dir [PATH/TO/OUTPUT_DIRECTORY] | ||
``` | ||
|
||
## Train | ||
|
||
#### Download Dataset | ||
|
||
Download and extract the [APPA-REAL dataset](http://chalearnlap.cvc.uab.es/dataset/26/description/). | ||
|
||
> The APPA-REAL database contains 7,591 images with associated real and apparent age labels. The total number of apparent votes is around 250,000. On average we have around 38 votes per each image and this makes the average apparent age very stable (0.3 standard error of the mean). | ||
```bash | ||
wget http://158.109.8.102/AppaRealAge/appa-real-release.zip | ||
unzip appa-real-release.zip | ||
``` | ||
|
||
#### Train Model | ||
Train a model using the APPA-REAL dataset. | ||
See `python train.py -h` for detailed options. | ||
|
||
```bash | ||
python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log | ||
``` | ||
|
||
Check training progress: | ||
|
||
```bash | ||
tensorboard --logdir=tf_log | ||
``` | ||
|
||
<img src="misc/tfboard.png" width="400px"> | ||
|
||
#### Training Options | ||
You can change training parameters including model architecture using additional arguments like this: | ||
|
||
```bash | ||
python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log MODEL.ARCH se_resnet50 TRAIN.OPT sgd TRAIN.LR 0.1 | ||
``` | ||
|
||
All default parameters defined in [defaults.py](defaults.py) can be changed using this style. | ||
|
||
|
||
#### Test Trained Model | ||
Evaluate the trained model using the APPA-REAL test dataset. | ||
|
||
```bash | ||
python test.py --data_dir [PATH/TO/appa-real-release] --resume [PATH/TO/BEST_MODEL.pth] | ||
``` | ||
|
||
After evaluation, you can see something like this: | ||
|
||
```bash | ||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:08<00:00, 1.28it/s] | ||
test mae: 4.800 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import argparse | ||
import better_exceptions | ||
from pathlib import Path | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
import cv2 | ||
from torch.utils.data import Dataset | ||
from imgaug import augmenters as iaa | ||
|
||
|
||
class ImgAugTransform: | ||
def __init__(self): | ||
self.aug = iaa.Sequential([ | ||
iaa.OneOf([ | ||
iaa.Sometimes(0.25, iaa.AdditiveGaussianNoise(scale=0.1 * 255)), | ||
iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))) | ||
]), | ||
iaa.Affine( | ||
rotate=(-20, 20), mode="edge", | ||
scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, | ||
translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)} | ||
), | ||
iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True), | ||
iaa.GammaContrast((0.3, 2)), | ||
iaa.Fliplr(0.5), | ||
]) | ||
|
||
def __call__(self, img): | ||
img = np.array(img) | ||
img = self.aug.augment_image(img) | ||
return img | ||
|
||
|
||
class FaceDataset(Dataset): | ||
def __init__(self, data_dir, data_type, img_size=224, augment=False, age_stddev=1.0): | ||
assert(data_type in ("train", "valid", "test")) | ||
csv_path = Path(data_dir).joinpath(f"gt_avg_{data_type}.csv") | ||
img_dir = Path(data_dir).joinpath(data_type) | ||
self.img_size = img_size | ||
self.augment = augment | ||
self.age_stddev = age_stddev | ||
|
||
if augment: | ||
self.transform = ImgAugTransform() | ||
else: | ||
self.transform = lambda i: i | ||
|
||
self.x = [] | ||
self.y = [] | ||
self.std = [] | ||
df = pd.read_csv(str(csv_path)) | ||
ignore_path = Path(__file__).resolve().parent.joinpath("ignore_list.csv") | ||
ignore_img_names = list(pd.read_csv(str(ignore_path))["img_name"].values) | ||
|
||
for _, row in df.iterrows(): | ||
img_name = row["file_name"] | ||
|
||
if img_name in ignore_img_names: | ||
continue | ||
|
||
img_path = img_dir.joinpath(img_name + "_face.jpg") | ||
assert(img_path.is_file()) | ||
self.x.append(str(img_path)) | ||
self.y.append(row["apparent_age_avg"]) | ||
self.std.append(row["apparent_age_std"]) | ||
|
||
def __len__(self): | ||
return len(self.y) | ||
|
||
def __getitem__(self, idx): | ||
img_path = self.x[idx] | ||
age = self.y[idx] | ||
|
||
if self.augment: | ||
age += np.random.randn() * self.std[idx] * self.age_stddev | ||
|
||
img = cv2.imread(str(img_path), 1) | ||
img = cv2.resize(img, (self.img_size, self.img_size)) | ||
img = self.transform(img).astype(np.float32) | ||
return torch.from_numpy(np.transpose(img, (2, 0, 1))), np.clip(round(age), 0, 100) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--data_dir", type=str, required=True) | ||
args = parser.parse_args() | ||
dataset = FaceDataset(args.data_dir, "train") | ||
print("train dataset len: {}".format(len(dataset))) | ||
dataset = FaceDataset(args.data_dir, "valid") | ||
print("valid dataset len: {}".format(len(dataset))) | ||
dataset = FaceDataset(args.data_dir, "test") | ||
print("test dataset len: {}".format(len(dataset))) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from yacs.config import CfgNode as CN | ||
|
||
_C = CN() | ||
|
||
# Model | ||
_C.MODEL = CN() | ||
_C.MODEL.ARCH = "se_resnext50_32x4d" # check python train.py -h for available models | ||
_C.MODEL.IMG_SIZE = 224 | ||
|
||
# Train | ||
_C.TRAIN = CN() | ||
_C.TRAIN.OPT = "adam" # adam or sgd | ||
_C.TRAIN.WORKERS = 8 | ||
_C.TRAIN.LR = 0.001 | ||
_C.TRAIN.LR_DECAY_STEP = 20 | ||
_C.TRAIN.LR_DECAY_RATE = 0.2 | ||
_C.TRAIN.MOMENTUM = 0.9 | ||
_C.TRAIN.WEIGHT_DECAY = 0.0 | ||
_C.TRAIN.BATCH_SIZE = 128 | ||
_C.TRAIN.EPOCHS = 80 | ||
_C.TRAIN.AGE_STDDEV = 1.0 | ||
|
||
# Test | ||
_C.TEST = CN() | ||
_C.TEST.WORKERS = 8 | ||
_C.TEST.BATCH_SIZE = 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import argparse | ||
import better_exceptions | ||
from pathlib import Path | ||
from contextlib import contextmanager | ||
import numpy as np | ||
import cv2 | ||
import dlib | ||
import torch | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim | ||
import torch.utils.data | ||
import torch.nn.functional as F | ||
from model import get_model | ||
from defaults import _C as cfg | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="Age estimation demo", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--resume", type=str, required=True, | ||
help="Model weight to be tested") | ||
parser.add_argument("--margin", type=float, default=0.4, | ||
help="Margin around detected face for age-gender estimation") | ||
parser.add_argument("--img_dir", type=str, default=None, | ||
help="Target image directory; if set, images in image_dir are used instead of webcam") | ||
parser.add_argument("--output_dir", type=str, default=None, | ||
help="Output directory to which resulting images will be stored if set") | ||
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, | ||
help="Modify config options using the command-line") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX, | ||
font_scale=0.8, thickness=1): | ||
size = cv2.getTextSize(label, font, font_scale, thickness)[0] | ||
x, y = point | ||
cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED) | ||
cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA) | ||
|
||
|
||
@contextmanager | ||
def video_capture(*args, **kwargs): | ||
cap = cv2.VideoCapture(*args, **kwargs) | ||
try: | ||
yield cap | ||
finally: | ||
cap.release() | ||
|
||
|
||
def yield_images(): | ||
with video_capture(0) as cap: | ||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | ||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | ||
|
||
while True: | ||
ret, img = cap.read() | ||
|
||
if not ret: | ||
raise RuntimeError("Failed to capture image") | ||
|
||
yield img, None | ||
|
||
|
||
def yield_images_from_dir(img_dir): | ||
img_dir = Path(img_dir) | ||
|
||
for img_path in img_dir.glob("*.*"): | ||
img = cv2.imread(str(img_path), 1) | ||
|
||
if img is not None: | ||
h, w, _ = img.shape | ||
r = 640 / max(w, h) | ||
yield cv2.resize(img, (int(w * r), int(h * r))), img_path.name | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
|
||
if args.opts: | ||
cfg.merge_from_list(args.opts) | ||
|
||
cfg.freeze() | ||
|
||
if args.output_dir is not None: | ||
if args.img_dir is None: | ||
raise ValueError("=> --img_dir argument is required if --output_dir is used") | ||
|
||
output_dir = Path(args.output_dir) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# create model | ||
print("=> creating model '{}'".format(cfg.MODEL.ARCH)) | ||
model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None) | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model = model.to(device) | ||
|
||
# load checkpoint | ||
resume_path = args.resume | ||
|
||
if Path(resume_path).is_file(): | ||
print("=> loading checkpoint '{}'".format(resume_path)) | ||
checkpoint = torch.load(resume_path, map_location="cpu") | ||
model.load_state_dict(checkpoint['state_dict']) | ||
print("=> loaded checkpoint '{}'".format(resume_path)) | ||
else: | ||
raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) | ||
|
||
if device == "cuda": | ||
cudnn.benchmark = True | ||
|
||
model.eval() | ||
margin = args.margin | ||
img_dir = args.img_dir | ||
detector = dlib.get_frontal_face_detector() | ||
img_size = cfg.MODEL.IMG_SIZE | ||
image_generator = yield_images_from_dir(img_dir) if img_dir else yield_images() | ||
|
||
with torch.no_grad(): | ||
for img, name in image_generator: | ||
input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
img_h, img_w, _ = np.shape(input_img) | ||
|
||
# detect faces using dlib detector | ||
detected = detector(input_img, 1) | ||
faces = np.empty((len(detected), img_size, img_size, 3)) | ||
|
||
if len(detected) > 0: | ||
for i, d in enumerate(detected): | ||
x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height() | ||
xw1 = max(int(x1 - margin * w), 0) | ||
yw1 = max(int(y1 - margin * h), 0) | ||
xw2 = min(int(x2 + margin * w), img_w - 1) | ||
yw2 = min(int(y2 + margin * h), img_h - 1) | ||
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2) | ||
cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2) | ||
faces[i] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1], (img_size, img_size)) | ||
|
||
# predict ages | ||
inputs = torch.from_numpy(np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device) | ||
outputs = F.softmax(model(inputs), dim=-1).cpu().numpy() | ||
ages = np.arange(0, 101) | ||
predicted_ages = (outputs * ages).sum(axis=-1) | ||
|
||
# draw results | ||
for i, d in enumerate(detected): | ||
label = "{}".format(int(predicted_ages[i])) | ||
draw_label(img, (d.left(), d.top()), label) | ||
|
||
if args.output_dir is not None: | ||
output_path = output_dir.joinpath(name) | ||
cv2.imwrite(str(output_path), img) | ||
else: | ||
cv2.imshow("result", img) | ||
key = cv2.waitKey(-1) if img_dir else cv2.waitKey(30) | ||
|
||
if key == 27: # ESC | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.