function demos_toybrain_567_djm()
% Shows how simple searchlight, wholebrain, and ROI decodings runs on simulated 3D
% toy data. The toy data are Matlab matrices and no real fMRI or EEG data,
%
% The script creates multiple volumes in the shape of an ellipsoid where
% all entries are filled with random Gaussian noise. For one class, one
% central slice will read "TDT" which constitutes the effect.
%
% Martin, 2014/10/28
%
% Modified by djm. 1/3/17, 5/9/22
% Based on demos 5-7 of The Decoding Toolbox

close all
dbstop if error % in case something goes wrong
set(0, 'DefaultAxesFontSize', 7);

%% set seed for random number generator (to get reproducible output)
rng(99);

%% add paths to required files
fpath = fileparts(which(mfilename)); % location of current file
addpath(fpath) % add this directory to path
addpath(fileparts(fileparts(fpath))) % for TDT
addpath(fileparts(fileparts(fileparts(fileparts(fpath))))); % for plotSVM_djm.m & fitplots2.m

%% choose which things to demo
dosearchlight=1;
dowholebrain=1;
doroi=1;

%% initialize TDT & cfg, and set some options
cfg = decoding_defaults;
cfg.fighandles.plot_design = 2;
cfg.results.write = 0; % no results are written to disk

%% set design parameters
n_runs = 6; % 6 runs
n_files_per_run = 8; % 8 samples per run (4 from each class)
n_files = n_files_per_run * n_runs; % 6 x 8 sample total

label = repmat(kron([1 -1],ones(1,n_files_per_run/2)),1,n_runs)';

chunk_single = ones(1,n_files)'; % single chunk to use all data when estimating weight map
chunk_perrun = kron(1:n_runs,ones(1,n_files_per_run))'; % and for normal leave-one-run-out CV

%% Create brain mask (ellipsoid)
sz = [64 64 9]; % please only change the z-dimension

[x,y,z] = ndgrid(linspace(-1,1,sz(1)),linspace(-1,1,sz(2)),linspace(-1,1,sz(3)));
mask = (x.^2+y.^2+z.^2)<=1;
mask_index = find(mask);

%% create signal ('tdt') region
tdt = false(sz);
tdt(:,:,round(sz(3)/2)) = ~(double(imread('tdt.bmp'))/255);

%% Start with noise everywhere
data_noise = 1*randn([sz n_files]);

%% combine signal and noise at a range of SNRs
snr = 0.1 : 0.1 : 0.9;
data=cell(1,length(snr));
data_orig=nan([size(data_noise) length(snr)]);
for iter = 1:length(snr)
    
    signal = snr(iter)*randn(sum(tdt(:)),1); % create signal for tdt region
    
    % Mask noise by mask and add signal in all volumes with label 1 at position of tdt
    data_orig(:,:,:,:,iter) = data_noise;
    for i_vol = 1:n_files
        cdat = data_orig(:,:,:,i_vol,iter);
        if label(i_vol) == 1 % add signal only to one label
            cdat(tdt) = cdat(tdt)+signal;
        end
        cdat(~mask) = NaN; % set all voxels outside of the mask to NaN
        data_orig(:,:,:,i_vol,iter) = cdat;
    end
    
    % Convert data to 2D matrix and mask
    data{iter} = reshape(data_orig(:,:,:,:,iter),[prod(sz) n_files])';
    data{iter} = data{iter}(:,mask_index);
end

%% Plot univariate original data
whichsnr=8;
resfig=1; figure(resfig); clf(resfig);
pos=fitplots2([2 4],'',10);
set(resfig,'WindowStyle','normal', 'WindowState','maximized','name', 'Demo 5 - toy data, searchlight, wholebrain & ROI'); % Maximize figure.

subplot('position',pos{1,1});
diff_vol = mean(data_orig(:,:,:,label==1,whichsnr),4) - mean(data_orig(:,:,:,label~=1,whichsnr),4);
imagesc(transform_vol(diff_vol)); % transform_vol tiles the volume as slices
colormap(parula(64))
axis off square
title(sprintf('Original data (difference of means; SNR=%.1f)',snr(whichsnr)))
colorbar('location', 'southoutside')
caxis([-max(abs(caxis)) max(abs(caxis))]);

%% plot ttest on original data
[h, p, ci, stats]=ttest2(permute(data_orig(:,:,:,label==1,whichsnr),[4 1 2 3]),permute(data_orig(:,:,:,label~=1,whichsnr),[4 1 2 3]));
h_fdr=fdr_bh(squeeze(p),0.05);
ax=subplot('position',pos{2,1});
imagesc(transform_vol(squeeze(h)+h_fdr),[0 3]);
colormap(ax,'hot')
colorbar('location', 'southoutside')
title('p<0.05. Red - Uncorrected; yellow - FDR')
%imagesc((transform_vol(squeeze(stats.tstat)))); title('univariate t stat')
axis off square

keyboard % pause to view data

%% set data scaling
% Enable scaling min0max1 (otherwise libsvm can get very slow)
% if you dont need model parameters, and if you use libsvm, use:
cfg.scale.method = 'min0max1';
cfg.scale.estimation = 'all'; % scaling across all data is equivalent to no scaling (i.e. will yield the same results), it only changes the data range which allows libsvm to compute faster

%% Fill passed_data
passed_data.data = data{whichsnr}; % use same SNR as plotted above
passed_data.dim = sz;
passed_data.mask_index = mask_index;
[passed_data,cfg] = fill_passed_data(passed_data,cfg,label,chunk_perrun);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% 'searchlight' examples (demo 5)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

cfg.plot_selected_voxels = 100; % Show searchlight every nth step
cfg.decoding.method = 'classification_kernel'; % fractionally quicker than 'classification'
cfg.results.output = {'accuracy_minus_chance','ninputdim'}; % To see searchlight size, add 'ninputdim';

if dosearchlight
    
    %% Set options for searchlight decoding
    cfg.analysis = 'searchlight';
    cfg.searchlight.radius = 2; % set searchlight size in voxels
    
    %% Make and display design
    cfg.design = make_design_cv(cfg);
    plot_design(cfg);
    
    %% Run searchlight decoding
    results = decoding(cfg,passed_data);
    
    % Convert to results volume(s)
    resvol = nan(sz);
    resvol(mask_index) = results.accuracy_minus_chance.output;
    
    slsizevol = nan(sz);
    slsizevol(mask_index) = results.ninputdim.output;
    
    %% Plot decoding results
    figure(resfig);
    
    subplot('position',pos{1,2});
    imagesc(transform_vol(slsizevol),[0 max(slsizevol(:))]); % Transforms 3D volume to set of axial slices
    axis off square
    title('Searchlight sizes')
    colorbar('location', 'southoutside')
    
    ax=subplot('position',pos{2,2});
    imagesc(transform_vol(resvol),[-50 50]); % Transforms 3D volume to set of axial slices
    colormap(ax,'parula')
    axis off square
    title('Searchlight results (accuracy-chance)')
    colorbar('location', 'southoutside')
    
    keyboard % pause to view searchlight results
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% 'wholebrain' (demo 6) example 1 - recursive feature elimination
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if dowholebrain
    
    %% Set general parameters for wholebrain decoding
    cfg.analysis = 'wholebrain'; % alternatives: 'searchlight', 'wholebrain' ('ROI' does not make sense here);
    cfg.plot_selected_voxels = 0; % all x steps, set 0 for not plotting, 1 for each step, 2 for each 2nd, etc
    
    %% Set parameters for first example - Recursive Feature Elimination
    cfg.results.output = {'accuracy_minus_chance'};
    cfg.design = make_design_cv(cfg);
    
    cfg.feature_selection.estimation='across'; % estimation of optimal features is done on training data, only
    cfg.feature_selection.method = 'embedded';
    cfg.feature_selection.embedded = 'RFE'; % Recursive Feature Elimination; only embedded method currently offered
    % Recursively trains classifier and eliminates feature subset n with the lowest weight until criterion is reached (see Guyon et al., 2002).
    cfg.feature_selection.direction = 'backward'; % must be backward for Recursive Feature Elimination
    cfg.feature_selection.n_vox = [5 10 20 40 80 150 300]; % possible number of voxels to finally select
    cfg.feature_selection.nested_n_vox = [5 10 20 40 80 150 300 600 1000]; % possible number of voxels to eliminate per iteration
    %cfg.feature_selection.decoding.method = 'classification';
    
    %% Run wholebrain RFE analysis
    results_RFE = decoding(cfg,passed_data);
    
    % Convert to results volume (and set of slices)
    % Here we plot, for each voxel, on how many folds (out of 6) it is
    % selected to be in the final classifier. 
    resvol_RFE = zeros(sz);
    
    fs_index = results_RFE.feature_selection.fs_index;
    for i = 1:length(fs_index)
        resvol_RFE(mask_index(fs_index{i})) = resvol_RFE(mask_index(fs_index{i}))+1;
    end
    % (I am not sure why results_RFE.accuracy_minus_chance.output is at
    % chance level?)
    
    %% Plot wholebrain RFE analysis
    figure(resfig);
    subplot('position',pos{1,3});
    imagesc(transform_vol(resvol_RFE),[0 n_runs]);
    axis off square
    title('Selected voxels in "Wholebrain" RFE (# of CVs)')
    colorbar('location', 'southoutside')
    
    keyboard % pause to view RFE results
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %% 'wholebrain' example 2 - create weight map
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %% Set parameters for second example - getting SVM weights
    cfg.feature_selection.method = 'none'; % turn feature selection off
    cfg.decoding.method = 'classification';  % can't get weights when using 'classification_kernel'
    cfg.results.output = {'SVM_weights'}; % output weights instead of accuracy
    
    % Disable scaling min0max1 to allow estimating model_parameters
    cfg.scale.method = 'none';
    cfg.scale.estimation = 'none';
    cfg.scale.IKnowThatLibsvmCanBeSlowWithoutScaling = 1;
    
    %% Make (and display) new design: use all data (don't split into folds)
    cfg.files.chunk = chunk_single; % Update chunk definition
    cfg.design = make_design_alldata(cfg);
    plot_design(cfg);
    
    %% Run second wholebrain analysis
    results_SVMweights = decoding(cfg,passed_data);
    
    % Convert to results volume
    resvol_SVMweights = nan(sz);
    resvol_SVMweights(mask_index) = results_SVMweights.SVM_weights.output{1}{1};
    
    %% Plot wholebrain SVM weight analysis
    figure(resfig);
    subplot('position',pos{2,3});
    imagesc(transform_vol(resvol_SVMweights));
    axis off square
    caxis([-max(abs(caxis)) max(abs(caxis))]);
    title('"Wholebrain" SVM weights')
    colorbar('location', 'southoutside')
    
    keyboard % pause to view weight map results
    
    %{
 To interpret classifier weights, see:
 Haufe, S. et al. 2014. On the interpretation of weight vectors of linear models in multivariate
 neuroimaging. Neuroimage 87, 96-110.
    %}
    
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% ROI example (demo 7)
% Two ROIs are created: One overlapping more with TDT, the other less. The
% SNR is varied and the analysis is repeated and finally plotted.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if doroi
    
    %% Create two ROI masks
    xx = round([1/3 2/3] *sz(2));
    zz = [ceil(2/5 * sz(3)) ceil(3/5 * sz(3))];
    
    y1 = round([1 sz(1)/3]); % 1st ROI covers 1/3 of signal
    [x,y,z] = ndgrid(xx(1):xx(2),y1(1):y1(2),zz(1):zz(2));
    roi1_index = sub2ind(sz,x(:),y(:),z(:));
    
    y2 = round([sz(1)/3+2 sz(1)]); % 1st ROI covers 2/3 of signal
    [x,y,z] = ndgrid(xx(1):xx(2),y2(1):y2(2),zz(1):zz(2));
    roi2_index = sub2ind(sz,x(:),y(:),z(:));
    
    % put them in passed_data (excluding out-of-brain voxels)
    passed_data.mask_index_each{1} = intersect(roi1_index,mask_index);
    passed_data.mask_index_each{2} = intersect(roi2_index,mask_index);
    
    % add a third ROI that covers everything
    passed_data.mask_index_each{3} = mask_index;
    
    %% plot ROIs
    vol = zeros(sz);
    vol(mask_index) = 2;
    vol(passed_data.mask_index_each{1}) = 3;
    vol(passed_data.mask_index_each{2}) = 1;
    
    figure(resfig);
    subplot('position',pos{1,4});
    imagesc(transform_vol(vol),[0 4]);
    axis off square
    title('ROIs')
    
    keyboard % pause to view ROIs
    
    %% Set parameters for ROI decoding
    cfg.analysis = 'roi'; 
    cfg.plot_selected_voxels = 1; % all x steps, set 0 for not plotting, 1 for each step, 2 for each 2nd, etc
    cfg.results.output = {'accuracy'};
    cfg.decoding.method = 'classification_kernel';
    
    %% change back to leave-one-run out cross-validation design
    cfg.files.chunk = chunk_perrun; % Update chunk definition
    cfg.design = make_design_cv(cfg);
    plot_design(cfg);
    
    %% Run ROI decoding and store results for each SNR
    allresults=nan(length(passed_data.mask_index_each),length(snr));
    for iter = 1:length(snr)
        passed_data.data = data{iter};
        results = decoding(cfg,passed_data);
        allresults(:,iter) = results.(cfg.results.output{1}).output;
    end
    
    %% plot decoding accuracy
    figure(resfig);
    ax=subplot('position',pos{2,4}); cla(ax)
    
    hold on
    colors = [201 143  54; 29 175 226; 100 200 100] ./255;
    ph=nan(1,size(allresults,1));
    for r=1:size(allresults,1)
        ph(r) = plot(allresults(r,:),'-o','linewidth',3,'color',colors(r,:));
    end
    
    set(ax,'xtick',1:1:length(snr),'xticklabel',snr(1:1:end))
    axis square on
    axis([0 length(snr)+1 0 105])
    plot(xlim,[50 50],'color',[.7 .7 .7])
    
    legend(ph(1:3),{sprintf('Small ROI (%d voxels)',length(results.mask_index_each{1})),...
        sprintf('Large ROI (%d voxels)',length(results.mask_index_each{2})),...
        sprintf('Large noisy ROI (%d voxels)',length(results.mask_index_each{3}))},...
        'location','southeast')
    title(' ROI decoding accuracy')
    xlabel('SNR');
    
    keyboard % pause to view ROI results
    
    %%%%% Q) What do we conclude from this?
    
    %%%%% Q) What happens if the large ROI contains more non-signal voxels?
    
    %%%%% Q) How might results depend on choice of classifier?
    
    %%%%% Q) How might results depend on the number of samples?
    
    %%%%% Q) How could you combine RFE with ROIs?
    
end

return
