-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.m
119 lines (96 loc) · 3.92 KB
/
main.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
% This is the entry point file for the project
% Clear all variables.
clear;
close all;
clc;
% Read the .png image files in the grayscale folder into a file name vector.
imageFilesList = glob('grayscale/*.png');
% Instantiate the training set matrix X
XAll = [];
mAll = length(imageFilesList);
yAll = zeros(mAll,1);
% Randomize the file list.
randVec = randperm(mAll);
% Image scale - determines by how much the loaded image is scaled.
image_scale = 0.1;
for i = 1:mAll
fileName = imageFilesList(randVec(i)){1,1};
% Replace the backslash character with forward slash to avoid errors due to special characters.
fileName(fileName == "\\") = "/";
% Load the image into a vector
XAll(i,:) = loadImageVector(fileName, image_scale);
% Check from the file name whether image is from male or female egg
fileName = strsplit(fileName, '/');
fname = fileName{1,2};
% Female eggs are labelled 1 all others are 0 :)
if(regexp(fname, "F*") == 1)
yAll(i) = 1;
else
yAll(i) = 0;
endif
end
% Get the number of input features
n = size(XAll)(:,2);
% We now have the input matrix XAll and output vector yAll.
% We need to split these into a training set (X_train, y_train),
% cross-validation set (X_cv, y_cv), and test set (X_test, y_test).
% The ratios are 60:20:20 respectively.
m_train = 0.6 * mAll;
m_cv = 0.2 * mAll;
m_test = m_cv;
% The input matrices.
X_train = XAll(1:m_train,:);
X_cv = XAll((m_train+1):(m_train+m_cv),:);
X_test = XAll((m_train+m_cv+1):(m_train+m_cv+m_test),:);
% The output vectors.
y_train = yAll(1:m_train,:);
y_cv = yAll((m_train+1):(m_train+m_cv),:);
y_test = yAll((m_train+m_cv+1):(m_train+m_cv+m_test),:);
fprintf(['All training data, cross-validation data,'...
'and test data is loaded \n'...
'\nPress any key to proceed\n']);
pause;
% Some parameters that we will use
input_layer_size = n;
hidden_layer_size = 30;
final_layer_size = 1; % the output will be a 0 or a 1
% We now initialize Theta1 and Theta2.
% The architecture of the neural network is such that it has three layers.
% One input layer, one hidden layer, and one output layer with one unit.
init_theta1 = initializeWeights(input_layer_size, hidden_layer_size);
init_theta2 = initializeWeights(hidden_layer_size, final_layer_size);
init_weights = [init_theta1(:); init_theta2(:)];
fprintf('Random weights initialized\n\nPress any key to continue.\n');
pause;
% Regularization parameter
lambda = 0;
% The calculated cost
J = costFunctionReg(init_weights,...
input_layer_size,...
hidden_layer_size,...
final_layer_size,...
X_train, y_train, lambda);
fprintf(['\nInitial cost is found to be: %f'...
'\n\nPress any key to continue\n'], J);
pause;
% The cost function to be minimized becomes
costFunction = @(input_params) costFunctionReg(input_params,...
input_layer_size,...
hidden_layer_size,...
final_layer_size,...
X_train, y_train, lambda);
options = optimset('MaxIter', 50);
[params, cost] = fmincg(costFunction, init_weights, options);
% From the params, reshape to obtain our final weights
Theta1 = reshape(params(1:(hidden_layer_size * (input_layer_size + 1))),...
hidden_layer_size,...
(input_layer_size + 1));
Theta2 = reshape(params(((hidden_layer_size * (input_layer_size + 1)) + 1):end),...
final_layer_size,...
(hidden_layer_size + 1));
% Predict the outcome against the training set
prediction = predict(Theta1, Theta2, X_train);
fprintf('\nTraining Set Accuracy: %f\n', mean(double(prediction == y_train)) * 100);
% Predict the outcome against the cross-validation set
prediction_cv = predict(Theta1, Theta2, X_cv);
fprintf('\nCross Validation Set Accuracy: %f\n', mean(double(prediction_cv == y_cv)) * 100);