-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsetUpKTRuns2.m
107 lines (82 loc) · 2.04 KB
/
setUpKTRuns2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
function [avgLogP sampCovs] = setUpKTRuns2(sigPr,a,b,T,N,M,dirName,sigCov,its)
%% written by joe austerweil 2012 for comparing the inductive bias of NN
% and graph priors over Gaussian process.
% returns avg. log marg. prob. using sigPr
% T = 1e5;
% N = 10;
% a = 2*N;
% b = 1.5*N;
% b=1000;
% sigPr = @(x)genRandWishSig(x,a, eye(N));
% sigCov = 1;
% sigCov = 1/1350;
if nargin < 5
T = 100;
% M = 100;
N = 10;
a = 1000;
b = 1000;
end
if nargin < 6
M = 100;
end
if nargin < 7
curTime = cputime;
dirName = ['/runAt' num2str(curTime) '/'];
mkdir([pwd dirName]);
end
if nargin < 8
sigCov = 1/1250;
end
if nargin < 9
its = 101;
end
sigPrParam = sigCov*eye(N);
% muPr = zeros(1,N);
thet = exp(-3);
bet = 0.4;
mkdir([pwd dirName]);
sampCovsIts = cell(its,1);
XsIts = cell(its,1);
logPs = zeros(its,1);
avgLogPs = zeros(its,1);
muPrs = cell(its,1);
objcount = N;
for it = 1:its
if (mod(it,10) == 0)
disp(['cur t:' num2str(it) ' out of ' num2str(its)]);
end
sampCovs = zeros(N,N,T);
Xs = zeros(N,M,T);
% [curSig DI] = wishrnd(sigPrParam, a);
args = [thet bet sigCov N];
muPr = zeros(N,1);
for t = 1:T
curSig = sigPr(args);
sampCovs(:,:,t) = curSig;
data = mvnrnd(muPr,curSig,M)';
Xs(:,:,t) = data;
% save([pwd dirName sigPr(-1) num2str(t)], 'data', 'objcount');
end
logP = wishartinessLP(sampCovs, a,b,sigPrParam);
avgLogP = logP/T;
muPrs{it} = muPr;
sampCovsIts{it} = sampCovs;
XsIts{it} = Xs;
logPs(it) = logP;
avgLogPs(it) = avgLogP;
end
% hist(avgLogPs); uncomment to see wishartiness distirbution
ind = find(avgLogPs == median(avgLogPs));
for t = 1:T
% curSig = sampCovIts{ind}(:,:,t);
data = XsIts{it}(:,:,t);
save([pwd dirName sigPr(-1) num2str(t)], 'data', 'objcount');
end
sampCovs = sampCovsIts{it};
Xs = XsIts{it};
avgLogP = median(avgLogPs);
logPs = median(logPs);
clear sampCovsIts;
clear XsIts;
save([pwd dirName sigPr(-1) 'allData']);