This code trains a convolutional neural network (CNN) to classify images of 10 different animals.
The data consists of images of 10 animal classes:
- Butterfly
- Cat
- Chicken
- Cow
- Dog
- Elephant
- Horse
- Sheep
- Spider
- Squirrel
The data is split into a training set and a test set. The image data is stored in the data/animals
folder, with separate subfolders for each class.
There are two model architectures implemented:
SimpleCNN
: A simple CNN with 2 convolutional layers and 2 fully connected layersAdvancedCNN
: A more advanced CNN with 5 convolutional layers and 3 fully connected layers
The models are defined in models.py
.
The main training script is animal_train.py
. It handles loading the data, initializing the model, defining the optimizer and loss function, training for multiple epochs, and evaluating on the test set.
Key parameters:
--batch-size
: Batch size for training--epochs
: Number of epochs to train for--log_path
: Path to save TensorBoard logs--save_path
: Path to save trained model checkpoints
Use python animal_train.py --help
to see all available arguments.
The script uses PyTorch and leverages GPU acceleration if available. Progress bars and TensorBoard logging are used to monitor training.
Model accuracy on the test set is evaluated at the end of each epoch. The best performing model checkpoint is saved.
To train a model:
python animal_train.py
This will train the AdvancedCNN
model for 100 epochs and save checkpoints to trained_models/animal
.
You can customize the model, hyperparameters, and output paths by modifying the commandline arguments.
The code requires the following packages:
- PyTorch
- torchvision
- tensorboard
- sklearn
- tqdm
Use pip install -r requirements.txt
to install the required packages.