function demos_realdata_motion_direction_89_djm_plus()
% This example runs a decoding analysis on the example data that can
% be downloaded from the TDT webpage. It performs various ROI decodings
% between up and down motion.
%
% Variations on the first decoding of 'demos_realdata_motion_direction_89_djm'
%
% Based on demo 8
% Modified by djm. 1/3/17, 5/9/22

dbstop if error % in case something goes wrong
set(0, 'DefaultAxesFontSize', 7);
rng(1);

%% Check that SPM and TDT are available on the path
fpath=fileparts(which(mfilename));
addpath(fpath) % add this directory to path
addpath(fileparts(fileparts(fpath))) % for TDT

if isempty(which('SPM'))
    try
        addpath(fullfile(fpath,'..','..','..','..','\spm12_fil_r7771')); 
        spm('fmri')
    catch
        error('Please add SPM to the path and restart'),
    end
end

if isempty(which('decoding_defaults')), error('Please add TDT to the path and restart'), end
decoding_defaults; % add all important directories to the path

%% Locate data directory
databasedir = fullfile(fileparts(which(mfilename)),'..','TDTdemo8data');
check_subdirs = {'sub01_firstlevel_reducedResolution/sub01_GLM_3x3x3mm'; 'sub01_firstlevel/sub01_GLM'};
for c_ind = 1:length(check_subdirs)
        d = fullfile(databasedir, check_subdirs{c_ind});
        if exist(d, 'dir')
                beta_loc = d; break
        end
end
if ~exist('beta_loc','var')
        beta_loc = uigetdir('', 'Select the sub01_GLM* directory from the demo data (inside sub01_firstlevel*)');
end
dispv(1, 'Located demodata in %s, starting analysis', beta_loc);

%% General settings

cfg = decoding_defaults; % add all important directories to the path, and set defaults
cfg.analysis = 'roi'; % 'searchlight', 'wholebrain', 'ROI' (if ROI, set one or multiple ROI images as mask files below instead of the mask)
cfg.decoding.method='classification';
cfg.decoding.software='libsvm';
cfg.decoding.train.classification.model_parameters = '-s 0 -t 0 -c 1 -b 0 -q';
cfg.design.function.name='make_design_cv';
cfg.cv='cvAll2fold'; % not a TDT field! Will be used to set up cross-validation scheme below
cfg.scale.method='min0max1'; cfg.scale.estimation='all';
cfg.feature_transformation.n_vox='';
cfg.results.overwrite = 1;
cfg.results.setwise=1;
cfg.results.write=0; % the meaning of this changes across versions!! Write nothing, .mat files only, or .mat files and .nii files
cfg.fighandles.plot_design = 10;
cfg.plot_selected_voxels = 0;
cfg.verbose = 1;

%% conditions
cues={'color','direction'};
directions = {'up','down'};
colours={'red','green'};
regressor_names = design_from_spm(beta_loc);
cfg = decoding_describe_data(cfg,directions,[-1 1],regressor_names,beta_loc);
runs=cfg.files.chunk;

cfg.results.output = {'accuracy_minus_chance'};

%% Set ROIs
if exist(fullfile(beta_loc, '..', 'sub01_ROI_3x3x3mm'), 'dir')
        cfg.files.mask = fullfile(beta_loc, '..', 'sub01_ROI_3x3x3mm', {'v1.img', 'v4_both.img','mt_both.img','m1_left.img'}); % reduce data
elseif exist(fullfile(beta_loc, '..', 'sub01_ROI'), 'dir')
        cfg.files.mask = fullfile(beta_loc, '..', 'sub01_ROI', {'v1.img', 'v4_both.img', 'mt_both.img','m1_left.img'}); % reduce data
else
        cfg.files.mask = uigetfile('', 'Could not automatically find ROI folder, please select which ROIs to use');
end

%% analyses
cfg.extralabel='';
analysis(1:12)=cfg;

% test different cv splits:
analysis(1).cv='cv2fold'; % as previous demo
analysis(2).cv='cv8fold';
analysis(3).cv='cvAll2fold';

% test different classifiers:
analysis(4).decoding.software='liblinear'; % settings below will implement L2-regularised SVM (L1, logistic regression, regression also available)
analysis(5).decoding.software='newton'; % a different SVM implementation
analysis(6).decoding.software='lda'; analysis(6).feature_transformation=struct('method','PCA','estimation','all','n_vox',20,'scale',cfg.scale); % note PCA added for speed
analysis(7).decoding.software='correlation_classifier';
analysis(8).decoding.train.classification.model_parameters = '-s 0 -t 2 -c 1 -b 0 -q'; analysis(8).extralabel='RB';

% test different data transformations etc.:
analysis(9).scale=struct('method','none','estimation','none','IKnowThatLibsvmCanBeSlowWithoutScaling',true);
analysis(10).feature_transformation=struct('method','PCA','estimation','all','n_vox','all','scale',cfg.scale);
analysis(11).feature_transformation=struct('method','PCA','estimation','all','n_vox',20,'scale',cfg.scale);
analysis(12).feature_transformation=struct('method','PCA','estimation','all','n_vox',5,'scale',cfg.scale);

figure(100); clf(100);
pos=fitplots2([length(analysis)+1 2]);
for a=1:length(analysis)
        
        % Settings for this analysis:
        name=sprintf('%s_%s%s_scale%s_trans%s%s',analysis(a).cv,...
                analysis(a).decoding.software, analysis(a).extralabel, analysis(a).scale.method,...
                analysis(a).feature_transformation.method, num2str(analysis(a).feature_transformation.n_vox));
        analysis(a).results.dir = fullfile(beta_loc, 'results', 'motion_up_vs_down', name);
        
        switch analysis(a).cv
                case 'cv2fold'
                        analysis(a).files.chunk = 2 - mod(runs, 2); % this will change all uneven chunk values to 1, and all even values to 2
                        analysis(a).design=make_design_cv(analysis(a));
                case 'cv4fold'
                        analysis(a).files.chunk = mod(runs-1,4)+1;
                        analysis(a).design=make_design_cv(analysis(a));
                case 'cv4foldgrouped'
                        analysis(a).files.chunk = ceil(runs/2);
                        analysis(a).design=make_design_cv(analysis(a));
                case 'cv8fold' % these data have 8 runs, so this is leave-one-run-out
                        analysis(a).files.chunk = runs;
                        analysis(a).design=make_design_cv(analysis(a));
                case 'cvNfold' % the combination of all 2-fold, 4-fold and 8-fold classifications
                        tempcfg1=analysis(a); tempcfg2=analysis(a); tempcfg3=analysis(a);
                        analysis(a).files.chunk = 2 - mod(runs, 2);
                        tempcfg1.design = make_design_cv(analysis(a));
                        analysis(a).files.chunk = mod(runs-1,4)+1;
                        tempcfg2.design = make_design_cv(analysis(a));
                        analysis(a).files.chunk = runs;
                        tempcfg3.design = make_design_cv(analysis(a));
                        
                        analysis(a)=combine_designs(tempcfg1,tempcfg2,tempcfg3);
                case 'cvAll2fold' % all possible split halves
                        splits=nchoosek(1:8,4);
                        nsplits=length(splits);
                        tempcfg=cell(1,nsplits);
                        for s=1:nsplits
                                tempcfg{s}=analysis(a);
                                analysis(a).files.chunk = ismember(runs,splits(s,:));
                                tempcfg{s}.design = make_design_cv(analysis(a));
                        end
                        analysis(a)=combine_designs(tempcfg);
        end
        
        switch analysis(a).decoding.software
                case 'libsvm'
                        % parameters specified above
                case 'liblinear'
                        analysis(a).decoding.train.classification.model_parameters = '-s 7 -c 1 -q';
                        %{
            0 -- L2-regularized logistic regression (primal)
            1 -- L2-regularized L2-loss support vector classification (dual)
            2 -- L2-regularized L2-loss support vector classification (primal)
            3 -- L2-regularized L1-loss support vector classification (dual)
            4 -- support vector classification by Crammer and Singer
            5 -- L1-regularized L2-loss support vector classification
            6 -- L1-regularized logistic regression
            7 -- L2-regularized logistic regression (dual)
                        %}
                case 'newton'
                        analysis(a).decoding.train.classification.model_parameters = -1;
                case 'lda'
                        analysis(a).decoding.train.classification.model_parameters=struct('shrinkage','oas'); % pinv v slow; oas faster than lw and lw2
        end
        
        plot_design(analysis(a));
        %keyboard
        
        tic
        % run main cv decoding:
        results = decoding(analysis(a));
        
        % rerun with label permutations:
        permcfg=analysis(a);
        if isfield(permcfg,'design'), permcfg=rmfield(permcfg,'design');
        end
        if isfield(permcfg.results,'resultsname'), permcfg.results = rmfield(permcfg.results, 'resultsname');
        end
        permcfg.results.filestart = 'perm';
        permcfg.design.function.name=analysis(a).design.function.name;
        switch analysis(a).cv
                case 'cvNfold'
                        permcfg1=permcfg; permcfg2=permcfg; permcfg3=permcfg;
                        
                        permcfg1.files.chunk = 2 - mod(runs, 2);
                        permcfg1.design = make_design_cv(permcfg1);
                        permcfg1.design = sort_design(make_design_permutation(permcfg1,128,true));
                        permcfg2.files.chunk = mod(runs-1,4)+1;
                        permcfg2.design = make_design_cv(permcfg2);
                        permcfg2.design = sort_design(make_design_permutation(permcfg2,128,true));
                        permcfg3.files.chunk = runs;
                        permcfg3.design = make_design_cv(permcfg3);
                        permcfg3.design = sort_design(make_design_permutation(permcfg3,128,true));
                        
                        permcfg=combine_designs(permcfg1,permcfg2,permcfg3);
                        permcfg.design.set=[kron(1:128,[1 1]) kron(1:128,[1 1 1 1]) kron(1:128,[1 1 1 1 1 1 1 1])];
                case 'cvAll2fold'
                        splits=nchoosek(1:8,4);
                        nsplits=length(splits);
                        tempcfg=cell(1,nsplits);
                        for s=1:nsplits
                                tempcfg{s}=permcfg;
                                tempcfg{s}.files.chunk = ismember(runs,splits(s,:));
                                tempcfg{s}.design = make_design_cv(tempcfg{s});
                                tempcfg{s}.design = sort_design(make_design_permutation(tempcfg{s},7,true)); % 7 perms per split gives almost 1000 perms total
                        end
                        permcfg=combine_designs(tempcfg);
                otherwise
                        permcfg.design = sort_design(make_design_permutation(permcfg,128,true)); % 128 unique permutations for 8-fold cross-validation
        end
        [results_perm, final_cfg_perm] = decoding(permcfg);
        
        % plot
        figure(100)
        subplot('position',pos{a,2}); cla
        colors=[0 1 .5; 0 .75 .75; 0 .5 1; 0 0 0];
        perms=[results_perm.accuracy_minus_chance.set.output];
        for roi=1:results.n_decodings
                barh(roi,results.accuracy_minus_chance.output(roi),'facecolor',colors(roi,:)); hold on
                plot(prctile(perms(roi,:),[5 95],2),[roi roi],'r-')
        end
        axis tight on; box off; xlim([-50 50])
        [~, s]=spm_str_manip(cfg.files.mask,'C');
        set(gca,'ytick',1:4,'yticklabel',regexprep(s.m,'_',' '));
        if a<length(analysis)
                set(gca,'xtick',[]);
        end
        
        text(max(xlim),mean(ylim),sprintf('Took %.0f s',toc));
        
        axes('position',pos{a,1});
        axis off
        t=text(mean(xlim),mean(ylim),name,'interpreter','none');
        t.FontSize=8;
        if a==3, t.Color='r'; end % the comparison analysis
        
        drawnow
end % next analysis

%%% other things that might be instructive:
%{
1) Set up an unbalanced and/or confounded data set.
   Test ways to remedy this at the level of:
   - design (e.g. make_design_boot?)
   - classifier (e.g. 'ensemble_balance'?)
   - results (e.g. dprime? AUC_minus_chance? balanced_accuracy_minus_chance?)

2) RSA
%}

keyboard

return
