Skip to content

Yuxi-Zhang/A-classifier-with-PyTorch

 
 

Repository files navigation

A classifier with PyTorch

This code calls the models in Torchvision, and the classification network topic framework is derived from Torchvision. (And if you have any problem,you can send email to me:[email protected] or leave an error message in Issues.

from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from . import segmentation
from . import detection
from . import video

image info

#Several classification frameworks are available
AlexNetdensenet121densenet169densenet201densenet161GoogLeNetInception3mnasnet0_5mnasnet0_75mnasnet1_0mnasnet1_3MobileNetV2resnet18resnet34resnet50resnet101resnet152resnext50_32x4dresnext101_32x8dwide_resnet50_2wide_resnet101_2vgg11vgg13vgg16vgg19vgg11_bnvgg13_bnvgg16_bnvgg19_bn...........

The above is the classic network framework available within the models, and only for the classification networks within.This code is can take transfer learning , download the ImageNet pre trained initial model and then transfer learning in your code, and can be frozen convolution training only full connection layer, or global training, we only use the convolution of the classic network layer, and then the convolution results set on our lightweight classifier,

Train on our datasets

We used this classifier to predict the gender of the chicken, and we used vgg16,vgg16_bn,vgg19,vgg19_bn,resnet18,resnet34、densenet101 made a comparison。You can get our dataset here(谷歌云盘,所以大陆用户需翻墙访问,如不能翻墙,或有需要可发邮件给我)

Some sample images from Our dataset: image info

Train on Custom Dataset

-your datasets
 |--train
 |   |--label_1
 |   |--label_2
 |   |--label_n
 |--test or Val
     |--label_1
     |--label_2
     |--label_n

Your data set needs to look like the file structure above.And if you're not dichotomous, change the last output dimension from 2 to n。 Then execute the following command

python train.py --data_directory=your dataset --arch=vgg16

if you want to train on resnet or densenet and other, you can change the --arch=vgg16 to --arch=resnet34 or -- arch=densenet101 or other

Visualization of Training Process

Use tensorboard for visualization. After training, you can enter the following command for visualization.

Then visit the page that pops up on the command line,the following image will appear

tensorboard --logdirs=logs

image info

Visit the above page and download the corresponding CSV, then plot the training process according to csv_plot.py: image info You can adjust the parameters to make the training process more beautiful

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%