-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassificationperformance.m
56 lines (52 loc) · 1.88 KB
/
classificationperformance.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
function [ROCperf,PrecRecallPerf] = classificationperformance(labels,predictions,probs,positiveclass,modelname,titlestr)
%% This code reads caclulates the Confusion Matrix, ROC and PR AUCs
% Description: Takes input from Main_hip_OA_trainer
% % Inputs: True labels (labels), predictions, prediction probabilities
% (probs), positive class, modelname and title string for the plots.
%
% % Outputs: A structure containing the ROC and PR perfromances
%
% (C) Robel K. Gebre
% Medical Imaging, Physics and Technology (MIPT)
% University of Oulu, Oulu, Finland
% 2021
%%
ROCperf = struct();
PrecRecallPerf = struct();
%% 1.1 Confusion Matrix
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(labels,predictions);
cm.Title = 'Confusion Matrix';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
%% 1.2 AUCROC curve
[X,Y,T,AUC] = perfcurve(labels,probs,positiveclass);
figure;
plot(X,Y);
xlabel('False positive rate')
ylabel('True positive rate')
legend(strcat(modelname,'(AUC = ', num2str(round(AUC,2)),')'),'Location',"southeast")
title(strcat(titlestr,', ROC Curve'))
ROCperf.fpr = X;
ROCperf.tpr = Y;
ROCperf.threshold = T;
ROCperf.ROCAUC = AUC;
%% 1.3 Precision-recall curve
[Xpr,Ypr,Tpr,AUCpr] = perfcurve(labels,probs,positiveclass,'xCrit', 'reca', 'yCrit', 'prec');
figure;
plot(Xpr,Ypr)
xlabel('Recall')
ylabel('Precision')
legend(strcat(modelname,'(AUC = ', num2str(round(AUCpr,2)),')'),'Location',"northeast")
title(strcat(titlestr,', Precision-Recall Curve'))
Precision = nanmean(Ypr,'all');
Recall = nanmean(Xpr,'all');
F1_score = 2 * (Precision*Recall/(Precision+Recall));
PrecRecallPerf.Recall = Xpr;
PrecRecallPerf.Precision = Ypr;
PrecRecallPerf.threshold = Tpr;
PrecRecallPerf.PRAUC = AUCpr;
PrecRecallPerf.Precision = Precision;
PrecRecallPerf.Recall = Recall;
PrecRecallPerf.FScore = F1_score;
end