Skip to content

Commit

Permalink
Adapt multitaper svd and plot results nicely
Browse files Browse the repository at this point in the history
  • Loading branch information
Proulx-S committed Sep 5, 2023
1 parent ac44ed7 commit 621752d
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 20 deletions.
75 changes: 55 additions & 20 deletions runMTsvd.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function svdStruct = runMTsvd(anaType,funTs,fpass,W,mask,normFact)
function svdStruct = runMTsvd(anaType,funTs,fpass,W,K,mask,vecNorm)
% Similar to Mitra 1997. A single svd is run on data tapered for
% sensitivity over user-defined frequency band (fpass).
tsMean = mean(funTs.vol,4);
Expand All @@ -10,9 +10,9 @@
end

%% Apply timeseries normalization
% normFact based on psd so we need to use its square root here
if exist('normFact','var') && ~isempty(normFact)
funTs.vec = funTs.vec ./ sqrt(normFact(logical(mask(funTs.vol2vec))));
% normFact (vector) based on psd so we need to use its square root here
if exist('normFact','var') && ~isempty(vecNorm)
funTs.vec = funTs.vec ./ sqrt(vecNorm(logical(mask(funTs.vol2vec))));
end

%% Detrend time series (detrend up to order-2 polynomial, since this is the highest order not fitting a sinwave)
Expand Down Expand Up @@ -104,25 +104,59 @@


case 'svdKlein'
error('double-check all that')
% if exist('W','var') && ~isempty(W)
% anaType = 'svdKlein';
T = tr.*nFrame;
TW = T*W;
K = round(TW*2-1);
TW = (K+1)/2;
param.tapers = [TW K];
% % error('double-check all that')
% % if exist('W','var') && ~isempty(W)
% % anaType = 'svdKlein';
% T = tr.*nFrame;
% TW = T*W;
% K = round(TW*2-1);
% TW = (K+1)/2;
% param.tapers = [TW K];
% if ~isempty(fpass)
% param.fpass = fpass;
% end
% mdkp = [];
% [~,f] = mtspectrumc(funTs.vec(:,1), param);
% %%% Display actual half-widht used
% Wreal = TW/T;
% display(['w (halfwidth) requested : ' num2str(W,'%0.5f ')])
% display(['w (halfwidth) used : ' num2str(Wreal,'%0.5f ')])
% display(['tw (time-halfwidth) used : ' num2str(TW)])
% display(['k (number of tapers) used: ' num2str(K)])




Wflag = exist('W','var') && ~isempty(W);
Kflag = exist('K','var') && ~isempty(K);
if Wflag && Kflag
error('Cannot specify both W and K');
elseif Wflag
T = tr.*funTs.nframes;
TW = T*W;
K = round(TW*2-1);
TW = (K+1)/2;
param.tapers = [TW K];
Wreal = TW/T;
display(['w (halfwidth) requested : ' num2str(W,'%0.5f ')])
display(['w (halfwidth) used : ' num2str(Wreal,'%0.5f ')])
display(['tw (time-halfwidth) used : ' num2str(TW)])
display(['k (number of tapers) used: ' num2str(K)])
elseif Kflag
TW = (K+1)/2;
T = tr.*funTs.nframes;
W = TW/T; Wreal = W;
param.tapers = [TW K];
display(['k (number of tapers) requested : ' num2str(K)])
display(['w (halfwidth) used : ' num2str(W,'%0.5f ')])
display(['tw (time-halfwidth) used : ' num2str(TW)])
end
if ~isempty(fpass)
param.fpass = fpass;
end
mdkp = [];
[~,f] = mtspectrumc(funTs.vec(:,1), param);
%%% Display actual half-widht used
Wreal = TW/T;
display(['w (halfwidth) requested : ' num2str(W,'%0.5f ')])
display(['w (halfwidth) used : ' num2str(Wreal,'%0.5f ')])
display(['tw (time-halfwidth) used : ' num2str(TW)])
display(['k (number of tapers) used: ' num2str(K)])

case 'svdMitra'
% else
% anaType = 'svdMitra';
Expand Down Expand Up @@ -150,7 +184,8 @@
tic
% param.fpass = fpass;
% [u,s,v,f,bandV] = spsvd2(funTs,param); sp = []; sv = []; fm = [];
[sv,sp,fm,u,s,v,a,proj] = spsvd(funTs.vec,param,mdkp);
% [sv,sp,fm,u,s,v,a,proj] = spsvd(funTs.vec,param,mdkp);
[sv,sp,fm] = spsvd(funTs.vec,param,mdkp);
toc

% %% Reconstruct reduced psd
Expand Down Expand Up @@ -259,7 +294,7 @@

%% Output
svdStruct.mask = mask;
svdStruct.normFact = normFact;
svdStruct.normFact = vecNorm;
svdStruct.tsMean = tsMean;
svdStruct.dim = strjoin({'space/taper' 'freq' 'modes'},' X ');
svdStruct.sv = permute(sv,[3 1 2]);
Expand Down
119 changes: 119 additions & 0 deletions viewMTsvdKlein.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
function axIm = viewMTsvdKlein(svdStruct,funPsd,peakFreq,imType,axIm)
if ~exist("imType",'var') || isempty(imType)
imType = 'sv'; % sv or psd
end

tmp = strsplit(funPsd.fspec,filesep); tmp = strsplit(tmp{end-1},'_');
sub = replace(tmp{contains(tmp,'sub-')},'sub-','');
ses = replace(tmp{contains(tmp,'ses-')},'ses-','');
run = replace(tmp{contains(tmp,'run-')},'run-','');
fpass = svdStruct.param.fpass;
K = svdStruct.param.tapers(2);

%%% plot power and eigen spectra
figure('WindowStyle','docked');
hTile = tiledlayout(3,3);
hTile.Padding = 'tight'; hTile.TileSpacing = 'tight';
nexttile([1 3])
yyaxis left
plot(funPsd.roi.f,funPsd.roi.psd); hold on
axis tight
ylabel('psd')
yyaxis right
plot(svdStruct.f,svdStruct.c); hold on
ylim([1/K 1])
ylabel('coherence')
ax1 = gca;
ax1.YAxis(2).Scale = 'linear';
ax1.YAxis(1).Scale = 'log';
xlim(fpass)
grid minor
title(['sub-' sub '; ses-' ses '; run-' run ' K=' num2str(K)])
drawnow

% figure('WindowStyle','docked');
% hTile = tiledlayout(3,3);
% hTile.Padding = 'tight'; hTile.TileSpacing = 'tight';
% nexttile([1 3])
% yyaxis left
% plot(funPsd.roi.f,funPsd.roi.psd); hold on
% axis tight
% ylabel('psd')
% ax = gca;
% ax.YScale = 'log';
% xlim(fpass)
% grid minor
% title(['sub-' sub '; ses-' ses '; run-' run ' K=' num2str(K)])
% drawnow
% yyaxis right
% ylabel(' ')

plot(([1 1].*peakFreq')',repmat(ylim,length(peakFreq),1)','-k')
yLim = ylim;
plot((peakFreq'+[-1 1].*svdStruct.w)',yLim(1).*ones(size(peakFreq,2),2)','-g','LineWidth',3)
text(peakFreq+0.005,ones(size(peakFreq)).*yLim(2),cellstr(num2str(peakFreq','%.3fHz')),'Rotation',-90,'VerticalAlignment','baseline')
xlabel('Hz')


if ~exist('axIm','var') || isempty(axIm)
axIm = {};
end
axIm{end+1} = nexttile;
imagesc(funPsd.tMean); hold on
maskC = getMaskOutline(funPsd.roi.mask,5);
hMask = plot(maskC); hMask.FaceColor = 'none'; hMask.EdgeColor = 'r';
axIm{end}.PlotBoxAspectRatio = [1 1 1]; axIm{end}.DataAspectRatio = [1 1 1]; axIm{end}.XAxis.Visible = 'off'; axIm{end}.YAxis.Visible = 'off'; axIm{end}.Colormap = gray;
title('timeseries mean')
switch imType
case 'psd'
%%% power spectra maps
for fInd = 1:length(peakFreq)
if fInd>5; break; end
axIm{end+1} = nexttile;
[~,b] = min(abs(svdStruct.f - peakFreq(fInd)));
pwrSpace = vec2vol(funPsd);
cLim = permute(pwrSpace.vol,[4 3 1 2]);
cLim = cLim(b,1,logical(svdStruct.mask));
cLim = [1 max(cLim(:))];
hIm = imagesc(pwrSpace.vol(:,:,:,b),cLim);
axIm{end}.ColorScale = 'log';
axIm{end}.PlotBoxAspectRatio = [1 1 1]; axIm{end}.DataAspectRatio = [1 1 1]; axIm{end}.XAxis.Visible = 'off'; axIm{end}.YAxis.Visible = 'off'; axIm{end}.Colormap = jet;
title([num2str(peakFreq(fInd),'%.3f') 'Hz'])
end
linkaxes([axIm{:}])
cb = colorbar;
cbTicks = cb.Ticks; if cbTicks(1) ~= cLim(1); cbTicks = [cLim(1) cbTicks]; end; if cbTicks(end) ~= cLim(2); cbTicks = [cbTicks cLim(2)]; end; cb.Ticks = cbTicks; cb.TickLabels(2:end-1) = {''};
cb.TickLabelInterpreter = 'latex';
cb.TickLabels{1} = ['$$\begin{array}{c}' 'noise' '\\' 'floor' '\\' '\end{array}$$'];
cb.TickLabels{end} = 'max';
hYlabel = ylabel(cb,'psd');
hYlabel.Units = 'normalized';
hYlabel.Position(1) = 1;
case 'sv'
%%% spatial sv
for fInd = 1:length(peakFreq)
if fInd>5; break; end
axIm{end+1} = nexttile;
[~,b] = min(abs(svdStruct.f - peakFreq(fInd)));
svSpace = nan(size(svdStruct.mask));
svSpace(logical(svdStruct.mask)) = svdStruct.sp(:,b,1);
imagesc(abs(svSpace));
axIm{end}.PlotBoxAspectRatio = [1 1 1]; axIm{end}.DataAspectRatio = [1 1 1]; axIm{end}.XAxis.Visible = 'off'; axIm{end}.YAxis.Visible = 'off'; axIm{end}.Colormap = jet;
title([num2str(peakFreq(fInd),'%.3f') 'Hz'])
cLim = abs(svSpace(logical(svdStruct.mask)));
cLim = [min(cLim(:)) max(cLim(:))];
axIm{end}.CLim = cLim;
end
cb = colorbar;
cbTicks = cb.Ticks; if cbTicks(1) ~= cLim(1); cbTicks = [cLim(1) cbTicks]; end; if cbTicks(end) ~= cLim(2); cbTicks = [cbTicks cLim(2)]; end; cb.Ticks = cbTicks; cb.TickLabels(2:end-1) = {''};
cb.TickLabels{1} = 'min'; cb.TickLabels{end} = 'max';
hYlabel = ylabel(cb,'mag of sv weigths');
hYlabel.Units = 'normalized';
hYlabel.Position(1) = 1;
end

%% Adjust some stuff
cb.Location = 'manual';
drawnow
cb.Position(1) = sum(axIm{end}.Position([1 3]));
linkaxes([axIm{:}])
152 changes: 152 additions & 0 deletions viewMTsvdKlein2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
function axIm = viewMTsvdKlein2(svdStruct,funPsd,peakFreq,imType,axIm)

if ~exist("imType",'var') || isempty(imType)
imType = 'psd'; % svMag, svPhase or psd
end

tmp = strsplit(funPsd.fspec,filesep); tmp = strsplit(tmp{end-1},'_');
sub = replace(tmp{contains(tmp,'sub-')},'sub-','');
ses = replace(tmp{contains(tmp,'ses-')},'ses-','');
run = replace(tmp{contains(tmp,'run-')},'run-','');
fpass = svdStruct.param.fpass;
fLim = [0 2.5];
K = svdStruct.param.tapers(2);
f = svdStruct.f;

%% Plot timeseries mean (underlay) only of not already plotted in axIm
if ~exist('axIm','var') || isempty(axIm)
axIm = {};
tmp = []; tmp.String = '';
else
tmp = [axIm{:}]; tmp = [tmp.Title];
end
if ~any(contains({tmp.String},'timeseries mean'))
figure('WindowStyle','docked');
hTile = tiledlayout(6,1); hTile.Padding = 'tight'; hTile.TileSpacing = "tight";
nexttile([5 1]);
imagesc(funPsd.tMean); hold on
maskC = getMaskOutline(funPsd.roi.mask,5);
hMask = plot(maskC); hMask.FaceColor = 'none'; hMask.EdgeColor = 'r';
axIm{end+1} = gca;
%%% Crop image
cropPrc = 75;
xLim = [find(any(funPsd.tMean>prctile(funPsd.tMean(:),cropPrc),1),1,'first') find(any(funPsd.tMean>prctile(funPsd.tMean(:),cropPrc),1),1,'last')];
yLim = [find(any(funPsd.tMean>prctile(funPsd.tMean(:),cropPrc),2),1,'first') find(any(funPsd.tMean>prctile(funPsd.tMean(:),cropPrc),2),1,'last')];
w = max([diff(xLim) diff(yLim)]);
xLim = mean(xLim)+[-1 1].*w/2;
yLim = mean(yLim)+[-1 1].*w/2;
axis(axIm{end},[xLim yLim]);
axIm{end}.PlotBoxAspectRatio = [1 1 1]; axIm{end}.DataAspectRatio = [1 1 1];
axIm{end}.XTick = []; axIm{end}.YTick = [];
axIm{end}.Colormap = gray;
cb = colorbar;
ylabel(cb,'MR intensity')
title(['sub-' sub '; ses-' ses '; run-' run ' K=' num2str(K) '; timeseries mean'])

%%% Plot spectrum
nexttile([1 1]);
yyaxis left
plot(funPsd.roi.f,funPsd.roi.psd); hold on
axis tight
ylabel('psd')
yyaxis right
plot(f,svdStruct.c); hold on
ylim([1/K 1])
ylabel('coherence')
ax1 = gca;
ax1.YAxis(2).Scale = 'linear';
ax1.YAxis(1).Scale = 'log';
xlim(fLim)
grid minor

%%% Plot spectrum peaks
if ~isnan(peakFreq)
plot(([1 1].*peakFreq')',repmat(ylim,length(peakFreq),1)','-k')
yLim = ylim;
plot((peakFreq'+[-1 1].*svdStruct.w)',yLim(1).*ones(size(peakFreq,2),2)','-g','LineWidth',3)
xlabel('Hz')
end

%%% Adjust
cb.Location = 'manual';
drawnow
cb.Position = [sum(axIm{end}.Position([1 3])) axIm{end}.Position(2) 0.02 axIm{end}.Position(4)];
cb.Visible = 'off';
end

peakFreq = peakFreq(~isnan(peakFreq));
%% Plot each frequency
for fInd = 1:length(peakFreq)
[~,b] = min(abs(f - peakFreq(fInd)));
figure('WindowStyle','docked');
%%% Plot map
axIm{end+1} = axes(gcf,'Units',axIm{end}.Units,'Position',axIm{end}.Position);
switch imType
case 'psd'
curIm = vec2vol(funPsd); curIm = curIm.vol(:,:,:,b);
hIm = imagesc(curIm);
axIm{end}.Colormap = jet;
cb = colorbar;
ylabel(cb,'psd')
axIm{end}.ColorScale = 'log';
cLim = curIm(logical(funPsd.roi.mask)); cLim = [1 max(cLim)];
axIm{end}.CLim = cLim;
case 'svMag'
curIm = nan(size(svdStruct.mask));
curIm(logical(svdStruct.mask)) = svdStruct.sp(:,b,1);
hIm = imagesc(abs(curIm));
axIm{end}.Colormap = jet;
cb = colorbar;
ylabel(cb,'spatial sv weigth mag')
cLim = abs(curIm(logical(funPsd.roi.mask))); cLim = [min(cLim) max(cLim)];
axIm{end}.CLim = cLim;
cbTicks = cb.Ticks; if cbTicks(1) ~= cLim(1); cbTicks = [cLim(1) cbTicks]; end; if cbTicks(end) ~= cLim(2); cbTicks = [cbTicks cLim(2)]; end; cb.Ticks = cbTicks; cb.TickLabels(2:end-1) = {''};
cb.TickLabels{1} = 'min'; cb.TickLabels{end} = 'max';
case 'svPhase'
curIm = nan(size(svdStruct.mask));
curIm(logical(svdStruct.mask)) = svdStruct.sp(:,b,1);
hIm = imagesc(wrapToPi(angle(curIm) - angle(mean(curIm(:),'omitnan'))));
hIm.AlphaData = ( abs(curIm) - min(abs(curIm(:))) ) ./ max(abs(curIm(:)));
axIm{end}.Color = [1 1 1].*0.5;
axIm{end}.Colormap = hsv;
cb = colorbar;
hYl = ylabel(cb,'spatial sv weigth phase');
cLim = [-pi pi];
axIm{end}.CLim = cLim;
cb.Ticks = [-pi -pi/2 0 pi/2 pi];
cb.TickLabels = {'-pi' '-pi/2' '0' 'pi/2' 'pi'};
end
axIm{end}.PlotBoxAspectRatio = [1 1 1]; axIm{end}.DataAspectRatio = [1 1 1];
axIm{end}.XTick = []; axIm{end}.YTick = [];
title(['sub-' sub '; ses-' ses '; run-' run ' K=' num2str(K) '; ' num2str(peakFreq(fInd),'%.3f') 'Hz'])
axis([axIm{end-1}.XLim axIm{end-1}.YLim])

%% Plot spectra
axes(gcf,'Units',axIm{end-1}.Parent.Children(1).Units,'Position',axIm{end-1}.Parent.Children(1).Position);
yyaxis left
plot(funPsd.roi.f,funPsd.roi.psd); hold on
axis tight
ylabel('psd')
ytickformat('%2.0f')
yyaxis right
plot(f,svdStruct.c); hold on
ylim([1/K 1])
ylabel('coherence')
ax = gca; ax.YAxis(2).Scale = 'linear'; ax.YAxis(1).Scale = 'log';
xlim(fLim)
grid minor

plot(([1 1].*f(b)')',ylim,'-k')
yLim = ylim;
plot((f(b)'+[-1 1].*svdStruct.w)',yLim(1).*[1 1]','-g','LineWidth',3)
xlabel('Hz')

%% Adjust
% cb.Location = 'manual';
% drawnow
axIm{end}.Position = axIm{end-1}.Position;
% cb.Position = [sum(axIm{end}.Position([1 3])) axIm{end}.Position(2) 0.02 axIm{end}.Position(4)];
end
linkaxes([axIm{:}])


0 comments on commit 621752d

Please sign in to comment.